首页 > 解决方案 > 当输入是许多相同的数组时使 np.einsum 更快?(或任何其他更快的方法)

问题描述

我有一段代码类型:

nnt = np.real(np.einsum('xa,xb,yc,yd,abcde->exy',evec,evec,evec,evec,quartic))

其中evec是(比如说)一个 L x Lnp.float32数组,并且quartic是一个 L x L x L x L x Tnp.complex64数组。

我发现这个程序相当慢。

我认为既然所有的evec' 都是相同的,那么可能有更快的方法来做到这一点?

提前致谢。

标签: pythonnumpynumpy-einsum

解决方案


首先,您可以重用第一个计算:

evec2 = np.real(np.einsum('xa,xb->xab',evec,evec))
nnt = np.real(np.einsum('xab,ycd,abcde->exy',evec2,evec2,quartic))

如果你不关心内存,只需要性能:

evec2 = np.real(np.einsum('xa,xb->xab',evec,evec))
nnt = np.real(np.einsum('xab,ycd,abcde->exy',evec2,evec2,quartic,optimize=True))

推荐阅读