首页 > 解决方案 > 从 Tensorflow 2.1 张量中每批次提取一个元素

问题描述

假设我有一个包含两个张量的批次,并且补丁中的张量大小为 3。

data = [[0.3, 0.5, 0.7], [-0.3, -0.5, -0.7]]

现在我想从补丁中的每个张量中提取一个基于索引的单个元素:

index = [0, 2]

因此输出应该是

out = [0.3, -0.7] # Get index 0 from the first tensor in the batch and index 2 from the second tensor in the batch.

当然,这应该可以扩展到大批量。的维度index等于批量大小。

我试图申请tf.gathertf.gather_nd但我没有得到我想要的结果。

例如下面的代码打印0.7不是上面指定的期望结果:

data = [[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]]

index = [0, 2]
out = tf.gather_nd(data, index)

print(out.numpy())

标签: numpytensorflow

解决方案


如果您知道批量大小,您可以执行以下操作,

import tensorflow as tf
data = tf.constant([[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]])

index = [0,2]
gather_inds = np.stack([np.arange(len(index)), index], axis=1)
out = tf.gather_nd(data, gather_inds)

Why your gather didn't work is because you are gathering from the inner most dimension. Therefore, your indices need to be as same as the rank of your data tensor. In other words, your indices should be,

[0,0] and [1,2]

推荐阅读