首页 > 解决方案 > 张量的 Pytorch 成对串联

问题描述

我想以批处理方式计算特定维度上的成对连接。

例如,

x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2, 3, 1])

我想得到y这样的结果y,即所有向量对在一维上的串联,即:

y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]], 
                 [[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])

y.shape = torch.Size([2, 3, 3, 2])

所以本质上,对于每个x[i,:],您生成所有向量对并将它们连接到最后一个维度。有没有一种简单的方法可以做到这一点?

标签: pytorchconcatenationtensorpairwise

解决方案


一种可能的方法是:

    all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
    y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])

重塑张量后:

y = y.view(x.shape[0], x.shape[1], x.shape[1], -1)

你得到:

y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]], 
                 [[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])

推荐阅读