首页 > 解决方案 > 基于旧张量和二维索引的新张量

问题描述

我之前问过: PyTorch 张量:基于旧张量和索引的新张量

我现在有同样的问题,但需要使用 2d 索引张量。

我有一个大小为 [batch_size, k] 的张量 col,其值介于 0 和 k-1 之间:

idx = tensor([[0,1,2,0],
        [0,3,2,2],
        ...])

和以下矩阵:

x = tensor([[[0, 9],
 [1, 8],
 [2, 3],
 [4, 9]],
 [[0, 0],
 [1, 2],
 [3, 4],
 [5, 6]]])

我想创建一个新的张量,其中包含索引中指定的行,按该顺序。所以我想要:

tensor([[[0, 9],
 [1, 8],
 [2, 3],
 [0, 9]],
 [[0, 0],
 [5, 6],
 [3, 4],
 [3, 4]]])

目前我正在这样做:

for i, batch in enumerate(t):
    t[i] = batch[col[i]]

我怎样才能更有效地做到这一点?

标签: pythonindexingpytorchtensortorch

解决方案


你应该使用火炬收集来实现这一点。它实际上也适用于您链接的其他问题,但这留给读者作为练习:p

让我们称idx您的第一个张量和source第二个张量。它们各自的尺寸是(B,N)(B, K, p)p=2在你的例子中),所有的值idx都在0和之间K-1

所以要使用torch gather,我们首先需要将您的操作表示为嵌套的for循环。就您而言,您真正想要实现的是:

for b in range(B):
    for i in range(N):
        for j in range(p):
            # This kind of nested for loops is what torch.gether actually does
            target[b,i,j] = source[b, idx[b,i,j], j]

但这不起作用,因为idx它是 2D 张量,而不是 3D 张量。好吧,没什么大不了的,让我们把它变成一个 3D 张量。我们希望它具有形状(B, N, p)并且沿最后一个维度实际上是恒定的。然后我们可以用调用来替换 for 循环gather

reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
target = source.gather(1, reshaped_idx)
# or : target = torch.gather(source, 1, reshaped_idx)

推荐阅读