首页 > 解决方案 > 当 dtype 为 object 时,numpy.array_equal 和 numpy.testing.assert_array_equal 不会将 NaN 比较为相等

问题描述

当 dtype 是对象时比较 np.nan 的正确方法是什么?我尝试了以下两种方法,但都失败了

1. np.array_equal

import numpy as np
np.array_equal(np.array([np.nan], dtype=object), np.array([np.nan], dtype=object), equal_nan=True)

我得到的是:

Traceback (most recent call last):
  File "c:\Users\xxxx\xxxxx\.venv\lib\site-packages\numpy\core\numeric.py", line 2455, in array_equal
    a1nan, a2nan = isnan(a1), isnan(a2)
TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

2. np.testing.assert_array_equal

import numpy as np
np.testing.assert_array_equal(np.array([np.nan], dtype=object), np.array([np.nan], dtype=object))

我得到的是:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "c:\Users\xxxx\xxxxx\.venv\lib\site-packages\numpy\testing\_private\utils.py", line 932, in assert_array_equal
    assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
  File "c:\Users\xxxx\xxxxx\.venv\lib\site-packages\numpy\testing\_private\utils.py", line 842, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 1 (100%)
Max absolute difference: nan
Max relative difference: nan
 x: array([nan], dtype=object)
 y: array([nan], dtype=object)

NumPy/Python 版本信息:

1.20.3 3.9.5(标签/v3.9.5:0a7dcbd,2021 年 5 月 3 日,17:27:52)[MSC v.1928 64 位 (AMD64)]

标签: pythonnumpy

解决方案


推荐阅读