首页 > 解决方案 > python __eq__ 测试失败

问题描述

我是 python 新手,我正在学习使用pytest. 我有一个类定义为:

class Matrix:

    def __init__(self, *rows):
        row_length = len(rows[0])
        for row in rows:
            # TODO skip first
            if len(row) != row_length:
                raise SystemError("Rows does not have equal length")

        self._rows = [*rows]

    def __eq__(self, other):
        return isinstance(self, other.__class__) and \
               all([x == y for x, y in zip_longest(self._rows, other._rows)])

    # other methods omitted for simplicity...

我为此写了一个测试__eq__(self, other)

def test_eq():
    m1 = Matrix([[1,2,3],[4,5,6]])
    m2 = Matrix([1,2,3],[4,5,6])
    m3 = Matrix([1,2,3],[5,4,6])
    assert m1 == m2
    assert m2 == m1
    assert m2 != m3

Wich 应该通过,因为m1m2具有相同的行,并且m3在第二行有差异。但是,当我运行此测试时,我有输出:

    def test_eq():
        m1 = Matrix([[1,2,3],[4,5,6]])
        m2 = Matrix([1,2,3],[4,5,6])
        m3 = Matrix([1,2,3],[5,4,6])
>       assert m1 == m2
E       assert <exercises.matrix.Matrix object at 0x10ccd67d0> == <exercises.matrix.Matrix object at 0x10ccd6810>

我在这里想念什么?我正在使用 Python 3.7.4 和 pytest 版本 5.1.2。提前感谢您的评论/回答


注意:我根据 ggorlen 的答案更改了实现,但我遇到了类似的问题


标签: python-3.xpytest

解决方案


比较中的行应类似于:

for i, i_row in enumerate(self._rows):
    if i_row != other._rows[i]:
        return False

other但是,如果行数多于,这仍然不会返回正确的结果self,因此:

def __eq__(self, other):
    return isinstance(self, other.__class__) and \
           len(other._rows) == len(self._rows) and \
           all([x == y for x, y in zip(self._rows, other._rows)])

该属性称为_rows,我们需要使用它[]来索引列表,而不是括号。

一个可能更快的版本,可以在失败的比较早期保释是:

def __eq__(self, other):
    if isinstance(self, other.__class__) and \
      len(other._rows) == len(self._rows):
        for i, row in enumerate(self._rows):
            if row != other._rows[i]:
                return False

        return True

    return False

在您的测试中,您可能有一个错字:

m1 = Matrix([[1,2,3],[4,5,6]]) # <-- this matrix has an extra `[]` wrapper
m2 = Matrix([1,2,3],[4,5,6])   # <-- but this one just uses flat lists

所以这些矩阵将不相等。


小建议:

  • 提高 a ValueErrororArgumentError而不是 aSystemError上的错误参数。
  • 考虑使用Numpy.matrix而不是滚动您自己的矩阵。

推荐阅读