首页 > 技术文章 > python-Django-rest_framework用户级联及用户访问节流

Vera-y 2019-12-27 16:33 原文

models.py

# models.py
from django.db import models

# 用户表
class User(models.Model):
    u_name = models.CharField(max_length=32, unique=True)
    u_password = models.CharField(max_length=256)


# 地址表
class Address(models.Model):
    a_address = models.CharField(max_length=128)
    # null=True 允许外键为空
    a_user = models.ForeignKey(User, on_delete=True, null=True)

views.py

# views.py
import uuid

from django.core.cache import cache
from django.shortcuts import render

# Create your views here.
from rest_framework import exceptions, viewsets, status
from rest_framework.generics import CreateAPIView, RetrieveAPIView
from rest_framework.response import Response

from myapp.auth import LoginAuthentication
from myapp.models import User, Address
from myapp.permissions import LoginPermissions
from myapp.serializers import UserSerializer, AddressSerializer


def test(request):
    return render(request, 'test.html')


# 用户的创建及登录
class UsersAPIView(CreateAPIView):
    serializer_class = UserSerializer
    queryset = User.objects.all()

    # 登录和注册都是post请求
    def post(self, request, *args, **kwargs):
        action = request.query_params.get('action')
        if action == 'login':
            u_name = request.data.get('u_name')
            u_password = request.data.get('u_password')
            try:
                user = User.objects.get(u_name=u_name)
                if user.u_password == u_password:
                    token = uuid.uuid4().hex
                    cache.set(token, user.id)
                    data = {
                        'msg': 'ok',
                        'status': 200,
                        'token': token
                    }
                    return Response(data)
                else:
                    raise exceptions.AuthenticationFailed  # 密码错误
            except User.DoesNotExist:
                raise exceptions.AuthenticationFailed  # 用户名错误
        elif action == 'register':
            return self.create(request, *args, **kwargs)
        else:
            raise exceptions.ParseError    # 既不是登录也不是注册


# 单个用户,只用于展示
class UserAPIView(RetrieveAPIView):
    serializer_class = UserSerializer
    queryset = User.objects.all()
    # 添加地址前需要进行用户认证
    authentication_classes = (LoginAuthentication,)
    # 权限控制
    permission_classes = (LoginPermissions,)

    # 验证登录后只能获取当前用户的数据
    # 判定用户数据只能是用户登录的用户数据,不能获取其他用户的用户数据
    # RetrieveAPIView->get->retrieve
    def retrieve(self, request, *args, **kwargs):
        # instance = self.get_object()
        # # 进行了数据库查询操作
        # if instance.id != request.user.id:
        #     raise exceptions.AuthenticationFailed   # 当前已登录用户的id和要获取值id不一致
        # serializer = self.get_serializer(instance)
        # return Response(serializer.data)

        # 在路径中拿到id值
        if kwargs.get('pk') != request.user.id:   # 也可在中间件中进行验证
            raise exceptions.AuthenticationFailed   # 当前已登录用户的id和要获取值id不一致
        instance = self.get_object()
        serializer = self.get_serializer(instance)
        return Response(serializer.data)


# 地址的增删改查
class AddressesAPIView(viewsets.ModelViewSet):
    serializer_class = AddressSerializer
    queryset = Address.objects.all()
    # 添加地址前需要进行用户认证
    authentication_classes = (LoginAuthentication,)
    # 权限控制
    permission_classes = (LoginPermissions,)

    # 重写create,在创建的时候就关联用户
    def create(self, request, *args, **kwargs):
        # 地址的创建
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        self.perform_create(serializer)
        # 拿到用户
        user = request.user
        # 拿到对应地址的id
        a_id = serializer.data.get('id')
        # 把是哪个用户绑定到地址表的a_user中
        address = Address.objects.get(pk=a_id)
        address.a_user = user
        address.save()

        headers = self.get_success_headers(serializer.data)
        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

    # 实现用户数据过滤
    def list(self, request, *args, **kwargs):
        # queryset = self.filter_queryset(self.get_queryset())
        # 用户数据过滤,queryset为序列化后的数据
        queryset = self.filter_queryset(self.queryset.filter(a_user=request.user))
        serializer = self.get_serializer(queryset, many=True)
        return Response(serializer.data)


class AddressAPIView(viewsets.ModelViewSet):
    serializer_class = AddressSerializer
    queryset = Address.objects.all()

urls.py

# urls.py
from django.urls import path
from . import views

app_name = 'my_app'
urlpatterns = [
    path('test/', views.test),
    path('user/', views.UsersAPIView.as_view()),
    path('user/<int:pk>/', views.UserAPIView.as_view()),
    path('address/', views.AddressesAPIView.as_view({
        'post': 'create',
        'get': 'list',
    })),
    path('address/<int:pk>/', views.AddressAPIView.as_view({
        'get': 'retrieve',
    })),
]

serializers.py

# 序列化
from rest_framework import serializers

# serializers.py
from myapp.models import User, Address


# 地址表序列化
class AddressSerializer(serializers.ModelSerializer):
    class Meta:
        model = Address
        fields = ('id', 'a_address')


# 用户表序列化
class UserSerializer(serializers.ModelSerializer):
    # 级联数据的级联显示
    # address_set 是User模型的隐形属性
    # 若这里想自定义名称,我们可以在Address模型的外键关联中添加related_name属性自定义名称(related_name='custom-name')
    address_set = AddressSerializer(many=True, read_only=True)

    class Meta:
        model = User
        fields = ('id', 'u_name', 'u_password','address_set')

permissions.py

# permissions.py
# 权限控制
from rest_framework.permissions import BasePermission

from myapp.models import User


class LoginPermissions(BasePermission):
    def has_permission(self, request, view):
        # 验证用户是否在用户列表里
        if isinstance(request.user, User):
            return True
        return False

auth.py

# auth.py
# 用户认证
from django.core.cache import cache
from rest_framework.authentication import BaseAuthentication

from myapp.models import User


class LoginAuthentication(BaseAuthentication):
    def authenticate(self, request):
        # 地址的增删改查都需要用户认证
        # if request.method == 'GET':
        try:
            token = request.query_params.get('token')
            u_id = cache.get(token)
            user = User.objects.get(pk=u_id)
            return user, token
        except:
            return

throttles.py

# throttles.py
# 节流
from rest_framework.throttling import SimpleRateThrottle

from myapp.models import User


class UserThrottle(SimpleRateThrottle):
    scope = 'user'

    def get_cache_key(self, request, view):
        # if request.user.is_authenticated:
        if isinstance(request.user, User):   # 若用户存在
            # ident = request.user.pk
            ident = request.auth   # 将auth作为唯一标识

        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

setttings.py

# 节流的全局配置
REST_FRAMEWORK ={
    # APIView ->throttle_classes->api_settings->DEFAULTS->DEFAULT_THROTTLE_CLASSES
    'DEFAULT_THROTTLE_CLASSES':[
        'myapp.throttles.UserThrottle'   # 找到包下对应的节流器
    ],
    'DEFAULT_THROTTLE_RATES': {
        # throttling -> SimpleRateThrottle —> parse_rate -> duration
        'user': '5/m'   # 一分钟可以访问5次
    }
}

推荐阅读