首页 > 解决方案 > Torch.sort 和 argsort 在相同元素的情况下随机排序

问题描述

当遇到相同的元素时,torch.sort 和 argsort 以随机方式对张量进行排序。在 numpy 中不是这种情况。我有一个已经根据第二列排序的元素列表,现在我想使用第一列对其进行排序,但保留较早的排序,以防在新排序中出现平局。

import torch

a = torch.tensor(
        [[ 0., 3.],
        [ 2., 3.],
        [ 2., 2.],
        [10., 2.],
        [ 0., 2.],
        [ 6., 2.],
        [10., 1.],
        [ 2., 1.],
        [ 0., 1.],
        [ 6., 1.],
        [10., 0.],
        [12., 0.]]
)
print(a[torch.argsort(a[:, 0])])

输出:

tensor([[ 0.,  3.],
        [ 0.,  2.],
        [ 0.,  1.],
        [ 2.,  1.],
        [ 2.,  2.],
        [ 2.,  3.],
        [ 6.,  1.],
        [ 6.,  2.],
        [10.,  1.],
        [10.,  2.],
        [10.,  0.],
        [12.,  0.]])

麻木:

import numpy as np

a = np.array(
        [[ 0., 3.],
        [ 2., 3.],
        [ 2., 2.],
        [10., 2.],
        [ 0., 2.],
        [ 6., 2.],
        [10., 1.],
        [ 2., 1.],
        [ 0., 1.],
        [ 6., 1.],
        [10., 0.],
        [12., 0.]]
)
print(a[np.argsort(a[:, 0])])

输出:

[[ 0.  3.]
 [ 0.  2.]
 [ 0.  1.]
 [ 2.  3.]
 [ 2.  2.]
 [ 2.  1.]
 [ 6.  2.]
 [ 6.  1.]
 [10.  2.]
 [10.  1.]
 [10.  0.]
 [12.  0.]]

这可能是什么原因?我能做些什么来避免它?

标签: pythonnumpysortingpytorch

解决方案


根据 torch 1.9.0,您可以使用 option 运行排序stable=True。请参阅https://pytorch.org/docs/1.9.0/generated/torch.sort.html?highlight=sort#torch.sort

>>> x = torch.tensor([0, 1] * 9)
>>> x.sort()
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 2, 16,  4,  6, 14,  8,  0, 10, 12,  9, 17, 15, 13, 11,  7,  5,  3,  1]))
>>> x.sort(stable=True)
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17]))

文档说这仅在 CPU 上,但很快就会用于 GPU 排序,因为该文档警告已在 github 的主分支中删除(根据https://github.com/pytorch/pytorch/pull/61685 )


推荐阅读