首页 > 解决方案 > numpy.einsum 的输出形状

问题描述

有没有一种优雅的方法可以根据np.einsum给定的 einsum 的输入参数预先计算结果的形状(不运行计算)?

# Given a, b and signature with 
# a.shape == (1, 2, 5)
# b.shape == (4, 5)
einsum_shape('ijk,mk->ik', a, b) # returns (1, 5)

标签: pythonnumpynumpy-einsum

解决方案


这是适用于通用输入数量和相关 einsum 表达式的东西,也适用于特定的标量减少情况 -

def einsum_outshape(einsum_expr, inputs):
    shps = np.concatenate([in_.shape for in_ in inputs])
    p = einsum_expr.split(',')
    s = p[:-1] + p[-1].split('->')
    if s[-1]=='':
        return ()
    else:
        inop = list(map(list,s))
        return tuple(shps[(np.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)])

样品运行 -

In [42]: a = np.random.rand(1,2,5)
    ...: b = np.random.rand(4,5)
    ...: c = np.random.rand(5,7,8)
    ...: d = np.random.rand(7,9)

In [43]: einsum_outshape('ijk,mk,kpq,pr->ikpqr', inputs=(a,b,c,d))
Out[43]: (1, 5, 7, 8, 9)

# Reduction to a scalar
In [44]: einsum_outshape('ijk,mk,kpq,pr->', inputs=(a,b,c,d))
Out[44]: ()

推荐阅读