python - 将 softmax 应用于矩阵中跨维度的非零元素
问题描述
也许这是微不足道的,但也许不是。我花了太多时间试图弄清楚如何完成这项工作。这是代码:
# batch x time x events
batch = 2
time = 3
events = 4
tensor = np.random.rand(batch, time, events)
tensor[0][0][2] = 0
tensor[0][0][3] = 0
tensor[0][1][3] = 0
tensor[0][2][1] = 0
tensor[0][2][2] = 0
tensor[0][2][3] = 0
tensor[1][0][3] = 0
non_zero = ~tf.equal(tensor, 0.)
s = tf.Session()
g = tf.global_variables_initializer()
s.run(g)
s.run(non_zero)
我试图在每个维度上应用tf.nn.softmax
非零值。time
但是,当我使用tf.boolean_mask
then 它实际上将所有非零值收集在一起。那不是我想要的。我想保留尺寸。
所以tf.nn.softmax
应该只适用于那些群体,它应该“把它们放回”原来的位置。有谁知道如何做到这一点?
编辑:
在你们的帮助下,我几乎找到了我需要的解决方案。但我仍然缺少一步。将每个时间维度上的 softmax 分配给非零值:
def apply_sparse_softmax(time_vector):
non_zeros = ~tf.equal(time_vector, 0.)
sparse_softmax = tf.nn.softmax(tf.boolean_mask(time_vector, non_zeros))
new_time_vector = sparse_softmax * tf.cast(non_zeros, tf.float64) # won't work because dimensions are different
return time_vector
另请注意,此解决方案应处理整个时间维度均为零的情况。那么它应该保持不变。
解决方案
可能重复:仅将 tf.nn.softmax() 应用于张量的正元素
tf.map_fn
在和的帮助下tf.where
session.run(tf.map_fn(
lambda x : tf.where(x > 0, tf.nn.softmax(x,axis=2,name="pidgeon"), x), tensor))
经过测试np.random.seed(1992)
# tensor
[[[0.86018176 0.42148685 0. 0. ]
[0.64714 0.68271286 0.6449022 0. ]
[0.92037941 0. 0. 0. ]]
[[0.38479139 0.26825327 0.43027759 0. ]
[0.56077674 0.49309016 0.2433904 0.85396874]
[0.1267429 0.1861004 0.92251748 0.67904445]]]
# result
[[[0.34841156, 0.33845624, 0. , 0. ],
[0.28155918, 0.43949257, 0.48794109, 0. ],
[0.37002926, 0. , 0. , 0. ]],
[[0.33727059, 0.31513436, 0.2885575 , 0. ],
[0.40216839, 0.39458556, 0.23936921, 0.44145382],
[0.26056102, 0.29028008, 0.47207329, 0.37060957]]])
0.34841156 == np.exp(0.86018176) / (np.exp(0.86018176) + np.exp(0.64714) + np.exp(0.92037941))
推荐阅读
- aws-lambda - 从 Lambda 访问 AWS SQS 时,“客户端网络套接字在建立安全 TLS 连接之前已断开”
- node.js - Square Node SDK API 调用返回“未授权”错误
- r - 根据条件和向量对列进行变异
- git - 基线代码更改时 git 合并的正确方法
- javascript - 毛毛雨useCacheCall()期间出错:TypeError:无法读取未定义的属性“方法”
- javascript - 输入逻辑不应该适用于所有输入元素吗?
- sql-server - 多个用户竞争“下一个”订单
- python - keras 的准确性提高不超过 59%
- python-3.x - Ubuntu 18.04 Python 3.7.5 无法安装tensorflow?
- c# - 如何检测 UI 按钮是否已被释放?统一 - C#