首页 > 解决方案 > PyTorch - 另一个张量中对应值的索引

问题描述

我有一个张量,我只想复制一些值(按列)。相同的值在另一个张量中,但顺序是随机的。我想要的是来自tensor2的值的列索引tensor1。这是一个例子:

copy_ind = torch.tensor([0, 1, 3], dtype=torch.long)
tensor1 = torch.tensor([[4, 6, 5, 1, 8],[10, 0, 8, 2, 1]])
temp = torch.index_select(tensor1, 1, copy_ind) # values to copy
tensor2 = torch.tensor([[1, 4, 5, 6, 8],[2, 10, 8, 0, 1]], dtype=torch.long)
_, t_ind = torch.sort(temp[0], dim=0)
t2_ind = copy_ind[t_ind] # indices of tensor2

输出应该是:

t2_ind = [1, 3, 0]

这是另一个示例,我想根据以下方式获取张量的值c1_new

c1 = torch.tensor([[6, 7, 7, 8, 6, 8, 9, 4, 7, 6, 1, 3],[5, 11, 5, 7, 2, 9, 5, 5, 7, 11, 10, 7]], dtype=torch.long)
copy_ind = torch.tensor([1, 2, 3, 5, 7, 8], dtype=torch.long)
c1_new = torch.index_select(c1, 1, copy_ind)

indices = torch.as_tensor([[1, 3, 4, 6, 6, 6, 7, 7, 7, 8, 8, 9], [10, 7, 5, 2, 5, 11, 5, 7, 11, 7, 9, 5]])
values = torch.randn(12)
tensor = torch.sparse.FloatTensor(indices, values, (12, 12))

_, t_ind = torch.sort(c1[0], dim=0)
ind = t_ind[copy_ind] # should be [8, 6, 9, 10, 2, 7]

不幸的是,指数ind不正确。有人可以帮帮我吗?

标签: pythonpytorch

解决方案


如果您可以使用 for 循环,您可以使用以下内容:检查您的临时张量的每一列与 tensor2 的列:

编辑:使用torch.prod跨维度 1 确保两行匹配

[torch.prod((temp.T[i] == tesnor2.T), dim=1).nonzero()[0] for i in range(temp.size(1))]

我的第一个例子的输出是[tensor(1), tensor(3), tensor(0)]


推荐阅读