首页 > 解决方案 > Pydantic 获取字段类型提示

问题描述

我想将我的 ML 模型的元数据存储在 pydantic 中。是否有正确的方法来访问字段类型?我知道你可以做到BaseModel.__fields__['my_field'].type_,但我认为有更好的方法。

我想这样做,以便如果 aBaseModel无法实例化,则非常清楚创建此缺少的字段需要哪些数据以及要使用哪些方法。像这样的东西:

from pydantic import BaseModel
import pandas as pd

# basic model
class Metadata(BaseModel):
    peaks_per_day: float


class PeaksPerDayType(float):
    data_required = pd.Timedelta("180D")
    data_type = "foo"

    @classmethod
    def determine(cls, data):
        return cls(data)

# use our custom float
class Metadata(BaseModel):
    peaks_per_day: PeaksPerDayType

def get_data(data_type, required_data):
    # get enough of the appropriate data type
    return [1]


# Initial data we have
metadata_json = {}
try:
    metadata = Metadata(**metadata_json)
    # peaks per day is missing
except Exception as e:
    error_msg = e

missing_fields = error_msg.errors()
missing_fields = [missing_field['loc'][0] for missing_field in missing_fields]

# For each missing field use its type hint to find what data is required to 
# determine it and access the method to determine the value

new_data = {}
for missing_field in missing_fields:
    req_data = Metadata[missing_field].data_required
    data_type = Metadata[missing_field].data_type
    data = get_data(data_type=data_type, required_data=req_data)

    new_data[missing_field] = Metadata[missing_field].determine(data)

metadata = Metadata(**metadata_json, **new_data)

标签: pythonpydantic

解决方案


如果您不需要处理嵌套类,这应该可以

from pydantic import BaseModel, ValidationError

import typing

class PeaksPerDayType(float):
    data_required = 123.22
    data_type = "foo"

    @classmethod
    def determine(cls, data):
        return cls(data)

# use our custom float
class Metadata(BaseModel):
    peaks_per_day: PeaksPerDayType

def get_data(data_type, required_data):
    # get enough of the appropriate data type
    return required_data

metadata_json = {}
try:
    Metadata(**metadata_json)
except ValidationError as e:
    field_to_type = typing.get_type_hints(Metadata)
    missing_fields = []
    for error in e.errors():
        if error['type']=='value_error.missing':
            missing_fields.append(error['loc'][0])
        else:
            raise

    new_data = {}
    for field in missing_fields:
        type_ = field_to_type[field]
        new_data[field] = get_data(type_.data_type, type_.data_required)

    print(Metadata(**metadata_json, **new_data))

peaks_per_day=123.22

我不太确定data_typeor的意义何在get_data,但我认为它是您想要添加的一些内部逻辑


推荐阅读