首页 > 解决方案 > Pytorch 张量 - 如何通过特定张量获取索引

问题描述

我有张量

t = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]])

和一个查询张量

q = torch.tensor([1, 0, 0, 0])

有没有办法获取qlike的索引

indexes = t.index(q) # get back [0, 3]

在火炬?

标签: pythonpytorch

解决方案


怎么样

In [1]: torch.nonzero((t == q).sum(dim=1) == t.size(1))
Out[1]: 
tensor([[ 0],
        [ 3]])

Comparing在和之间t == q执行逐元素比较,因为您正在寻找整行匹配,所以您需要沿着行查看哪一行是完美匹配。tq.sum(dim=1)== t.size(1)


从 v0.4.1 开始,torch.all()支持dim参数:

torch.all(t==q, dim=1)

推荐阅读