pytorch - 张量的 Pytorch 成对串联
问题描述
我想以批处理方式计算特定维度上的成对连接。
例如,
x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2, 3, 1])
我想得到y
这样的结果y
,即所有向量对在一维上的串联,即:
y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]],
[[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])
y.shape = torch.Size([2, 3, 3, 2])
所以本质上,对于每个x[i,:]
,您生成所有向量对并将它们连接到最后一个维度。有没有一种简单的方法可以做到这一点?
解决方案
一种可能的方法是:
all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])
重塑张量后:
y = y.view(x.shape[0], x.shape[1], x.shape[1], -1)
你得到:
y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]],
[[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])
推荐阅读
- c - 仅打印带空格的五位数字的代码
- c++ - 普通变量可以保存 C++ 中其他变量的内存地址吗?
- sql - SQL:如何添加一个虚拟列(每个单元格包含“A”,或者一个整数或其他一些字符串
- curl - 从谷歌开发者那里获取授权令牌
- angular - NgrX 效果重定向角度
- java - Java中的数值集成
- excel - 根据唯一的“ID”列将行从一个工作表复制到另一个工作表?
- javascript - 我如何计算然后在 Web 表单上显示已选中但并非全部选中的特定复选框
- c# - .Net 5+ 创建平台无关的 RSA 密钥对的方法
- javascript - 如何在 React Native 中保持 FastImage 的图像行为?