首页 > 解决方案 > 广播 numpy dot 产品

问题描述

假设我有一个a矩阵数组,例如 shape = N 3x3 矩阵和一个shape(N, 3, 3)数组。b(i, 3, k)

我想要点积 a * b 的这些行为

  1. 如果i = N,结果应该是一个(N, 3, k)数组,其中第一个元素是a[0].dot(b[0]),第二个是a[1].dot(b[1]),依此类推。
  2. 如果i = 1,则 的每个元素都a必须乘以 ,b[0]并且结果的形状应该是(N, 3, k)

如果我尝试只使用 numpy.dot() 结果几乎是好的,但结果形状不是我所期望的。有没有办法轻松有效地做到这一点,并使其适用于任何维度?

标签: pythonarrayspython-3.xnumpy

解决方案


你是对的,在(N, 3, 3) * (N, 3, k)你不能np.dot直接使用的情况下,因为结果会是(N, 3, N, k). 您实际上必须N从轴 0 和轴 2 中提取对角线元素,但这是大量不必要的计算。如果您后应用 a以删除不必要的第三轴,则(N, 3, 3) * (1, 3, k)可以使用以下方法解决此问题: .np.dotsqueezeresult = a.dot(b).squeeze()

好消息是您不需要np.dot获得点积。以下是三种选择:

最简单的,使用@运算符,相当于np.matmul,它需要前导维度一起广播:

a @ b
np.matmul(a, b)

如果您的矩阵不在最后两个维度中,您可以将它们转置为。这可能效率低下,因为它可能会复制数据。在这些情况下,请继续阅读。

另一种流行的解决方案是使用np.einsum显式指定匹配轴和要求和的轴:

np.einsum('ijk,ikl->ijl', a, b)

i = N广播将处理这两种i = 1情况。

当然,您始终可以使用*/np.multiply和手动获取 sum-product np.sum

(a[..., None] * b[:, None, ...]).sum(axis=2)

推荐阅读