首页 > 解决方案 > 具有“@”的对象的 Python 类型提示(矩阵乘法)

问题描述

我有一个fun()接受 NumPy ArrayLike和“矩阵”的函数,并返回一个 numpy 数组。

from numpy.typing import ArrayLike
import numpy as np

def fun(A, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

type对于有操作的实体,什么是正确的@?请注意,fun()也可以接受scipy.sparse;也许更多。

标签: pythonnumpytype-hintingpython-typing

解决方案


您可以使用typing.Protocol来断言该类型实现了__matmul__.

class SupportsMatrixMultiplication(typing.Protocol):
    def __matmul__(self, x):
        ...


def fun(A: SupportsMatrixMultiplication, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

我相信,x如果您想要的不仅仅是@作为运算符的支持,您可以通过提供类型提示和返回类型提示来进一步完善这一点。


推荐阅读