首页 > 解决方案 > DRF UniqueTogetherValidator及相关模型属性命名问题

问题描述

我正在尝试配置U​​niqueTogetherValidator以从 DRF 中的 API 创建实例,因为如果我不这样做,Django 在提交非唯一数据时会抛出 500。

问题是TagSerializer我没有使用,project而是使用project.slug作为我重命名为的源project

class TagSerializer(ModelSerializer):

    project = serializers.CharField(source="project.slug")

    class Meta:
        model = Tag
        fields = [
            "id",
            "description",
            "name",
            "project",
        ]
        validators = [
            UniqueTogetherValidator(
                queryset=Tag.objects.all(),
                fields=["project", "name"],
            )
        ]

这里是模型

class Tag(Model):
    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    description = models.TextField()
    name = models.CharField(max_length=100)
    project = models.ForeignKey(Project, on_delete=models.CASCADE)

    class Meta:
        unique_together = [["name", "project"]]

class Project(Model):
    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    name = models.CharField("Project name", max_length=50, unique=True)
    slug = AutoSlugField(
        "Project slug", populate_from="name", always_update=False, unique=True
    )

我想使用,UniqueTogetherValidator,但它会抛出'This field is required.'project字段,因为它在检查是否需要字段时正在寻找源。我要提交的有效负载是:

{
    "name": "myname",
    "description": "some-description",
    "project": "test"
}

标签: pythondjango-rest-frameworkdeserialization

解决方案


我使用了一个从UniqeTogetherValidator开始的自定义验证器,并添加了一个nested_getandnested_getattr来获取嵌套对象的值。这是验证器:

from typing import Any, Dict
from functools import reduce

from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import ValidationError
from rest_framework.utils.representation import smart_repr
from rest_framework.validators import qs_exists, qs_filter


def nested_get(dictionary: Dict, keys: str, default=None) -> Any:
    """
    Apply get to a nested dict given a hierarchical key separated with '.'
    """
    return reduce(
        lambda d, key: d.get(key, default) if isinstance(d, dict) else default,
        keys.split("."),
        dictionary,
    )


def nested_getattr(instance: Any, attrs: str) -> Any:
    """
    Fetch an attribute value from a nested sintance given a hierarchical attrs separated with '.'
    """
    return reduce(getattr, [instance] + attrs.split("."))


class UniqueTogetherRelatedValidator:
    """
    Validator that corresponds to `unique_together = (...)` on a model class.
    Should be applied to the serializer class, not to an individual field.
    """

    message = _("The fields {field_names} must make a unique set.")
    missing_message = _("This field is required.")
    requires_context = True

    def __init__(self, queryset, fields, message=None):
        self.queryset = queryset
        self.fields = fields
        self.message = message or self.message

    def enforce_required_fields(self, attrs, serializer):
        """
        The `UniqueTogetherValidator` always forces an implied 'required'
        state on the fields it applies to.
        """
        if serializer.instance is not None:
            return

        missing_items = {
            field_name: self.missing_message
            for field_name in self.fields
            if serializer.fields[field_name].source not in attrs
        }
        if missing_items:
            raise ValidationError(missing_items, code="required")

    def filter_queryset(self, attrs, queryset, serializer):
        """
        Filter the queryset to all instances matching the given attributes.
        """
        # field names => field sources
        sources = [serializer.fields[field_name].source for field_name in self.fields]

        # If this is an update, then any unprovided field should
        # have it's value set based on the existing instance attribute.
        if serializer.instance is not None:
            for source in sources:
                if source not in attrs:
                    attrs[source.replace(".", "__")] = nested_getattr(
                        serializer.instance, source
                    )

        # Determine the filter keyword arguments and filter the queryset.
        filter_kwargs = {
            source.replace(".", "__"): nested_get(attrs, source) for source in sources
        }
        return qs_filter(queryset, **filter_kwargs)

    @staticmethod
    def exclude_current_instance(queryset, instance):
        """
        If an instance is being updated, then do not include
        that instance itself as a uniqueness conflict.
        """
        if instance is not None:
            return queryset.exclude(pk=instance.pk)
        return queryset

    def __call__(self, attrs, serializer):
        # self.enforce_required_fields(attrs, serializer)
        queryset = self.queryset
        queryset = self.filter_queryset(attrs, queryset, serializer)
        queryset = self.exclude_current_instance(queryset, serializer.instance)

        # Ignore validation if any field is None
        checked_values = [
            value for field, value in attrs.items() if field in self.fields
        ]
        logger.debug(f"{checked_values=}")
        if None not in checked_values and qs_exists(queryset):
            field_names = ", ".join(self.fields)
            message = self.message.format(field_names=field_names)
            raise ValidationError(message, code="unique")

    def __repr__(self):
        return "<%s(queryset=%s, fields=%s)>" % (
            self.__class__.__name__,
            smart_repr(self.queryset),
            smart_repr(self.fields),
        )

在我的序列化程序中,我Meta在原始问题中的类中使用它。


推荐阅读