首页 > 解决方案 > 有没有办法让 Python all() 函数与多维数组一起工作?

问题描述

我正在尝试__eq__为基类实现一个通用且灵活的方法,该方法将使用尽可能多的对象类型,包括可迭代对象和 numpy 数组。

这是我到目前为止所拥有的:

class Environment:

    def __init__(self, state):
        self.state = state

    def __eq__(self, other):
        """Compare two environments based on their states.
        """
        if isinstance(other, self.__class__):
            try:
                return all(self.state == other.state)
            except TypeError:
                return self.state == other.state
        return False

这适用于大多数对象类型,包括一维数组:

s = 'abcdef'
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

s = [[1, 2, 3], [4, 5, 6]]
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

s = np.array(range(6))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

self.state问题是,当它是一个多维的 numpy 数组时,它会返回一个 ValueError 。

s = np.array(range(6)).reshape((2, 3))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2

产生:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

显然,我可以检查isinstance(other, np.ndarray)然后做,(return self.state == other.state).all()但只是认为可能有一种更通用的方法可以用一个语句来处理任何类型的所有迭代、集合和数组。

我也有点困惑,为什么all()不遍历数组的所有元素,比如array.all(). 有没有办法触发np.nditer并做到这一点?

标签: pythonarraysnumpyequality

解决方案


Signature: all(iterable, /)
Docstring:
Return True if bool(x) is True for all values x in the iterable.

对于一维数组:

In [200]: x=np.ones(3)                                                               
In [201]: x                                                                          
Out[201]: array([1., 1., 1.])
In [202]: y = x==x                                                                   
In [203]: y          # 1d array of booleans                                                                      
Out[203]: array([ True,  True,  True])
In [204]: bool(y[0])                                                                 
Out[204]: True
In [205]: all(y)                                                                     
Out[205]: True

对于二维数组:

In [206]: x=np.ones((2,3))                                                           
In [207]: x                                                                          
Out[207]: 
array([[1., 1., 1.],
       [1., 1., 1.]])
In [208]: y = x==x                                                                   
In [209]: y                                                                          
Out[209]: 
array([[ True,  True,  True],
       [ True,  True,  True]])
In [210]: y[0]                                                                       
Out[210]: array([ True,  True,  True])
In [211]: bool(y[0])                                                                 
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-211-d0ce0868392c> in <module>
----> 1 bool(y[0])

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

但是对于不同的二维数组:

In [212]: x=np.ones((3,1))                                                           
In [213]: y = x==x                                                                   
In [214]: y                                                                          
Out[214]: 
array([[ True],
       [ True],
       [ True]])
In [215]: y[0]                                                                       
Out[215]: array([ True])
In [216]: bool(y[0])                                                                 
Out[216]: True
In [217]: all(y)                                                                     
Out[217]: True

numpy 数组的迭代沿第一个维度进行。 [i for i in x]

每当在需要标量布尔值的上下文中使用多值布尔数组时,都会引发此歧义 ValueError。 ifor/and表达是常见的。

In [223]: x=np.ones((2,3))                                                           
In [224]: y = x==x                                                                   
In [225]: np.all(y)                                                                  
Out[225]: True

np.all与 Pythonall的不同之处在于它“知道”维度。在这种情况下,它会将ravel数组视为 1d:

默认 ( axis= None) 是对输入数组的所有维度执行逻辑与。


推荐阅读