首页 > 解决方案 > 沿特定维度和特定通道索引整个张量

问题描述

假设我们有一个尺寸为dim(A)=[i, j, k=6, u, v]的张量A。现在我们有兴趣通过channels=[0:3]获得维度k的整个张量。我知道我们可以这样得到:

B = A[:, :, 0:3, :, :]

现在我想知道是否有任何更好的“pythonic”方式来实现相同的结果,而无需进行这种次优索引。我的意思是类似的东西。

B = subset(A, dim=2, index=[0, 1, 2])

无论在哪个框架下,即pytorch、tensorflow、numpy等。

非常感谢

标签: pythonnumpytensorflowindexingpytorch

解决方案


在numpy中,您可以使用以下take方法:

B = A.take([0,1,2], axis=2)

在 TensorFlow 中,没有比使用传统方法更简洁的方法了。使用tf.slice会非常冗长:

B = tf.slice(A,[0,0,0,0,0],[-1,-1,3,-1,-1])

您可能会使用take(自 TF 2.4 起)的实验版本:

B = tf.experimental.numpy.take(A, [0,1,2], axis=2)

在 PyTorch 中,您可以使用index_select

torch.index_select(A, dim=2, index=torch.tensor([0,1,2]))

请注意,您可以使用 : 明确地跳过列出第一个维度(或最后一个维度)ellipsis

# Both are equivalent in that case
B = A[..., 0:3, :, :]
B = A[:, :, 0:3, ...]

推荐阅读