首页 > 解决方案 > 有效地计算数组最后一维的点积

问题描述

在多维ndarray的最后一维上计算点积的最快方法是什么?

目前我正在这样做:

import numpy as np

a=np.reshape(np.arange(90),[3,3,2,5])
b=np.reshape(np.arange(90),[3,3,2,5])
# for the sake of simplicity, a and b are the same for this example

ab=(a*b).sum(axis=-1)

我认为这einsum在这里可能有用,但我很难将其应用于我的案例。

谢谢!

标签: python-3.xnumpynumpy-einsum

解决方案


对于通用 ndim 数组以沿最后一个轴进行减和 -

np.einsum('...i,...i->...',a,b)

替代np.matmul-

np.matmul(a[...,None,:],b[...,None])[...,0,0]

注意:在 Python 3.xnp.matmul上可以替换为@ operator.


推荐阅读