首页 > 解决方案 > 使用排序索引重新排列 3-D 数组?

问题描述

我有一个大小为 3-D 的随机数数组[channels = 3, height = 10, width = 10]

在此处输入图像描述 然后我使用 pytorch 中的 sort 命令沿列对其进行排序并获得索引。

在此处输入图像描述

对应的索引如下图:

在此处输入图像描述

现在,我想使用这些索引返回原始矩阵。我目前使用for循环来执行此操作(不考虑批次)。代码是:

import torch
torch.manual_seed(1)
ch = 3
h = 10
w = 10
inp_unf = torch.randn(ch,h,w)
inp_sort, indices  = torch.sort(inp_unf,1)
resort = torch.zeros(inp_sort.shape)

for i in range(ch):
    for j in range(inp_sort.shape[1]):
        for k in range (inp_sort.shape[2]):
            temp = inp_sort[i,j,k]
            resort[i,indices[i,j,k],k] = temp

考虑到批次以及输入大小,我希望它被矢量化[batch, channel, height, width]

标签: python-3.xsortingfor-loopvectorizationpytorch

解决方案


使用Tensor.scatter_()

您可以使用以下提供的索引将已排序的张量直接分散回其原始状态sort()

torch.zeros(ch,h,w).scatter_(dim=1, index=indices, src=inp_sort)

直觉基于以下先前的答案。Asscatter()基本上是 , 的反面gather()inp_reunf = inp_sort.gather(dim=1, index=reverse_indices)与 相同inp_reunf.scatter_(dim=1, index=indices, src=inp_sort)


上一个答案

注意:虽然正确,但这可能性能较差,因为sort()第二次调用该操作。

您需要获取排序“反向索引”,这可以通过“排序返回的索引sort()”来完成。

换句话说,给定x_sort, indices = x.sort(),你有x[indices] -> x_sort;而你想要的是reverse_indices这样的x_sort[reverse_indices] -> x

这可以通过以下方式获得:_, reverse_indices = indices.sort()

import torch
torch.manual_seed(1)
ch, h, w = 3, 10, 10
inp_unf = torch.randn(ch,h,w)

inp_sort, indices = inp_unf.sort(dim=1)
_, reverse_indices = indices.sort(dim=1)
inp_reunf = inp_sort.gather(dim=1, index=reverse_indices)

print(torch.equal(inp_unf, inp_reunf))
# True

推荐阅读