首页 > 解决方案 > 用一维张量对二维张量进行子集化

问题描述

我想从二维张量的每一行中提取存储在另一个一维张量中的列。

import torch
test_tensor = tensor([1,-2,3], [-2,7,4]).float()
select_tensor = tensor([1,2])

所以在这个特定的例子中,我想获得位置 1 的第一行的元素(so -2)和位置 2 的第二行的元素(so 4)。我试过了:

test_tensor[:, select_tensor]

但这会为每一行选择位置 1 和 2 的元素。我怀疑这可能是我错过的非常简单的事情。

标签: pythonpytorchtensor

解决方案


如果您正在寻找带有索引的解决方案,您也需要建立索引axis=0,您可以这样做torch.arange

>>> test_tensor = torch.tensor([[1,-2,3], [-2,7,4]])
>>> select_tensor = torch.tensor([1,2])

>>> test_tensor[torch.arange(len(select_tensor)), select_tensor]
tensor([-2,  4])

推荐阅读