首页 > 解决方案 > PyTorch 中两个张量的所有可能串联

问题描述

假设我有两个张量ST定义为:

S = torch.rand((3,2,1))
T = torch.ones((3,2,1))

我们可以将这些视为包含具有形状的张量批次(2, 1)。在这种情况下,批量大小为3.

我想连接批次之间所有可能的配对。批次的单个串联产生一个张量 shape (4, 1)。并且存在3*3组合,因此最终产生的张量C必须具有 的形状(3, 3, 4, 1)

一种解决方案是执行以下操作:

for i in range(S.shape[0]):
  for j in range(T.shape[0]):
    C[i,j,:,:] = torch.cat((S[i,:,:],T[j,:,:]))

但是 for 循环不能很好地扩展到大批量。是否有 PyTorch 命令可以执行此操作?

标签: pytorch

解决方案


在 numpy 中,使用了一种叫做 np.meshgrid 的东西。

https://stackoverflow.com/a/35608701/3259896

所以在pytorch中,它会是

torch.stack(
torch.meshgrid(x, y)
).T.reshape(-1,2)

其中 x 和 y 是您的两个列表。您可以使用任何号码。x、y、z 等

然后你将它重塑为你使用的列表数量。

因此,如果您使用了三个列表,则使用.reshape(-1,3)、 四个使用.reshape(-1,4)等。

所以对于 5 个张量,使用

torch.stack(
torch.meshgrid(a, b, c, d, e)
).T.reshape(-1,5)

推荐阅读