首页 > 解决方案 > 张量之间的详尽串联

问题描述

我正在尝试在张量之间进行详尽的连接。因此,例如,我有张量:

a = torch.randn(3, 512)

我想连接像 concat(t1,t1),concat(t1,t2), concat(t1,t3), concat(t2,t1), concat(t2,t2)....

作为一个天真的解决方案,我使用了for循环:

ans = []
result = []
split = torch.split(a, [1, 1, 1], dim=0)

for i in range(len(split)):
    ans.append(split[i])

for t1 in ans:
    for t2 in ans:
        result.append(torch.cat((t1,t2), dim=1))

问题是每个时代都需要很长时间,而且代码很慢。我尝试了在PyTorch上发布的解决方案:How to implement attention for graph attention layer但这会产生内存错误。

t1 = a.repeat(1, a.shape[0]).view(a.shape[0] * a.shape[0], -1)
t2 = a.repeat(a.shape[0], 1)
result.append(torch.cat((t1, t2), dim=1))

我确信有一种更快的方法,但我无法弄清楚。

标签: pythondeep-learningpytorchattention-model

解决方案


推荐阅读