Django REST框架与API开发

课程目标

  • 理解Django REST Framework (DRF) 的核心概念
  • 掌握序列化器的定义和使用
  • 学会创建和管理API端点
  • 了解API认证和权限控制

Django REST Framework概述

Django REST Framework (DRF) 是一个强大且灵活的工具包,用于构建Web API。它建立在Django之上,提供了构建Web API所需的各种功能,包括序列化、认证、权限、限流等。

DRF的主要特性:

  • 强大的序列化系统
  • 可浏览的Web界面
  • 认证和权限系统
  • 限流控制
  • 内容协商
  • 广泛的文档支持

安装和配置

安装DRF

pip install djangorestframework

配置settings.py

# settings.py
INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    'rest_framework',  # 添加DRF
    'myapp',  # 你的应用
]

# DRF配置
REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.TokenAuthentication',
    ],
    'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.IsAuthenticated',
    ],
    'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
    'PAGE_SIZE': 20,
    'DEFAULT_RENDERER_CLASSES': [
        'rest_framework.renderers.JSONRenderer',
        'rest_framework.renderers.BrowsableAPIRenderer',
    ],
}

序列化器(Serializers)

基础序列化器

# serializers.py
from rest_framework import serializers
from .models import Article, Category

class CategorySerializer(serializers.ModelSerializer):
    class Meta:
        model = Category
        fields = ['id', 'name', 'slug', 'created_at']
        read_only_fields = ['slug', 'created_at']

class ArticleSerializer(serializers.ModelSerializer):
    author_name = serializers.CharField(source='author.username', read_only=True)
    category_name = serializers.CharField(source='category.name', read_only=True)
    tags_list = serializers.SerializerMethodField()
    
    class Meta:
        model = Article
        fields = [
            'id', 'title', 'content', 'summary', 'author', 'author_name',
            'category', 'category_name', 'status', 'tags', 'tags_list',
            'created_at', 'updated_at', 'view_count'
        ]
        read_only_fields = ['author', 'created_at', 'updated_at', 'view_count']
    
    def get_tags_list(self, obj):
        """获取标签列表"""
        return [tag.name for tag in obj.tags.all()]
    
    def validate_title(self, value):
        """验证标题"""
        if len(value) < 5:
            raise serializers.ValidationError("标题至少需要5个字符")
        return value
    
    def validate_content(self, value):
        """验证内容"""
        if len(value) < 10:
            raise serializers.ValidationError("内容至少需要10个字符")
        return value
    
    def create(self, validated_data):
        """自定义创建逻辑"""
        tags_data = validated_data.pop('tags', [])
        article = Article.objects.create(**validated_data)
        
        if tags_data:
            article.tags.set(tags_data)
        
        return article
    
    def update(self, instance, validated_data):
        """自定义更新逻辑"""
        tags_data = validated_data.pop('tags', None)
        
        # 更新普通字段
        for attr, value in validated_data.items():
            setattr(instance, attr, value)
        instance.save()
        
        # 更新多对多字段
        if tags_data is not None:
            instance.tags.set(tags_data)
        
        return instance

嵌套序列化器

class ArticleDetailSerializer(serializers.ModelSerializer):
    author = serializers.StringRelatedField(read_only=True)
    category = CategorySerializer(read_only=True)
    tags = serializers.StringRelatedField(many=True, read_only=True)
    comments = serializers.SerializerMethodField()
    
    class Meta:
        model = Article
        fields = [
            'id', 'title', 'content', 'summary', 'author', 'category',
            'status', 'tags', 'comments', 'created_at', 'updated_at', 'view_count'
        ]
    
    def get_comments(self, obj):
        """获取文章评论"""
        from myapp.serializers import CommentSerializer
        comments = obj.comments.filter(is_approved=True)
        return CommentSerializer(comments, many=True).data

API视图开发

基于函数的视图

# views.py
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework import status
from rest_framework.decorators import permission_classes
from rest_framework.permissions import IsAuthenticated
from .models import Article
from .serializers import ArticleSerializer

@api_view(['GET'])
def article_list(request):
    """获取文章列表"""
    articles = Article.objects.filter(status='published').order_by('-created_at')
    
    # 分页
    from rest_framework.pagination import PageNumberPagination
    paginator = PageNumberPagination()
    paginator.page_size = 10
    result_page = paginator.paginate_queryset(articles, request)
    
    serializer = ArticleSerializer(result_page, many=True)
    return paginator.get_paginated_response(serializer.data)

@api_view(['GET'])
def article_detail(request, pk):
    """获取文章详情"""
    try:
        article = Article.objects.get(pk=pk, status='published')
        # 增加浏览次数
        article.view_count += 1
        article.save(update_fields=['view_count'])
        
        serializer = ArticleDetailSerializer(article)
        return Response(serializer.data)
    except Article.DoesNotExist:
        return Response({'error': '文章不存在'}, status=status.HTTP_404_NOT_FOUND)

@api_view(['POST'])
@permission_classes([IsAuthenticated])
def create_article(request):
    """创建文章"""
    serializer = ArticleSerializer(data=request.data)
    if serializer.is_valid():
        serializer.save(author=request.user)
        return Response(serializer.data, status=status.HTTP_201_CREATED)
    return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

@api_view(['PUT', 'PATCH'])
@permission_classes([IsAuthenticated])
def update_article(request, pk):
    """更新文章"""
    try:
        article = Article.objects.get(pk=pk)
        
        # 检查权限
        if article.author != request.user:
            return Response({'error': '无权限修改'}, status=status.HTTP_403_FORBIDDEN)
        
        serializer = ArticleSerializer(
            article, 
            data=request.data, 
            partial=isinstance(request, 'PATCH')
        )
        
        if serializer.is_valid():
            serializer.save()
            return Response(serializer.data)
        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
    except Article.DoesNotExist:
        return Response({'error': '文章不存在'}, status=status.HTTP_404_NOT_FOUND)

@api_view(['DELETE'])
@permission_classes([IsAuthenticated])
def delete_article(request, pk):
    """删除文章"""
    try:
        article = Article.objects.get(pk=pk)
        
        if article.author != request.user:
            return Response({'error': '无权限删除'}, status=status.HTTP_403_FORBIDDEN)
        
        article.delete()
        return Response(status=status.HTTP_204_NO_CONTENT)
    except Article.DoesNotExist:
        return Response({'error': '文章不存在'}, status=status.HTTP_404_NOT_FOUND)

基于类的视图

from rest_framework import generics, permissions, status
from rest_framework.response import Response
from django.shortcuts import get_object_or_404

class ArticleListCreateView(generics.ListCreateAPIView):
    queryset = Article.objects.filter(status='published')
    serializer_class = ArticleSerializer
    permission_classes = [permissions.IsAuthenticatedOrReadOnly]
    
    def get_queryset(self):
        queryset = Article.objects.filter(status='published')
        
        # 搜索
        search = self.request.query_params.get('search', '')
        if search:
            queryset = queryset.filter(title__icontains=search)
        
        # 分类筛选
        category = self.request.query_params.get('category', '')
        if category:
            queryset = queryset.filter(category__name=category)
        
        # 标签筛选
        tag = self.request.query_params.get('tag', '')
        if tag:
            queryset = queryset.filter(tags__name=tag)
        
        return queryset.order_by('-created_at')
    
    def perform_create(self, serializer):
        serializer.save(author=self.request.user)

class ArticleDetailView(generics.RetrieveUpdateDestroyAPIView):
    queryset = Article.objects.all()
    serializer_class = ArticleDetailSerializer
    permission_classes = [permissions.IsAuthenticatedOrReadOnly]
    
    def get_permissions(self):
        """
        实现细粒度权限控制
        - 任何人都可以查看
        - 只有作者可以编辑/删除
        """
        if self.request.method in ['PUT', 'PATCH', 'DELETE']:
            permission_classes = [permissions.IsAuthenticated]
        else:
            permission_classes = [permissions.AllowAny]
        
        return [permission() for permission in permission_classes]
    
    def check_object_permissions(self, request, obj):
        """检查对象级别的权限"""
        super().check_object_permissions(request, obj)
        
        if request.method in ['PUT', 'PATCH', 'DELETE']:
            if obj.author != request.user and not request.user.is_staff:
                self.permission_denied(
                    request,
                    message='您没有权限执行此操作',
                    code='permission_denied'
                )

视图集(ViewSets)

基础视图集

from rest_framework import viewsets, mixins
from rest_framework.decorators import action
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.filters import SearchFilter, OrderingFilter

class ArticleViewSet(viewsets.ModelViewSet):
    queryset = Article.objects.all()
    serializer_class = ArticleSerializer
    permission_classes = [permissions.IsAuthenticatedOrReadOnly]
    filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter]
    filterset_fields = ['status', 'category', 'author']
    search_fields = ['title', 'content', 'tags__name']
    ordering_fields = ['created_at', 'view_count', 'like_count']
    ordering = ['-created_at']
    
    def get_queryset(self):
        queryset = super().get_queryset()
        
        # 只有管理员可以看到所有文章
        if not self.request.user.is_staff:
            queryset = queryset.filter(status='published')
        
        return queryset
    
    @action(detail=True, methods=['post'], permission_classes=[permissions.IsAuthenticated])
    def like(self, request, pk=None):
        """点赞文章"""
        article = self.get_object()
        article.like_count += 1
        article.save(update_fields=['like_count'])
        return Response({'status': 'liked', 'likes': article.like_count})
    
    @action(detail=True, methods=['post'], permission_classes=[permissions.IsAuthenticated])
    def unlike(self, request, pk=None):
        """取消点赞"""
        article = self.get_object()
        if article.like_count > 0:
            article.like_count -= 1
            article.save(update_fields=['like_count'])
        return Response({'status': 'unliked', 'likes': article.like_count})
    
    @action(detail=False, methods=['get'])
    def my_articles(self, request):
        """获取当前用户的全部文章"""
        if not request.user.is_authenticated:
            return Response({'error': '请先登录'}, status=status.HTTP_401_UNAUTHORIZED)
        
        articles = self.queryset.filter(author=request.user)
        serializer = self.get_serializer(articles, many=True)
        return Response(serializer.data)

# urls.py
from rest_framework.routers import DefaultRouter

router = DefaultRouter()
router.register(r'articles', ArticleViewSet, basename='article')
urlpatterns = router.urls

认证和权限

Token认证

# settings.py
REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework.authentication.TokenAuthentication',
        'rest_framework.authentication.SessionAuthentication',
    ],
    'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.IsAuthenticated',
    ],
}

# 安装并配置Token认证
"""
INSTALLED_APPS = [
    # ...
    'rest_framework.authtoken',
]

# 运行迁移
python manage.py migrate
"""

# 自定义权限类
from rest_framework.permissions import BasePermission

class IsOwnerOrReadOnly(BasePermission):
    """
    自定义权限类:只有对象的创建者可以编辑
    """
    def has_object_permission(self, request, view, obj):
        # 读取权限对所有人开放
        if request.method in permissions.SAFE_METHODS:
            return True
        
        # 写入权限只对对象的创建者开放
        return hasattr(obj, 'author') and obj.author == request.user

class IsStaffOrTargetUser(BasePermission):
    """
    管理员可以访问所有用户,普通用户只能访问自己的信息
    """
    def has_permission(self, request, view):
        if request.user.is_staff:
            return True
        return request.user.is_authenticated
    
    def has_object_permission(self, request, view, obj):
        if request.user.is_staff:
            return True
        return obj == request.user

JWT认证

# 安装
"""
pip install djangorestframework-simplejwt
"""

# settings.py
REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework_simplejwt.authentication.JWTAuthentication',
    ],
}

from datetime import timedelta

SIMPLE_JWT = {
    'ACCESS_TOKEN_LIFETIME': timedelta(minutes=60),
    'REFRESH_TOKEN_LIFETIME': timedelta(days=7),
    'ROTATE_REFRESH_TOKENS': True,
    'BLACKLIST_AFTER_ROTATION': True,
}

# URLs
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView

urlpatterns = [
    path('api/token/', TokenObtainPairView.as_view(), name='token_obtain_pair'),
    path('api/token/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
]

API文档

使用drf-yasg生成API文档

# 安装
"""
pip install drf-yasg
"""

# settings.py
INSTALLED_APPS = [
    # ...
    'drf_yasg',
]

# urls.py
from django.urls import path
from drf_yasg.views import get_schema_view
from drf_yasg import openapi
from rest_framework import permissions

schema_view = get_schema_view(
    openapi.Info(
        title="博客API文档",
        default_version='v1',
        description="博客应用的API接口文档",
        terms_of_service="https://www.google.com/policies/terms/",
        contact=openapi.Contact(email="admin@example.com"),
        license=openapi.License(name="BSD License"),
    ),
    public=True,
    permission_classes=[permissions.AllowAny],
)

urlpatterns = [
    path('swagger/', schema_view.with_ui(
        'swagger',
        cache_timeout=0
    ), name='schema-swagger-ui'),
    path('redoc/', schema_view.with_ui(
        'redoc',
        cache_timeout=0
    ), name='schema-redoc'),
]

API测试

使用APIClient进行测试

# tests.py
from django.test import TestCase
from rest_framework.test import APIClient
from rest_framework import status
from django.contrib.auth.models import User
from .models import Article

class ArticleAPITest(TestCase):
    def setUp(self):
        self.client = APIClient()
        self.user = User.objects.create_user(
            username='testuser',
            password='testpass123'
        )
        self.article = Article.objects.create(
            title='Test Article',
            content='Test content',
            author=self.user,
            status='published'
        )
    
    def test_get_articles_list(self):
        """测试获取文章列表"""
        response = self.client.get('/api/articles/')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
    
    def test_create_article_authenticated(self):
        """测试认证用户创建文章"""
        self.client.force_authenticate(user=self.user)
        data = {
            'title': 'New Article',
            'content': 'New content',
            'status': 'published'
        }
        response = self.client.post('/api/articles/', data)
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
    
    def test_create_article_unauthenticated(self):
        """测试未认证用户无法创建文章"""
        data = {
            'title': 'New Article',
            'content': 'New content',
            'status': 'published'
        }
        response = self.client.post('/api/articles/', data)
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

API性能优化

缓存配置

# settings.py
CACHES = {
    'default': {
        'BACKEND': 'django_redis.cache.RedisCache',
        'LOCATION': 'redis://127.0.0.1:6379/1',
        'OPTIONS': {
            'CLIENT_CLASS': 'django_redis.client.DefaultClient',
        }
    }
}

# 视图中的缓存使用
from django.core.cache import cache
from django.views.decorators.cache import cache_page
from rest_framework.decorators import api_view

@api_view(['GET'])
@cache_page(60 * 15)  # 缓存15分钟
def cached_article_list(request):
    # 获取缓存的数据
    cache_key = 'article_list'
    articles = cache.get(cache_key)
    
    if articles is None:
        articles = list(Article.objects.filter(status='published').values())
        cache.set(cache_key, articles, timeout=60*15)
    
    return Response(articles)

# 使用装饰器缓存
from rest_framework_extensions.cache.decorators import cache_response

class ArticleViewSet(viewsets.ModelViewSet):
    # ...
    
    @cache_response(timeout=60*15)
    def list(self, request, *args, **kwargs):
        return super().list(request, *args, **kwargs)

查询优化

# 使用select_related和prefetch_related优化查询
class ArticleListCreateView(generics.ListCreateAPIView):
    serializer_class = ArticleSerializer
    
    def get_queryset(self):
        return Article.objects.select_related(
            'author', 'category'
        ).prefetch_related(
            'tags', 'comments'
        ).filter(status='published')

课程总结

本节课我们学习了Django REST Framework的各个方面,包括序列化器、API视图开发、认证授权、API文档生成和性能优化等。DRF是构建现代化Web API的强大工具,掌握它对于开发前后端分离的应用至关重要。