首页 > 解决方案 > 可以使用 Mypy 给 fold_left() 一个类型吗?

问题描述

我有一个简单的左折叠功能,如下:

from typing import Iterable, Callable, Optional, TypeVar, overload


S = TypeVar("S")
T = TypeVar("T")


def fold_left(it: Iterable[S], f: Callable[[T, S], T], init: Optional[T] = None) -> T:
    it = iter(it)

    if init is None:
        try:
            acc = next(it)
        except StopIteration:
            raise ValueError("fold_left given empty iterable with no init")
    else:
        acc = init

    for i in it:
        acc = f(acc, i)

    return acc

检查该代码时,Mypy 会引发以下错误:

10: error: Incompatible types in assignment (expression has type "T", variable has type "S")
13: error: Incompatible types in assignment (expression has type "T", variable has type "S")
13: error: Argument 1 has incompatible type "S"; expected "T"
15: error: Incompatible return value type (got "S", expected "T")

Mypy 似乎不喜欢 - 当init is None- 类型 S 和 T 将相同的事实。有什么方法可以修改代码以便正确输入检查吗?

我尝试使用以下行重载,但没有效果:

@overload
def fold_left(it: Iterable[S], f: Callable[[T, S], T], init: T) -> T:
    ...


@overload
def fold_left(it: Iterable[S], f: Callable[[S, S], S]) -> S:
    ...

标签: pythonmypy

解决方案


这个问题有两个部分。

首先,对函数的外部接口进行类型检查。这是overload相关的地方。

@overload
def fold_left(it: Iterable[S], f: Callable[[T, S], T], init: T) -> T:
    ...


@overload
def fold_left(it: Iterable[S], f: Callable[[S, S], S]) -> S:
    ...

这指定了函数的所有有效签名。

但这并不能(直接)帮助我们函数的类型检查。我们需要独立解决这个问题。

def fold_left(it: Iterable[S], f: Callable[[T, S], T], init: Optional[T] = None) -> T:
    # we can't change the type of it which is already defined to be an 
    # `Iterable`, so in order for `next` to type check we need a new 
    # variable of type `Iterator` 
    itor: Iterator[S] = iter(it) 

    # acc contains or return value it therefore has to be of type `T`
    acc: T

    if init is None:
        try:
            # mypy isn't smart enough to figure out that if `init` is 
            # `None`, then `S` is equal to `T`, we therefore need to tell it
            # that this assignment is correct
            acc = cast(T, next(itor)) 
        except StopIteration:
            raise ValueError("fold_left given empty iterable with no init")
    else:
        acc = init

    for i in itor:
        acc = f(acc, i)

    return acc

这些overload语句确实间接地帮助了我们,因为它们以某种方式限制了我们函数的公共接口,使我们能够推断cast我们在实现中使用的 是正确的。


推荐阅读