首页 > 解决方案 > 有没有更好的方法来沿第一维将两个 Pytorch 张量相乘和求和?

问题描述

我有两个形状分别为和的Pytorch 张量a& 。是我的批次维度。我想将两个张量相乘并求和,以使输出具有 shape 。也就是说,我想计算 的总和。b(S, M)(S, M, H)M(M, H)sa[s] * b[s]

例如,对于S=2, M=2, H=3:

>>> import torch
>>> S, M, H = 2, 2, 3
>>> a = torch.arange(S*M).view((S,M))
tensor([[0, 1],
        [2, 3]])
>>> b = torch.arange(S*M*H).view((S,M,H))
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])

'''
DESIRED OUTPUT:
= [[0*[0, 1, 2] + 2*[6, 7, 8]], 
   [1*[3, 4, 5] + 3*[9, 10, 11]]]

= [[12, 14, 16],
   [30, 34, 38]]

note: shape is (2, 3) = (M, H)
'''

我找到了一种可行的方法,使用torch.tensordot

>>> output = torch.tensordot(a, b, ([0], [0]))
tensor([[[12, 14, 16],
         [18, 20, 22]],

        [[18, 22, 26],
         [30, 34, 38]]])
>>> output.shape
torch.Size([2, 2, 3]) # always (M, M, H)
>>> output = output[torch.arange(M), torch.arange(M), :]
tensor([[12, 14, 16],
        [30, 34, 38]])

但正如你所看到的,它会产生很多不必要的计算,我必须对与我相关的计算进行切片。

有没有更好的方法来做到这一点,而不涉及不必要的计算?

标签: pythonpytorchtensor

解决方案


这应该有效:

(torch.unsqueeze(a, 2)*b).sum(axis=0)
>>> tensor([[12, 14, 16],
            [30, 34, 38]])

推荐阅读