numpy - 从 Tensorflow 2.1 张量中每批次提取一个元素
问题描述
假设我有一个包含两个张量的批次,并且补丁中的张量大小为 3。
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, -0.7]]
现在我想从补丁中的每个张量中提取一个基于索引的单个元素:
index = [0, 2]
因此输出应该是
out = [0.3, -0.7] # Get index 0 from the first tensor in the batch and index 2 from the second tensor in the batch.
当然,这应该可以扩展到大批量。的维度index
等于批量大小。
我试图申请tf.gather
,tf.gather_nd
但我没有得到我想要的结果。
例如下面的代码打印0.7
而不是上面指定的期望结果:
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]]
index = [0, 2]
out = tf.gather_nd(data, index)
print(out.numpy())
解决方案
如果您知道批量大小,您可以执行以下操作,
import tensorflow as tf
data = tf.constant([[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]])
index = [0,2]
gather_inds = np.stack([np.arange(len(index)), index], axis=1)
out = tf.gather_nd(data, gather_inds)
Why your gather didn't work is because you are gathering from the inner most dimension. Therefore, your indices need to be as same as the rank of your data
tensor. In other words, your indices should be,
[0,0] and [1,2]
推荐阅读
- emacs - 通过 org-sbe 将 org-mode 属性传递给源代码块
- javascript - Javascript没有按预期反应
- python - 在 python 正则表达式中,为什么 (h)* 和 (h)+ 不能产生相同的结果?
- google-maps-api-3 - 使用带有 Flutter 的 Google Maps 初始单击标记后如何更新 InfoWindowText?
- c# - HttpClient.SendAsync 显然崩溃或阻塞?
- matlab - 未分配一个或多个输出参数
- javascript - 节点文件系统模块重命名方法
- reactjs - 如何隐藏 MUI React ListItem?
- excel - 匹配两个表中的两个单元格并检查条件是否为真
- python - 此代码必须使用 for 循环打印每个其他字符