首页 > 解决方案 > 优化 DRF ModelSerializer 中的查询数量

问题描述

在 Django Rest Framework 的序列化器中,可以向序列化对象中添加比原始模型中更多的数据。这对于在服务器端计算统计信息以及在响应 API 调用时添加此额外信息时很有用。

据我了解,添加额外数据是使用 a 完成的SerializerMethodField,其中每个字段都由一个get_...函数实现。但是,如果您有许多这样的 SerializerMethodField,每个都可以分别查询模型/数据库,以获取可能本质上相同的数据。

是否可以查询一次数据库,将列表/结果存储为ModelSerializer 对象的数据成员,并在许多函数中使用查询集的结果?

这是一个非常简单的示例,仅用于说明:

############## Model

class Employee(Model):
    SALARY_TYPE_CHOICES = (('HR', 'Hourly Rate'), ('YR', 'Annual Salary'))
    salary_type = CharField(max_length=2, choices=SALARY_TYPE_CHOICES, blank=False)
    salary = PositiveIntegerField(blank=True, null=True, default=0)
    company = ForeignKey(Company, related_name='employees')

class Company(Model):
    name = CharField(verbose_name='company name', max_length=100)


############## View

class CompanyView(RetrieveAPIView):
    queryset = Company.objects.all()
    lookup_field='id'
    serializer_class = CompanySerialiser

class CompanyListView(ListAPIView):
    queryset = Company.objects.all()
    serializer_class = CompanySerialiser


############## Serializer

class CompanySerialiser(ModelSerializer):
    number_employees = SerializerMethodField()
    total_salaries_estimate = SerializerMethodField()
    class Meta:
        model = Company
        fields = ['id', 'name',
                  'number_employees',
                  'total_salaries_estimate',
                 ]
    def get_number_employees(self, obj):
        return obj.employees.count()
    def get_total_salaries_estimate(self, obj):
        employee_list = obj.employees.all()
        salaries_estimate = 0
        HOURS_PER_YEAR = 8*200 # 8hrs/day, 200days/year
        for empl in employee_list:
            if empl.salary_type == 'YR':
                salaries_estimate += empl.salary
            elif empl.salary_type == 'HR':
                salaries_estimate += empl.salary * HOURS_PER_YEAR
        return salaries_estimate

序列化器可以优化为:

例子:

class CompanySerialiser(ModelSerializer):
    def __init__(self, *args, **kwargs):
        super(CompanySerialiser, self).__init__(*args, **kwargs)
        self.employee_list = None

    number_employees = SerializerMethodField()
    total_salaries_estimate = SerializerMethodField()
    class Meta:
        model = Company
        fields = ['id', 'name',
                  'number_employees',
                  'total_salaries_estimate',
                 ]
    def _populate_employee_list(self, obj):
        if not self.employee_list: # Query the database only once.
            self.employee_list = obj.employees.all()
    def get_number_employees(self, obj):
        self._populate_employee_list(obj)
        return len(self.employee_list)
    def get_total_salaries_estimate(self, obj):
        self._populate_employee_list(obj)
        salaries_estimate = 0
        HOURS_PER_YEAR = 8*200 # 8hrs/day, 200days/year
        for empl in self.employee_list:
            if empl.salary_type == 'YR':
                salaries_estimate += empl.salary
            elif empl.salary_type == 'HR':
                salaries_estimate += empl.salary * HOURS_PER_YEAR
        return salaries_estimate

这适用于单次检索CompanyView。而且,实际上节省了一次查询/上下文切换/往返数据库;我已经消除了“计数”查询。

但是,它不适用于列表视图CompanyListView,因为似乎序列化程序对象只创建一次并为每个公司重用。因此,只有第一家公司的员工列表存储在对象“ self.employee_list”数据成员中,因此,所有其他公司都错误地从第一家公司获得了数据。

是否有此类问题的最佳实践解决方案?或者我只是错误地使用 ListAPIView,如果是这样,是否有替代方案?

标签: djangodjango-rest-framework

解决方案


CompanySerialiser我认为如果您可以将查询集传递给已获取的数据,则可以解决此问题。

您可以进行以下更改

class CompanyListView(ListAPIView):
    queryset = Company.objects.all().prefetch_related('employee_set')
    serializer_class = CompanySerialiser`

而不是 count 使用len函数,因为 count 再次进行查询。

class CompanySerialiser(ModelSerializer):
    number_employees = SerializerMethodField()
    total_salaries_estimate = SerializerMethodField()
    class Meta:
        model = Company
        fields = ['id', 'name',
                  'number_employees',
                  'total_salaries_estimate',
                 ]
    def get_number_employees(self, obj):
        return len(obj.employees.all())
    def get_total_salaries_estimate(self, obj):
        employee_list = obj.employees.all()
        salaries_estimate = 0
        HOURS_PER_YEAR = 8*200 # 8hrs/day, 200days/year
        for empl in employee_list:
            if empl.salary_type == 'YR':
                salaries_estimate += empl.salary
            elif empl.salary_type == 'HR':
                salaries_estimate += empl.salary * HOURS_PER_YEAR
        return salaries_estimate

由于数据是预取的,序列化程序不会对all. 但请确保您没有进行任何类型的过滤,因为在这种情况下将执行另一个查询。


推荐阅读