首页 > 解决方案 > 如何在 Python 2 中实现中缀运算符矩阵乘法?

问题描述

我已经使用 python 3.7 几个月了,但我最近不得不转向 python 2.7。由于我正在开发科学代码,因此我严重依赖使用中缀运算符@来乘以 nd 数组。这个操作符是在 python 3.5 中引入的(见这里),因此,我不能在我的新设置中使用它。

显而易见的解决方案是将 all 替换M1 @ M2numpy.matmul(M1, M2),这严重限制了我的代码的可读性。

我看到了这个hack,它包括定义一个中缀类,允许通过重载orror运算符创建自定义运算符。我的问题是:我怎样才能使用这个技巧使中缀|at|运算符像 一样工作@

我尝试的是:

import numpy as np

class Infix:
    def __init__(self, function):
        self.function = function
    def __ror__(self, other):      
        return Infix(lambda x, self=self, other=other: self.function(other, x))
    def __or__(self, other):
        return self.function(other)
    def __call__(self, value1, value2):
        return self.function(value1, value2)

# Matrix multiplication
at = Infix(lambda x,y: np.matmul(x,y))

M1 = np.ones((2,3))
M2 = np.ones((3,4))

print(M1 |at| M2)

当我执行此代码时,我得到:

ValueError: operands could not be broadcast together with shapes (2,3) (3,4) 

我想我知道什么是行不通的。当我只看 时M1|at,我可以看到它是一个 2*3 的函数数组:

array([[<__main__.Infix object at 0x7faa1c0d6da0>,
        <__main__.Infix object at 0x7faa1c0d6860>,
        <__main__.Infix object at 0x7faa1c0d6828>],
       [<__main__.Infix object at 0x7faa1c0d6f60>,
        <__main__.Infix object at 0x7faa1c0d61d0>,
        <__main__.Infix object at 0x7faa1c0d64e0>]], dtype=object)

这不是我所期望的,因为我希望我的代码将这个二维数组视为一个整体,而不是元素...

有人知道我应该做什么吗?

PS:我也考虑过使用这个答案,但我必须避免使用外部模块。

标签: pythonnumpyoperatorspython-2.xmatrix-multiplication

解决方案


我在这里找到了我的问题的解决方案。

正如评论中所建议的,理想的解决方法是使用 Python 3.x 或使用numpy.matmul,但这段代码似乎有效,甚至具有正确的优先级:

import numpy as np

class Infix(np.ndarray):
    def __new__(cls, function):
        obj = np.ndarray.__new__(cls, 0)
        obj.function = function
        return obj
    def __array_finalize__(self, obj):
        if obj is None: return
        self.function = getattr(obj, 'function', None)
    def __rmul__(self, other):
        return Infix(lambda x, self=self, other=other: self.function(other, x))
    def __mul__(self, other):
        return self.function(other)
    def __call__(self, value1, value2):
        return self.function(value1, value2)

at = Infix(np.matmul)

M1 = np.ones((2,3))
M2 = np.ones((3,4))
M3 = np.ones((2,4))

print(M1 *at* M2)
print(M3 + M1 *at* M2)

推荐阅读