首页 > 解决方案 > PyTorch - 在 torch.sort 之后取回原始张量顺序的更好方法

问题描述

我想在操作和对排序张量进行一些其他修改后取回原始张量顺序torch.sort,以便不再对张量进行排序。最好用一个例子来解释这一点:

x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices) 
# final must be equal to torch.tanh(x)

我以这种方式实现了该功能:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    for i in range(ordered.size(0)):
        z[indices[i]] = ordered[i]
    return z

有一个更好的方法吗?特别是,是否可以避免循环并更有效地计算操作?

在我的情况下,我有一个大小的张量,我通过一次调用单独torch.Size([B, N])对每一行进行排序。所以,我必须用另一个循环来调用时间。Btorch.sortoriginal_order B

任何,更多pytorch-ic,想法?

编辑 1 - 摆脱内循环

我通过以这种方式简单地用索引索引 z 解决了部分问题:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    z[indices] = ordered
    return z

现在,我只需要了解如何避免B维度上的外循环。

编辑 2 - 摆脱外循环

def original_order(ordered, indices, batch_size):
    # produce a vector to shift indices by lenght of the vector 
    # times the batch position
    add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)


    indices = indices + add.long().view(-1,1)

    # reduce tensor to single dimension. 
    # Now the indices take in consideration the new length
    long_ordered = ordered.view(-1)
    long_indices = indices.view(-1)

    # we are in the previous case with one dimensional vector
    z = torch.zeros_like(long_ordered).float()
    z[long_indices] = long_ordered

    # reshape to get back to the correct dimension
    return z.view(batch_size, -1)

标签: pythonpytorch

解决方案


def original_order(ordered, indices):
    return ordered.gather(1, indices.argsort(1))

例子

original = torch.tensor([
    [20, 22, 24, 21],
    [12, 14, 10, 11],
    [34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))

为什么有效

为简单起见,想象一下t = [30, 10, 20],省略张量符号。

t.sort()免费为我们提供排序张量s = [10, 20, 30]以及排序索引i = [1, 2, 0]i实际上是 的输出t.argsort()

i告诉我们如何从tst“要对进行排序s,请从“中获取元素 1,然后是 2,然后是 0 t。Argsortingi为我们提供了另一个排序索引j = [2, 0, 1],它告诉我们如何从i自然数的规范序列转到自然数的规范序列[0, 1, 2],实际上是反转排序。另一种看待它的方式是j告诉我们如何从st。“要排序st,请从“中获取元素 2,然后是 0,然后是 1 s。对排序索引进行 Argsorting 为我们提供了它的“逆索引”,反之亦然。

现在我们有了逆索引,我们将其转储到torch.gather()正确的dim中,这会取消对张量的排序。

来源

torch.gather torch.argsort

在研究这个问题时我找不到这个确切的解决方案,所以我认为这是一个原始答案。


推荐阅读