首页 > 解决方案 > 显示导致测试失败的数组条目

问题描述

作为测试套件的一部分,我必须检查函数返回的 numpy 数组是否正确。np.array_equal使用which 返回一个关于所有数组元素是否相同的布尔值很容易进行此检查。

如果测试失败,则错误消息对于了解导致失败的原因并不是特别有帮助。

import unittest
import numpy as np

class TestArray(unittest.TestCase):
    def test_values(self):
        x = np.array([1, 2])
        self.assertTrue(np.array_equal(x, [1, 3]))


if __name__ == "__main__":
    unittest.main()

测试失败消息:

Traceback (most recent call last):
  File "test.py", line 7, in test_values
    self.assertTrue(np.array_equal(x, [1, 3]))
AssertionError: False is not true

有没有一种简单的方法来检查条目是否相等,它显示第一个不相等的条目的索引和值?我想要一条错误消息,例如:

AssertionError: Arrays not equal at index 1 (2 != 3) 

标签: pythonpython-3.xnumpypython-unittest

解决方案


np.array_equal我们可以获取代码并重写它,在最后添加另一个检查

def array_equal(a1, a2):
    try:
        a1, a2 = asarray(a1), asarray(a2)
    except Exception:
        return False
    if a1.shape != a2.shape:
        return False
    eq = asarray(a1 == a2) # [ True False False True]
    if not bool(eq.all()):
        errors = [f"idx:{idx} ({vals[0]}!={vals[1]})"
                  for idx, vals in enumerate(zip(a1, a2))
                  if not eq[idx]]
        raise AssertionError("Arrays not equal " + " ".join(errors))
    return True

class TestArray(unittest.TestCase):
    def test_values(self):
        x = np.array([1, 1, 1, 1])
        self.assertTrue(array_equal(x, [1, 2, 3, 1]))

if __name__ == "__main__":
    unittest.main()

AssertionError: Arrays not equal idx:1 (1!=2) idx:2 (1!=3)


推荐阅读