首页 > 解决方案 > Numpy ndarray 意外的形状广播错误

问题描述

我有一个形状为 (3,) 的 numpy ndarray。我有另一个形状为(3,100,100)的ndarray。以下作品:

a = np.array([1,1,1]) # Shape is (3,)
b = np.zeros((3,100,100)) # Shape is (3,100,100)
c = np.array([b[0], b[1], 0]) # Shape (3,)
c - a # works fine and as expected 

但以下中断:

c_wrong = np.array([b[0], b[1], b[2]]) # now c_wrong is (3,100,100) too

c_wrong - a # ValueError: operands could not be broadcast together with shapes (3,100,100) (3,)

有没有办法将 (3,100,100) 重塑为 (3,)?

我发现一个丑陋的走动只是添加一个虚拟的额外组件:

>>> c_wrong = np.array([b[0],b[1],b[2],0])
>>> a = np.array([1,1,1,1])
>>> d = c_wrong - a
>>> d[0:3]

虽然这很丑陋,但我希望它有助于理解问题和所需的行为。

标签: pythonnumpynumpy-ndarray

解决方案


多看外形!

In [82]: a = np.array([1,1,1]) # Shape is (3,) 
    ...: b = np.zeros((3,10,10)) # Shape is (3,10,10) 
    ...: c = np.array([b[0], b[1], 0]) # Shape (3,)                             
In [83]:                                                                        
In [83]: c                                                                      
Out[83]: 
array([array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
       array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
       0], dtype=object)
In [84]: c.shape                                                                
Out[84]: (3,)

是的,c只有 3 个元素,但每个元素都是数组或标量(最后一个 0)。

In [85]: c-a                                                                    
Out[85]: 
array([array([[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]]),
       array([[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]]),
       -1], dtype=object)

所以你设法从每个元素中减去 1!

c_wrong是一个非常不同的数组 - 它是带有数字 dtype 的 3d。将其替换0d[3]一切都不同。

In [88]: c_wrong.shape                                                          
Out[88]: (3, 10, 10)
In [89]: c_wrong.dtype                                                          
Out[89]: dtype('float64')

要从 (3,N,N) 中减去 (3,),您必须将尺寸调整a为 (3,1,1)。然后它可以进行适当的广播。

In [91]: c_wrong -  a[:,None,None]                                              
Out[91]: 
array([[[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
        [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
        ....
        [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]]])

我认为你的c-a作品只是一个意外。通过c使用0元素定义,您创建了一个objectdtype 数组。具有对象 dtype 数组的数学是不择手段的。这种减法恰好是其中之一。但不要指望它;有很多方法使这种数组的数学不起作用 - 而且它总是更慢。


c_wrong本质上与b.


numpy 的核心是多维数值数组。 np.array默认情况下,尝试构建尽可能高的维数。在您的c_wrong情况下,它可以制作 3d;inc不能因为标量 0。所以它回退到制作一维对象数组。

制作所需形状的对象数组的最可靠方法是初始化一个“空白”数组,然后填充它。但即便如此,填充也可能很棘手。在这里,我设法做到了:

In [92]: c3 = np.empty(3, object)                                               
In [93]: c3                                                                     
Out[93]: array([None, None, None], dtype=object)
In [94]: c3[:] = list(b)                                                        
In [95]: c3                                                                     
Out[95]: 
array([array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       ....
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])], dtype=object)
In [96]: c3-a                                                                   
Out[96]: 
array([array([[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
....
       [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]])], dtype=object)

无效的填充:

In [97]: c3[:] = b                                                              
------------------------------------------------------------------------ 
...
ValueError: could not broadcast input array from shape (3,10,10) into shape (3)

a[:,None,None]当你熟悉广播时,看起来并不那么难看。

比较时间:

In [98]: timeit c_wrong-a[:,None,None]                                          
5.22 µs ± 6.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [99]: timeit c3-a                                                            
9.53 µs ± 20.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [100]: timeit c-a                                                            
7.66 µs ± 10.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

或与dot

In [103]: timeit np.dot(a, b.reshape(3,-1)).shape                              
2.44 µs ± 9.63 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [104]: timeit np.dot(a,c).shape                                              
10.9 µs ± 16.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [105]: timeit np.dot(a,c3).shape                                             
11.6 µs ± 30.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

dot有非常具体的规则 - 的最后一个轴a必须与第二个到最后一个匹配b。这就是我使用reshape. 它会将任务传递给一个快速的“blas”例程。

使用 (3,) 对象数组,它执行一维dot乘积 - 但迭代。

@,matmul适用于 reshape b,但不适用于cor c3。同样适用于einsum:np.einsum('i,ijk->jk',a,b).shape有效,但没有任何使用c.


推荐阅读