python - 仅将 tf.nn.softmax() 应用于张量的正元素
问题描述
我尝试了很长时间来解决这个问题,但在互联网上没有找到任何有用的东西,所以我不得不问:
假设给定一个张量T
,T = tf.random_normal([100])
我只想应用于softmax()
张量的正元素。类似的东西T = tf.nn.softmax(T[T>0])
当然在 Tensorflow 中不起作用。
简而言之:我想计算 softmax 并仅应用于 elements T > 0
。
我如何在 Tensorflow 中做到这一点?
解决方案
如果您希望 softmax 计算 + 仅应用于元素 T > 0:
一个想法可能是根据您的条件 ( T > 0
) 创建 2 个分区,将操作 ( softmax
) 应用于目标分区,然后将它们缝合在一起。
像这样,使用tf.dynamic_partition
and tf.dynamic_stitch
:
import tensorflow as tf
T = tf.random_normal(shape=(2, 3, 4))
# Creating partition based on condition:
condition_mask = tf.cast(tf.greater(T, 0.), tf.int32)
partitioned_T = tf.dynamic_partition(T, condition_mask, 2)
# Applying the operation to the target partition:
partitioned_T[1] = tf.nn.softmax(partitioned_T[1])
# Stitching back together, flattening T and its indices to make things easier::
condition_indices = tf.dynamic_partition(tf.range(tf.size(T)), tf.reshape(condition_mask, [-1]), 2)
res_T = tf.dynamic_stitch(condition_indices, partitioned_T)
res_T = tf.reshape(res_T, tf.shape(T))
with tf.Session() as sess:
t, res = sess.run([T, res_T])
print(t)
# [[[-1.92647386 0.7442674 1.86053932 -0.95315439]
# [-0.38296485 1.19349718 -1.27562618 -0.73016083]
# [-0.36333972 -0.90614134 -0.15798278 -0.38928652]]
#
# [[-0.42384467 0.69428021 1.94177043 -0.13672788]
# [-0.53473723 0.94478583 -0.52320045 0.36250541]
# [ 0.59011376 -0.77091616 -0.12464728 1.49722672]]]
print(res)
# [[[-1.92647386 0.06771058 0.20675084 -0.95315439]
# [-0.38296485 0.10610957 -1.27562618 -0.73016083]
# [-0.36333972 -0.90614134 -0.15798278 -0.38928652]]
#
# [[-0.42384467 0.06440912 0.22424641 -0.13672788]
# [-0.53473723 0.08274478 -0.52320045 0.04622314]
# [ 0.05803747 -0.77091616 -0.12464728 0.14376813]]]
上一个答案
仅当您希望对 的所有元素计算 softmaxT
但仅应用于大于的元素时,此答案才有效0
。
使用tf.where()
:
T = tf.where(tf.greater(T, 0.), tf.nn.softmax(T), T)
推荐阅读
- c# - 未处理的异常呈现组件:在“窗口”中找不到“AuthenticationService”
- r - 将 n 对开始日期和结束日期的宽格式转换为长格式
- python - 所以我正在尝试使用 atom 在我的 mac 上运行 python 脚本,但由于某种原因它无法正常工作
- javascript - 电子邮件功能在本地服务器上有效,但在 Web 服务器上无效
- javascript - Javascript中的原型概念,我说对了吗?
- spring-data-jpa - DB2 不接受 UUID
- reactjs - 如何在路由更改时获取 React 路由器路径
- android - 使用两个 RecyclerView 滚动页面(片段):最先进的技术以及如何回收?
- r - R,数据框操作,排序
- php - 如何使用 PHP 将字节流发送到套接字?