首页 > 解决方案 > Pytorch 索引

问题描述

我有一个张量 [[1,2],[4,5],[7,8]] 和一个索引为 [0,1,0] 的张量。

我想将它们应用于第二维,以便它返回:[1,5,8]。

我该怎么做?

谢谢!

标签: pythonpytorch

解决方案


import torch

arr=torch.tensor([[1,2],[4,5],[7,8]])
indices_arr=torch.tensor([0,1,0])

ret=arr[[0,1,2],indices_arr]
# print(ret)
# tensor([1, 5, 7])

推荐阅读