Django RESTful API 开发
2026/3/20大约 11 分钟
Django RESTful API 开发
一、Django REST Framework 简介
1.1 安装配置
pip install djangorestframework
pip install django-filter # 过滤器支持
pip install markdown # Markdown 支持
# settings.py
INSTALLED_APPS = [
# ...
'rest_framework',
'django_filters',
]
REST_FRAMEWORK = {
# 认证
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.TokenAuthentication',
],
# 权限
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
],
# 分页
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'PAGE_SIZE': 20,
# 过滤
'DEFAULT_FILTER_BACKENDS': [
'django_filters.rest_framework.DjangoFilterBackend',
'rest_framework.filters.SearchFilter',
'rest_framework.filters.OrderingFilter',
],
# 渲染
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
],
# 限流
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle',
],
'DEFAULT_THROTTLE_RATES': {
'anon': '100/hour',
'user': '1000/hour',
},
# 版本控制
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning',
# 异常处理
'EXCEPTION_HANDLER': 'api.exceptions.custom_exception_handler',
}
1.2 URL 配置
# urls.py
from django.urls import path, include
from rest_framework.routers import DefaultRouter
from .views import ArticleViewSet, CategoryViewSet
router = DefaultRouter()
router.register('articles', ArticleViewSet)
router.register('categories', CategoryViewSet)
urlpatterns = [
path('api/', include(router.urls)),
path('api-auth/', include('rest_framework.urls')),
]
二、序列化器(Serializer)
2.1 基础序列化器
from rest_framework import serializers
from .models import Article, Category, Tag
class ArticleSerializer(serializers.Serializer):
"""基础序列化器"""
id = serializers.IntegerField(read_only=True)
title = serializers.CharField(max_length=200)
content = serializers.CharField()
status = serializers.ChoiceField(choices=Article.STATUS_CHOICES)
created_at = serializers.DateTimeField(read_only=True)
author_id = serializers.IntegerField()
def create(self, validated_data):
return Article.objects.create(**validated_data)
def update(self, instance, validated_data):
instance.title = validated_data.get('title', instance.title)
instance.content = validated_data.get('content', instance.content)
instance.status = validated_data.get('status', instance.status)
instance.save()
return instance
2.2 ModelSerializer
class ArticleSerializer(serializers.ModelSerializer):
"""模型序列化器"""
class Meta:
model = Article
fields = ['id', 'title', 'content', 'category', 'tags', 'status', 'created_at']
# 或使用 '__all__'
# fields = '__all__'
# 或排除某些字段
# exclude = ['created_at', 'updated_at']
read_only_fields = ['id', 'created_at', 'author']
extra_kwargs = {
'content': {'required': True, 'allow_blank': False},
'status': {'default': 'draft'},
}
class ArticleDetailSerializer(serializers.ModelSerializer):
"""详情序列化器"""
author = serializers.StringRelatedField()
category = serializers.StringRelatedField()
tags = serializers.StringRelatedField(many=True)
# 或使用嵌套序列化器
# category = CategorySerializer()
class Meta:
model = Article
fields = '__all__'
2.3 字段类型
class FieldExamplesSerializer(serializers.Serializer):
# 基础字段
string_field = serializers.CharField(max_length=100)
integer_field = serializers.IntegerField(min_value=0, max_value=100)
float_field = serializers.FloatField()
decimal_field = serializers.DecimalField(max_digits=10, decimal_places=2)
boolean_field = serializers.BooleanField()
# 日期时间
date_field = serializers.DateField()
datetime_field = serializers.DateTimeField()
time_field = serializers.TimeField()
duration_field = serializers.DurationField()
# 选择字段
choice_field = serializers.ChoiceField(choices=['A', 'B', 'C'])
multiple_choice_field = serializers.MultipleChoiceField(choices=['A', 'B', 'C'])
# 文件字段
file_field = serializers.FileField()
image_field = serializers.ImageField()
# 特殊字段
email_field = serializers.EmailField()
url_field = serializers.URLField()
uuid_field = serializers.UUIDField()
ip_field = serializers.IPAddressField()
slug_field = serializers.SlugField()
regex_field = serializers.RegexField(regex=r'^\d{6}$')
# 复合字段
list_field = serializers.ListField(child=serializers.IntegerField())
dict_field = serializers.DictField(child=serializers.CharField())
json_field = serializers.JSONField()
# 只读/只写
read_only_field = serializers.CharField(read_only=True)
write_only_field = serializers.CharField(write_only=True)
# 方法字段
method_field = serializers.SerializerMethodField()
def get_method_field(self, obj):
return f'computed value for {obj}'
2.4 关系字段
class ArticleSerializer(serializers.ModelSerializer):
# 主键关系
category_id = serializers.PrimaryKeyRelatedField(
queryset=Category.objects.all(),
source='category'
)
# 字符串表示
author_name = serializers.StringRelatedField(source='author')
# Slug 关系
category_slug = serializers.SlugRelatedField(
slug_field='slug',
queryset=Category.objects.all(),
source='category'
)
# 超链接关系
category_url = serializers.HyperlinkedRelatedField(
view_name='category-detail',
read_only=True,
source='category'
)
# 嵌套关系
category = CategorySerializer(read_only=True)
tags = TagSerializer(many=True, read_only=True)
class Meta:
model = Article
fields = '__all__'
2.5 嵌套序列化器
class TagSerializer(serializers.ModelSerializer):
class Meta:
model = Tag
fields = ['id', 'name', 'slug']
class CategorySerializer(serializers.ModelSerializer):
class Meta:
model = Category
fields = ['id', 'name', 'slug']
class ArticleListSerializer(serializers.ModelSerializer):
"""列表序列化器(简化)"""
author = serializers.StringRelatedField()
category = serializers.StringRelatedField()
class Meta:
model = Article
fields = ['id', 'title', 'author', 'category', 'created_at']
class ArticleDetailSerializer(serializers.ModelSerializer):
"""详情序列化器(完整)"""
author = serializers.SerializerMethodField()
category = CategorySerializer()
tags = TagSerializer(many=True)
comments_count = serializers.SerializerMethodField()
class Meta:
model = Article
fields = '__all__'
def get_author(self, obj):
return {
'id': obj.author.id,
'username': obj.author.username,
'avatar': obj.author.profile.avatar.url if obj.author.profile.avatar else None,
}
def get_comments_count(self, obj):
return obj.comments.count()
class ArticleCreateSerializer(serializers.ModelSerializer):
"""创建序列化器"""
tags = serializers.PrimaryKeyRelatedField(
many=True,
queryset=Tag.objects.all(),
required=False
)
class Meta:
model = Article
fields = ['title', 'content', 'category', 'tags', 'status']
def create(self, validated_data):
tags = validated_data.pop('tags', [])
article = Article.objects.create(**validated_data)
article.tags.set(tags)
return article
2.6 验证
from rest_framework import serializers
from rest_framework.validators import UniqueValidator, UniqueTogetherValidator
class ArticleSerializer(serializers.ModelSerializer):
title = serializers.CharField(
max_length=200,
validators=[
UniqueValidator(
queryset=Article.objects.all(),
message='标题已存在'
)
]
)
class Meta:
model = Article
fields = '__all__'
validators = [
UniqueTogetherValidator(
queryset=Article.objects.all(),
fields=['author', 'slug'],
message='该作者已有同名文章'
)
]
# 字段级验证
def validate_title(self, value):
if len(value) < 5:
raise serializers.ValidationError('标题至少需要5个字符')
if '敏感词' in value:
raise serializers.ValidationError('标题包含敏感词')
return value
def validate_content(self, value):
if len(value) < 100:
raise serializers.ValidationError('内容至少需要100个字符')
return value
# 对象级验证
def validate(self, attrs):
if attrs.get('status') == 'published':
if not attrs.get('category'):
raise serializers.ValidationError({
'category': '发布文章必须选择分类'
})
if not attrs.get('content'):
raise serializers.ValidationError({
'content': '发布文章必须有内容'
})
return attrs
# 自定义验证器
class TitleValidator:
def __init__(self, forbidden_words=None):
self.forbidden_words = forbidden_words or []
def __call__(self, value):
for word in self.forbidden_words:
if word in value:
raise serializers.ValidationError(f'标题不能包含:{word}')
def __repr__(self):
return f'TitleValidator(forbidden_words={self.forbidden_words})'
三、视图
3.1 函数视图
from rest_framework.decorators import api_view, permission_classes
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import IsAuthenticated
@api_view(['GET', 'POST'])
@permission_classes([IsAuthenticated])
def article_list(request):
if request.method == 'GET':
articles = Article.objects.all()
serializer = ArticleSerializer(articles, many=True)
return Response(serializer.data)
elif request.method == 'POST':
serializer = ArticleSerializer(data=request.data)
if serializer.is_valid():
serializer.save(author=request.user)
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@api_view(['GET', 'PUT', 'DELETE'])
def article_detail(request, pk):
try:
article = Article.objects.get(pk=pk)
except Article.DoesNotExist:
return Response(status=status.HTTP_404_NOT_FOUND)
if request.method == 'GET':
serializer = ArticleSerializer(article)
return Response(serializer.data)
elif request.method == 'PUT':
serializer = ArticleSerializer(article, data=request.data)
if serializer.is_valid():
serializer.save()
return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
elif request.method == 'DELETE':
article.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
3.2 类视图(APIView)
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
class ArticleList(APIView):
"""文章列表视图"""
def get(self, request):
articles = Article.objects.all()
serializer = ArticleSerializer(articles, many=True)
return Response(serializer.data)
def post(self, request):
serializer = ArticleSerializer(data=request.data)
if serializer.is_valid():
serializer.save(author=request.user)
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
class ArticleDetail(APIView):
"""文章详情视图"""
def get_object(self, pk):
try:
return Article.objects.get(pk=pk)
except Article.DoesNotExist:
raise Http404
def get(self, request, pk):
article = self.get_object(pk)
serializer = ArticleSerializer(article)
return Response(serializer.data)
def put(self, request, pk):
article = self.get_object(pk)
serializer = ArticleSerializer(article, data=request.data)
if serializer.is_valid():
serializer.save()
return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def delete(self, request, pk):
article = self.get_object(pk)
article.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
3.3 通用视图
from rest_framework import generics
class ArticleList(generics.ListCreateAPIView):
"""文章列表和创建"""
queryset = Article.objects.all()
serializer_class = ArticleSerializer
def perform_create(self, serializer):
serializer.save(author=self.request.user)
class ArticleDetail(generics.RetrieveUpdateDestroyAPIView):
"""文章详情、更新和删除"""
queryset = Article.objects.all()
serializer_class = ArticleSerializer
# 其他通用视图
# CreateAPIView - 创建
# ListAPIView - 列表
# RetrieveAPIView - 详情
# DestroyAPIView - 删除
# UpdateAPIView - 更新
# ListCreateAPIView - 列表+创建
# RetrieveUpdateAPIView - 详情+更新
# RetrieveDestroyAPIView - 详情+删除
# RetrieveUpdateDestroyAPIView - 详情+更新+删除
3.4 视图集(ViewSet)
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
class ArticleViewSet(viewsets.ModelViewSet):
"""文章视图集"""
queryset = Article.objects.all()
serializer_class = ArticleSerializer
def get_queryset(self):
"""自定义查询集"""
queryset = super().get_queryset()
# 非管理员只能看已发布文章
if not self.request.user.is_staff:
queryset = queryset.filter(status='published')
# 按分类过滤
category = self.request.query_params.get('category')
if category:
queryset = queryset.filter(category__slug=category)
return queryset.select_related('author', 'category').prefetch_related('tags')
def get_serializer_class(self):
"""根据操作返回不同序列化器"""
if self.action == 'list':
return ArticleListSerializer
elif self.action == 'retrieve':
return ArticleDetailSerializer
elif self.action in ['create', 'update', 'partial_update']:
return ArticleCreateSerializer
return ArticleSerializer
def perform_create(self, serializer):
serializer.save(author=self.request.user)
# 自定义操作
@action(detail=True, methods=['post'])
def publish(self, request, pk=None):
"""发布文章"""
article = self.get_object()
article.status = 'published'
article.published_at = timezone.now()
article.save()
return Response({'status': 'published'})
@action(detail=True, methods=['post'])
def like(self, request, pk=None):
"""点赞文章"""
article = self.get_object()
like, created = Like.objects.get_or_create(
article=article,
user=request.user
)
if not created:
like.delete()
return Response({'liked': False, 'count': article.likes.count()})
return Response({'liked': True, 'count': article.likes.count()})
@action(detail=False, methods=['get'])
def featured(self, request):
"""推荐文章"""
articles = self.get_queryset().filter(is_featured=True)[:10]
serializer = self.get_serializer(articles, many=True)
return Response(serializer.data)
@action(detail=False, methods=['get'])
def my_articles(self, request):
"""我的文章"""
articles = self.get_queryset().filter(author=request.user)
page = self.paginate_queryset(articles)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(articles, many=True)
return Response(serializer.data)
四、认证与权限
4.1 认证类
# settings.py
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.TokenAuthentication',
'rest_framework_simplejwt.authentication.JWTAuthentication',
],
}
# Token 认证
from rest_framework.authtoken.views import obtain_auth_token
urlpatterns = [
path('api/token/', obtain_auth_token),
]
# JWT 认证
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
urlpatterns = [
path('api/token/', TokenObtainPairView.as_view()),
path('api/token/refresh/', TokenRefreshView.as_view()),
]
# 自定义认证
from rest_framework.authentication import BaseAuthentication
class CustomAuthentication(BaseAuthentication):
def authenticate(self, request):
# 从请求头获取 token
token = request.META.get('HTTP_X_API_KEY')
if not token:
return None
try:
user = User.objects.get(api_key=token)
except User.DoesNotExist:
raise AuthenticationFailed('无效的 API Key')
return (user, None)
4.2 权限类
from rest_framework.permissions import (
BasePermission,
IsAuthenticated,
IsAdminUser,
IsAuthenticatedOrReadOnly,
AllowAny,
DjangoModelPermissions,
)
# 使用内置权限
class ArticleViewSet(viewsets.ModelViewSet):
permission_classes = [IsAuthenticatedOrReadOnly]
# 自定义权限
class IsOwnerOrReadOnly(BasePermission):
"""只有作者才能修改"""
def has_object_permission(self, request, view, obj):
# 读取权限对所有人开放
if request.method in ['GET', 'HEAD', 'OPTIONS']:
return True
# 写权限只给作者
return obj.author == request.user
class IsAdminOrReadOnly(BasePermission):
"""只有管理员才能修改"""
def has_permission(self, request, view):
if request.method in ['GET', 'HEAD', 'OPTIONS']:
return True
return request.user and request.user.is_staff
class ArticlePermission(BasePermission):
"""文章权限"""
def has_permission(self, request, view):
# 列表和详情对所有人开放
if view.action in ['list', 'retrieve']:
return True
# 创建需要登录
if view.action == 'create':
return request.user.is_authenticated
# 其他操作需要具体判断
return True
def has_object_permission(self, request, view, obj):
# 读取权限对所有人开放
if view.action == 'retrieve':
return True
# 只有作者或管理员可以修改/删除
return obj.author == request.user or request.user.is_staff
class ArticleViewSet(viewsets.ModelViewSet):
permission_classes = [ArticlePermission]
4.3 限流
from rest_framework.throttling import UserRateThrottle, AnonRateThrottle
# 全局配置
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle',
],
'DEFAULT_THROTTLE_RATES': {
'anon': '100/hour',
'user': '1000/hour',
},
}
# 自定义限流
class BurstRateThrottle(UserRateThrottle):
scope = 'burst'
class SustainedRateThrottle(UserRateThrottle):
scope = 'sustained'
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_RATES': {
'burst': '60/minute',
'sustained': '1000/day',
},
}
# 视图级别限流
class ArticleViewSet(viewsets.ModelViewSet):
throttle_classes = [UserRateThrottle]
@action(detail=True, methods=['post'])
def like(self, request, pk=None):
# 对点赞操作特殊限流
pass
def get_throttles(self):
if self.action == 'like':
return [BurstRateThrottle()]
return super().get_throttles()
五、过滤与搜索
5.1 DjangoFilterBackend
from django_filters import rest_framework as filters
class ArticleFilter(filters.FilterSet):
"""文章过滤器"""
title = filters.CharFilter(lookup_expr='icontains')
status = filters.ChoiceFilter(choices=Article.STATUS_CHOICES)
category = filters.CharFilter(field_name='category__slug')
author = filters.CharFilter(field_name='author__username')
created_after = filters.DateFilter(field_name='created_at', lookup_expr='gte')
created_before = filters.DateFilter(field_name='created_at', lookup_expr='lte')
min_views = filters.NumberFilter(field_name='views', lookup_expr='gte')
tags = filters.CharFilter(method='filter_by_tags')
class Meta:
model = Article
fields = ['title', 'status', 'category', 'author']
def filter_by_tags(self, queryset, name, value):
tags = value.split(',')
return queryset.filter(tags__slug__in=tags).distinct()
class ArticleViewSet(viewsets.ModelViewSet):
queryset = Article.objects.all()
serializer_class = ArticleSerializer
filterset_class = ArticleFilter
5.2 搜索和排序
from rest_framework.filters import SearchFilter, OrderingFilter
class ArticleViewSet(viewsets.ModelViewSet):
queryset = Article.objects.all()
serializer_class = ArticleSerializer
filter_backends = [SearchFilter, OrderingFilter]
# 搜索字段
search_fields = [
'title', # 精确匹配
'^title', # 前缀匹配
'=title', # 精确等于
'@title', # 全文搜索(需要数据库支持)
'$title', # 正则搜索
'content',
'author__username',
]
# 排序字段
ordering_fields = ['created_at', 'views', 'likes_count', 'title']
ordering = ['-created_at'] # 默认排序
# 请求示例
# GET /api/articles/?search=django
# GET /api/articles/?ordering=-views
# GET /api/articles/?ordering=created_at,-views
六、分页
6.1 分页类
from rest_framework.pagination import (
PageNumberPagination,
LimitOffsetPagination,
CursorPagination,
)
# 页码分页
class StandardPagination(PageNumberPagination):
page_size = 20
page_size_query_param = 'page_size'
max_page_size = 100
page_query_param = 'page'
# Limit/Offset 分页
class LimitOffsetPagination(LimitOffsetPagination):
default_limit = 20
max_limit = 100
limit_query_param = 'limit'
offset_query_param = 'offset'
# 游标分页(适合大数据集)
class ArticleCursorPagination(CursorPagination):
page_size = 20
ordering = '-created_at'
cursor_query_param = 'cursor'
class ArticleViewSet(viewsets.ModelViewSet):
pagination_class = StandardPagination
# 自定义分页响应
class CustomPagination(PageNumberPagination):
def get_paginated_response(self, data):
return Response({
'pagination': {
'total': self.page.paginator.count,
'page': self.page.number,
'page_size': self.page_size,
'total_pages': self.page.paginator.num_pages,
'has_next': self.page.has_next(),
'has_previous': self.page.has_previous(),
},
'results': data
})
七、版本控制
7.1 URL 路径版本
# settings.py
REST_FRAMEWORK = {
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning',
'DEFAULT_VERSION': 'v1',
'ALLOWED_VERSIONS': ['v1', 'v2'],
'VERSION_PARAM': 'version',
}
# urls.py
urlpatterns = [
path('api/<version>/', include('api.urls')),
]
# views.py
class ArticleViewSet(viewsets.ModelViewSet):
def get_serializer_class(self):
if self.request.version == 'v1':
return ArticleSerializerV1
return ArticleSerializerV2
7.2 其他版本控制方式
# URL 参数版本
# GET /api/articles/?version=v1
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.QueryParameterVersioning',
# 请求头版本
# Accept: application/json; version=1.0
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.AcceptHeaderVersioning',
# 命名空间版本
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.NamespaceVersioning',
# 主机名版本
# v1.api.example.com
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.HostNameVersioning',
八、异常处理
8.1 自定义异常处理
from rest_framework.views import exception_handler
from rest_framework.response import Response
from rest_framework import status
import logging
logger = logging.getLogger('api')
def custom_exception_handler(exc, context):
"""自定义异常处理"""
# 调用默认处理
response = exception_handler(exc, context)
if response is not None:
# 自定义响应格式
response.data = {
'success': False,
'error': {
'code': response.status_code,
'message': str(exc),
'detail': response.data,
}
}
else:
# 处理未捕获的异常
logger.error(f'未处理的异常: {exc}', exc_info=True)
response = Response({
'success': False,
'error': {
'code': 500,
'message': '服务器内部错误',
}
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return response
# 自定义异常
from rest_framework.exceptions import APIException
class ServiceUnavailable(APIException):
status_code = 503
default_detail = '服务暂时不可用'
default_code = 'service_unavailable'
class ArticleNotPublished(APIException):
status_code = 403
default_detail = '文章未发布'
default_code = 'article_not_published'
# 使用
class ArticleViewSet(viewsets.ModelViewSet):
def retrieve(self, request, pk=None):
article = self.get_object()
if article.status != 'published' and article.author != request.user:
raise ArticleNotPublished()
return super().retrieve(request, pk)
九、API 文档
9.1 OpenAPI/Swagger
# pip install drf-yasg
# settings.py
INSTALLED_APPS = [
# ...
'drf_yasg',
]
# urls.py
from drf_yasg.views import get_schema_view
from drf_yasg import openapi
schema_view = get_schema_view(
openapi.Info(
title="Blog API",
default_version='v1',
description="博客系统 API 文档",
terms_of_service="https://www.example.com/terms/",
contact=openapi.Contact(email="contact@example.com"),
license=openapi.License(name="MIT License"),
),
public=True,
permission_classes=[AllowAny],
)
urlpatterns = [
path('swagger/', schema_view.with_ui('swagger', cache_timeout=0)),
path('redoc/', schema_view.with_ui('redoc', cache_timeout=0)),
path('swagger.json', schema_view.without_ui(cache_timeout=0)),
]
9.2 为视图添加文档
from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi
class ArticleViewSet(viewsets.ModelViewSet):
@swagger_auto_schema(
operation_description="获取文章列表",
manual_parameters=[
openapi.Parameter(
'category',
openapi.IN_QUERY,
description="按分类筛选",
type=openapi.TYPE_STRING
),
openapi.Parameter(
'status',
openapi.IN_QUERY,
description="按状态筛选",
type=openapi.TYPE_STRING,
enum=['draft', 'published']
),
],
responses={
200: ArticleSerializer(many=True),
401: '未认证',
}
)
def list(self, request):
return super().list(request)
@swagger_auto_schema(
operation_description="创建文章",
request_body=ArticleCreateSerializer,
responses={
201: ArticleSerializer,
400: '请求数据无效',
}
)
def create(self, request):
return super().create(request)
十、测试
10.1 API 测试
from rest_framework.test import APITestCase, APIClient
from rest_framework import status
from django.urls import reverse
class ArticleAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
self.user = User.objects.create_user(
username='testuser',
password='testpass123'
)
self.client.force_authenticate(user=self.user)
self.category = Category.objects.create(name='技术', slug='tech')
self.article = Article.objects.create(
title='测试文章',
content='测试内容',
author=self.user,
category=self.category,
status='published',
)
def test_list_articles(self):
"""测试获取文章列表"""
url = reverse('article-list')
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data['results']), 1)
def test_create_article(self):
"""测试创建文章"""
url = reverse('article-list')
data = {
'title': '新文章',
'content': '新内容',
'category': self.category.id,
'status': 'draft',
}
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Article.objects.count(), 2)
self.assertEqual(response.data['title'], '新文章')
def test_retrieve_article(self):
"""测试获取文章详情"""
url = reverse('article-detail', kwargs={'pk': self.article.pk})
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['title'], '测试文章')
def test_update_article(self):
"""测试更新文章"""
url = reverse('article-detail', kwargs={'pk': self.article.pk})
data = {'title': '更新标题'}
response = self.client.patch(url, data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.article.refresh_from_db()
self.assertEqual(self.article.title, '更新标题')
def test_delete_article(self):
"""测试删除文章"""
url = reverse('article-detail', kwargs={'pk': self.article.pk})
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(Article.objects.count(), 0)
def test_unauthenticated_create(self):
"""测试未认证用户创建文章"""
self.client.logout()
url = reverse('article-list')
data = {'title': '新文章', 'content': '内容'}
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_filter_by_category(self):
"""测试按分类过滤"""
url = reverse('article-list')
response = self.client.get(url, {'category': 'tech'})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data['results']), 1)
def test_search_articles(self):
"""测试搜索文章"""
url = reverse('article-list')
response = self.client.get(url, {'search': '测试'})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertGreater(len(response.data['results']), 0)
十一、最佳实践
11.1 API 设计规范
# 1. 使用复数名词作为资源名
# GET /api/articles/ # 好
# GET /api/article/ # 不好
# 2. 使用 HTTP 方法表示操作
# GET /api/articles/ 获取列表
# POST /api/articles/ 创建
# GET /api/articles/1/ 获取详情
# PUT /api/articles/1/ 完整更新
# PATCH /api/articles/1/ 部分更新
# DELETE /api/articles/1/ 删除
# 3. 使用嵌套路由表示关系
# GET /api/articles/1/comments/
# 4. 统一响应格式
{
"success": true,
"data": {...},
"message": "操作成功"
}
{
"success": false,
"error": {
"code": 400,
"message": "请求无效",
"detail": {...}
}
}
# 5. 使用合适的状态码
# 200 OK - 成功
# 201 Created - 创建成功
# 204 No Content - 删除成功
# 400 Bad Request - 请求错误
# 401 Unauthorized - 未认证
# 403 Forbidden - 无权限
# 404 Not Found - 未找到
# 500 Internal Server Error - 服务器错误
11.2 性能优化
class ArticleViewSet(viewsets.ModelViewSet):
def get_queryset(self):
# 优化查询
return Article.objects.select_related(
'author', 'category'
).prefetch_related(
'tags'
).annotate(
likes_count=Count('likes'),
comments_count=Count('comments')
)
def get_serializer_class(self):
# 列表使用简化序列化器
if self.action == 'list':
return ArticleListSerializer
return ArticleDetailSerializer
@method_decorator(cache_page(60 * 15)) # 缓存15分钟
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)