首页 > 解决方案 > 将 2d 数组与 3d 数组的每个切片相乘 - Numpy

问题描述

我正在寻找一种优化的方法来计算 2d 数组的元素乘法乘以 3d 数组的每个切片(使用 numpy)。

例如:

w = np.array([[1,5], [4,9], [12,15]]) y = np.ones((3,2,3))

我想得到一个与 3d 数组形状相同的结果y

不允许使用 * 运算符进行广播。就我而言,第三维很长,for 循环不方便。

标签: pythonarraysnumpymultiplicationlapack

解决方案


给定数组

import numpy as np

w = np.array([[1,5], [4,9], [12,15]])

print(w)

[[ 1  5]
 [ 4  9]
 [12 15]]

y = np.ones((3,2,3))

print(y)

[[[ 1.  1.  1.]
  [ 1.  1.  1.]]

 [[ 1.  1.  1.]
  [ 1.  1.  1.]]

 [[ 1.  1.  1.]
  [ 1.  1.  1.]]]

我们可以直接对数组进行多重化,

z = ( y.transpose() * w.transpose() ).transpose()

print(z)

[[[  1.   1.   1.]
  [  5.   5.   5.]]

 [[  4.   4.   4.]
  [  9.   9.   9.]]

 [[ 12.  12.  12.]
  [ 15.  15.  15.]]]

我们可能会注意到,这会产生与 np.einsum('ij,ijk->ijk',w,y) 相同的结果,可能需要更少的努力和开销。


推荐阅读