首页 > 解决方案 > N-dim 数字列表类型

问题描述

如何在 Python 3.7 中为 N 维数值列表(张量)定义一种类型?这将用作 Pydantic 的道具之一BaseModel

我想要类似的东西

from typing import List, Union

NumericalList = Union[
    int, float,
    List[int], List[float],
    List[List[int]], List[List[float]],
    ...
]
n1: NumericalList = [0]
n3: NumericalList = [ [ [0, 1, 2], [1, 1, 2] ],
                      [ [1, 2, 3], [0, 1, 3] ],
                    ]

我知道在类上可以编写字符串文字来指示孩子属于同一类型。(或者只是添加from __future__ import annotations。)我想要一个可迭代/可切片的,而不是通过道具访问。

认为递归定义可能会起作用,但是在按模块AttributeError: __forward_arg__的多个级别之后它会失败。deepcopytyping

NumericalList = Union[int, float, List["NumericalList"]]  # AttributeError: __forward_arg__

但是请注意,在使用 Pydantic 而不是在 Python 的 IDE 中时,这会失败。这是 Pydantic 特有的,还是错误的做法?

标签: python-3.xfastapipython-typingpydantic

解决方案


虽然我希望 pydantic 支持本机递归类型,但您可以使用具有pydantic 严格类型的 pydantic 自定义根类型模型来确保浮点值不会变成 int

from __future__ import annotations
from typing import Union, List
from pydantic import BaseModel, StrictInt, StrictFloat

class NumericalList(BaseModel):
    __root__: Union[StrictInt, StrictFloat, List[NumericalList]]


NumericalList.update_forward_refs()


n1: NumericalList = NumericalList.parse_obj([0])
"""
NumericalList(__root__=[NumericalList(__root__=0)])
"""

n2: NumericalList = NumericalList.parse_obj(
                    [ [ [0, 1, 2], [1, 1, 2] ],
                      [ [1, 2, 3], [0, 1, 3] ],
                    ]
                   )
"""
NumericalList(__root__=[
    NumericalList(__root__=[
        NumericalList(__root__=[NumericalList(__root__=0), NumericalList(__root__=1), NumericalList(__root__=2)]),
        NumericalList(__root__=[NumericalList(__root__=1), NumericalList(__root__=1), NumericalList(__root__=2)])
    ]),
    NumericalList(__root__=[
        NumericalList(__root__=[NumericalList(__root__=1), NumericalList(__root__=2), NumericalList(__root__=3)]),
        NumericalList(__root__=[NumericalList(__root__=0), NumericalList(__root__=1), NumericalList(__root__=3)])
    ])
])

"""


n1.dict()
"""
{"__root__": [0]}
"""

n2.dict()
"""
{"__root__": [ 
                  [ [0, 1, 2], [1, 1, 2] ],
                  [ [1, 2, 3], [0, 1, 3] ],
             ]}
"""

您可以扩展该类以涵盖列表功能,但这是可选的

class NumericalList(BaseModel):
    __root__: Union[StrictInt, StrictFloat, List[NumericalList]]

    def __iter__(self):
        return iter(self.__root__)

    def __getitem__(self, index):
        return self.__root__[index]

    def __setitem__(self, index, value):
        self.__root__[index] = value


NumericalList.update_forward_refs()

n3: NumericalList = NumericalList.parse_obj([0, 5, [1]])
for i in n3:
    print(i)
"""
__root__=0
__root__=5
__root__=[NumericalList(__root__=1)]
"""

如果你想获得原生类型,你可以做

n3.dict()
"""
{'__root__': [0, 5, [1]]}
"""

推荐阅读