首页 > 解决方案 > Pytorch Tensor 如何获取元素的索引?

问题描述

我有 2 个名为xlist的张量,它们的定义如下:

x = torch.tensor(3)
list = torch.tensor([1,2,3,4,5])

现在我想从list获取元素x的索引。预期的输出是一个整数:

2

我怎样才能以简单的方式做?

标签: pythonpytorchtorchtensor

解决方案


import torch

x = torch.tensor(3)

list = torch.tensor([1,2,3,4,5])
idx = (list == x).nonzero().flatten()
print (idx.tolist()) # [2]

list = torch.tensor([1,2,3,3,5])
idx = (list == x).nonzero().flatten()
print (idx.tolist()) # [2, 3]

推荐阅读