首页 > 解决方案 > 如何在二维张量中获得相同值的张量索引?

问题描述

如标题中所述,给定一个二维张量,假设:

tensor([
    [0, 1, 0, 1], # A
    [1, 1, 0, 1], # B
    [1, 0, 0, 1], # C
    [0, 1, 0, 1], # D
    [1, 1, 0, 1], # E
    [1, 1, 0, 1]  # F
])

这很容易看出“A 和 D”、“B、E 和 F”是两组张量,

具有相同的值(这意味着 A == D 和 B == E == F)。

所以我的问题是:

如何获取这些组的索引?

细节:

输入:上面的张量

输出:(0, 3), (1, 4, 5)

标签: pythonpytorchtensorindices

解决方案


使用 PyTorch 函数的解决方案:

import torch

x = torch.tensor([
    [0, 1, 0, 1], # A
    [1, 1, 0, 1], # B
    [1, 0, 0, 1], # C
    [0, 1, 0, 1], # D
    [1, 1, 0, 1], # E
    [1, 1, 0, 1]  # F
])

_, inv, counts = torch.unique(x, dim=0, return_inverse=True, return_counts=True)
print([tuple(torch.where(inv == i)[0].tolist()) for i, c, in enumerate(counts) if counts[i] > 1])
# > [(0, 3), (1, 4, 5)]

推荐阅读