首页 > 解决方案 > 两个张量的 Pytorch 广播乘积

问题描述

我想将两个张量相乘,这就是我得到的:

第一个索引用于批量大小。我想要做的基本上是从B-(20, 1, 110)中获取每个张量,例如,我想将每个A张量相乘(20, n, 110)。所以产品将在最后:AB形状为的张量(20, 96 * 16, 110)

所以我想A通过广播来乘以每个张量B。PyTorch 中是否有一种方法可以做到这一点?

标签: pythondeep-learningmatrix-multiplicationpytorchtensor

解决方案


使用torch.einsum后跟torch.reshape

AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])

例子:

import numpy as np
import torch

# A of shape (2, 3, 2):
A = torch.from_numpy(np.array([[[1, 1], [2, 2], [3, 3]], 
                               [[4, 4], [5, 5], [6, 6]]]))
# B of shape (2, 2, 2):
B = torch.from_numpy(np.array([[[1, 1], [10, 10]], 
                               [[2, 2], [20, 20]]]))

# AB of shape (2, 3*2, 2):
AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])
# tensor([[[ 1, 1], [ 10, 10], [  2,  2], [ 20,   20], [ 3,   3], [ 30,  30]],
#         [[ 8, 8], [ 80, 80], [ 10, 10], [ 100, 100], [ 12, 12], [ 120, 120]]])

推荐阅读