首页 > 解决方案 > 是否可以在中间件中过滤掉一些数据库行?

问题描述

对不起这个问题的长度,我不能让它更短并且仍然有意义。

我们有一个非常简单的应用程序,其中包含两个简单的模型CompanyBuilding,它们之间具有多对多的关系。每个都有一个restricted属性。User是一个普通的 DjangoUser类,除了我们添加了一个show属性。

# models.py

class User(AbstractUser):
    show = models.BooleanField(default=True)


class Company(models.Model):
    name = models.CharField(max_length=100)
    restricted = models.BooleanField(default=False)


class Building(models.Model):
    name = models.CharField(max_length=100)
    restricted = models.BooleanField(default=False)
    companies = models.ManyToManyField(Company, related_name='buildings')

视图是常规的 Django REST Framework 视图集,序列化器尽可能简单:

# views.py

class CompanyViewSet(ModelViewSet):
    queryset = Company.objects.all()
    serializer_class = CompanySerializer


class BuildingViewSet(ModelViewSet):
    queryset = Building.objects.all()
    serializer_class = BuildingSerializer


# serializers.py

class CompanySerializer(serializers.ModelSerializer):
    class Meta:
        model = Company
        fields = '__all__'


class BuildingSerializer(serializers.ModelSerializer):
    class Meta:
        model = Building
        fields = '__all__'

现在我们要实现这个行为:如果user.showFalse,用户一定不能看到(在视图中)restricted Companyand Building

换句话说,如果johnUserand john.show is Falsejohn可以(在视图中)看到normal_companyand normal_building,但看不到 restricted_companyor restricted_building

为了实现这一点,如果可能的话,我们不想编辑视图/序列化程序,因为它们有很多(这是一个更大的真实项目的简化版本)。

我的团队考虑使用中间件。我们试图动态改变Company.objectsBuilding.objects

# middleware.py

class FilterMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        user = get_user()  # Get the user somehow.

        if not user.show:
            # Replace objects.
            for model in (Company, Building):
                model.objects = model.objects.filter(restricted=False)

        response = self.get_response(request)
        return response

但是,由于某种未知的原因,这不起作用:john仍然可以看到受限制的公司。然后我们尝试动态更新该django.db.models.Manager.get_queryset方法:

# middleware.py
class FilterMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        user = get_user()

        if not user.show:
            models.Manager.get_queryset = get_restricted_queryset

        response = self.get_response(request)
        return response


def get_restricted_queryset(self, *args, **kwargs):
    # Conditions for the filter later.
    hiding_conditions = {
        "Company": Q(restricted=True),
        "Building": Q(companies__restricted=True) | Q(restricted=True),
    }

    model_name = self.model.__name__

    if model_name in hiding_conditions:
        # We must filter the model out, so apply the hiding conditions.
        hiding_condition = hiding_conditions[model_name]

        return self._queryset_class(
            model=self.model, using=self._db, hints=self._hints
        ).exclude(hiding_condition)
    else:
        return self._queryset_class(model=self.model, using=self._db, hints=self._hints)

但这不起作用——这很奇怪:当我获取公司时,它实际上只是User由 调用的模型get_queryset,所以get_restricted_queryset没有任何效果。

现在,我们真的被困住了。有没有人有可以帮助我们的想法?或者只是中间件不应该做这样的事情?

标签: pythondjangodjango-rest-framework

解决方案


您不需要中间件(因为中间件只处理请求和响应,它们是 QuerySets 之下的抽象级别)。您可以在 DRF 中使用自定义 FilterBackend 执行此操作,如下所示: 更新:也可以过滤嵌套公司!

from rest_framework import filters

class IsRestrictedFilterBackend(filters.BaseFilterBackend):
    def filter_queryset(self, request, queryset, view):
        if request.user and user.is_authenticated and not user.show:
            if queryset.model and queryset.model in [Company, Building]:
                queryset = queryset.filter(restricted=False)
                if queryset.model == Building:
                    return queryset.filter(companies__restricted=False)
        return queryset

然后将此过滤器后端添加到您的设置中:

REST_FRAMEWORK = {
    'DEFAULT_FILTER_BACKENDS': ['yourapp.filter_backends.IsRestrictedFilterBackend']
}

...或者您可以在 ViewSet 的基础上使用它:

class BuildingViewSet(ModelViewSet):
    queryset = Building.objects.all()
    serializer_class = BuildingSerializer
    filter_backends = [yourapp.filter_backends.IsRestrictedFilterBackend]

推荐阅读