Django 中间件与信号
2026/3/20大约 11 分钟
Django 中间件与信号
一、中间件基础
1.1 中间件概念
中间件是 Django 请求/响应处理的钩子框架,它是一个轻量级的、底层的"插件"系统,用于全局改变 Django 的输入或输出。每个中间件组件负责执行某个特定的功能。
# settings.py
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
1.2 请求/响应处理流程
请求进入
↓
SecurityMiddleware.process_request()
↓
SessionMiddleware.process_request()
↓
CommonMiddleware.process_request()
↓
CsrfViewMiddleware.process_request()
↓
AuthenticationMiddleware.process_request()
↓
MessageMiddleware.process_request()
↓
XFrameOptionsMiddleware.process_request()
↓
URL 路由 → 视图函数
↓
XFrameOptionsMiddleware.process_response()
↓
MessageMiddleware.process_response()
↓
AuthenticationMiddleware.process_response()
↓
CsrfViewMiddleware.process_response()
↓
CommonMiddleware.process_response()
↓
SessionMiddleware.process_response()
↓
SecurityMiddleware.process_response()
↓
响应返回
二、自定义中间件
2.1 函数式中间件
# middleware.py
def simple_middleware(get_response):
"""简单的函数式中间件"""
def middleware(request):
# 请求处理前的代码
print(f'请求路径: {request.path}')
response = get_response(request)
# 响应处理后的代码
print(f'响应状态: {response.status_code}')
return response
return middleware
2.2 类式中间件
class SimpleMiddleware:
"""类式中间件"""
def __init__(self, get_response):
self.get_response = get_response
# 一次性配置和初始化
print('中间件初始化')
def __call__(self, request):
# 请求处理前的代码
self.process_request(request)
response = self.get_response(request)
# 响应处理后的代码
self.process_response(request, response)
return response
def process_request(self, request):
"""处理请求"""
pass
def process_response(self, request, response):
"""处理响应"""
pass
2.3 完整的中间件钩子
class FullMiddleware:
"""包含所有钩子的中间件"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
return response
def process_view(self, request, view_func, view_args, view_kwargs):
"""
在视图函数调用前执行
返回 None 继续处理,返回 HttpResponse 则短路
"""
print(f'即将调用视图: {view_func.__name__}')
return None
def process_exception(self, request, exception):
"""
视图抛出异常时执行
返回 None 则继续传播异常,返回 HttpResponse 则处理异常
"""
print(f'发生异常: {exception}')
return None
def process_template_response(self, request, response):
"""
视图返回 TemplateResponse 时执行
必须返回响应对象
"""
return response
三、常用中间件实现
3.1 请求日志中间件
import time
import logging
logger = logging.getLogger('request')
class RequestLoggingMiddleware:
"""请求日志中间件"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 记录开始时间
start_time = time.time()
# 获取请求信息
request_info = {
'method': request.method,
'path': request.path,
'user': str(request.user) if hasattr(request, 'user') else 'Anonymous',
'ip': self.get_client_ip(request),
}
response = self.get_response(request)
# 计算耗时
duration = time.time() - start_time
# 记录日志
logger.info(
f"{request_info['method']} {request_info['path']} "
f"| User: {request_info['user']} | IP: {request_info['ip']} "
f"| Status: {response.status_code} | Duration: {duration:.3f}s"
)
return response
def get_client_ip(self, request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
return x_forwarded_for.split(',')[0].strip()
return request.META.get('REMOTE_ADDR')
3.2 性能监控中间件
import time
from django.db import connection, reset_queries
from django.conf import settings
class PerformanceMiddleware:
"""性能监控中间件"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 重置查询记录
reset_queries()
start_time = time.time()
response = self.get_response(request)
# 计算统计信息
duration = time.time() - start_time
queries = connection.queries
query_count = len(queries)
query_time = sum(float(q['time']) for q in queries)
# 添加响应头(仅开发环境)
if settings.DEBUG:
response['X-Request-Duration'] = f'{duration:.3f}s'
response['X-Query-Count'] = str(query_count)
response['X-Query-Time'] = f'{query_time:.3f}s'
# 慢请求警告
if duration > 1.0:
import logging
logger = logging.getLogger('performance')
logger.warning(
f'慢请求: {request.path} | '
f'Duration: {duration:.3f}s | '
f'Queries: {query_count}'
)
return response
3.3 访问控制中间件
from django.http import HttpResponseForbidden
from django.conf import settings
import re
class IPAccessMiddleware:
"""IP 访问控制中间件"""
def __init__(self, get_response):
self.get_response = get_response
self.whitelist = getattr(settings, 'IP_WHITELIST', [])
self.blacklist = getattr(settings, 'IP_BLACKLIST', [])
def __call__(self, request):
ip = self.get_client_ip(request)
# 检查黑名单
if self.is_in_list(ip, self.blacklist):
return HttpResponseForbidden('您的 IP 已被禁止访问')
# 检查白名单(如果配置了白名单,则只允许白名单 IP)
if self.whitelist and not self.is_in_list(ip, self.whitelist):
return HttpResponseForbidden('您的 IP 不在允许列表中')
return self.get_response(request)
def get_client_ip(self, request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
return x_forwarded_for.split(',')[0].strip()
return request.META.get('REMOTE_ADDR')
def is_in_list(self, ip, ip_list):
for pattern in ip_list:
if self.match_ip(ip, pattern):
return True
return False
def match_ip(self, ip, pattern):
# 支持通配符,如 192.168.*.*
pattern = pattern.replace('.', r'\.').replace('*', r'\d+')
return re.match(f'^{pattern}$', ip) is not None
class MaintenanceMiddleware:
"""维护模式中间件"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
if getattr(settings, 'MAINTENANCE_MODE', False):
# 允许管理员访问
if hasattr(request, 'user') and request.user.is_superuser:
return self.get_response(request)
# 允许特定路径
allowed_paths = ['/admin/', '/maintenance/']
if any(request.path.startswith(p) for p in allowed_paths):
return self.get_response(request)
from django.shortcuts import render
return render(request, 'maintenance.html', status=503)
return self.get_response(request)
3.4 请求限流中间件
from django.core.cache import cache
from django.http import HttpResponse
import time
class RateLimitMiddleware:
"""请求限流中间件"""
def __init__(self, get_response):
self.get_response = get_response
self.rate_limit = 100 # 每分钟最大请求数
self.window = 60 # 时间窗口(秒)
def __call__(self, request):
if self.is_rate_limited(request):
return HttpResponse(
'请求过于频繁,请稍后再试',
status=429
)
return self.get_response(request)
def get_client_identifier(self, request):
"""获取客户端标识"""
if request.user.is_authenticated:
return f'user_{request.user.id}'
return f'ip_{self.get_client_ip(request)}'
def get_client_ip(self, request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
return x_forwarded_for.split(',')[0].strip()
return request.META.get('REMOTE_ADDR')
def is_rate_limited(self, request):
identifier = self.get_client_identifier(request)
key = f'rate_limit:{identifier}'
# 获取当前请求计数
requests = cache.get(key, [])
# 清理过期记录
now = time.time()
requests = [r for r in requests if now - r < self.window]
# 检查是否超过限制
if len(requests) >= self.rate_limit:
return True
# 记录本次请求
requests.append(now)
cache.set(key, requests, self.window)
return False
3.5 异常处理中间件
from django.http import JsonResponse
from django.conf import settings
import traceback
import logging
logger = logging.getLogger('django.request')
class ExceptionHandlingMiddleware:
"""全局异常处理中间件"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
return self.get_response(request)
def process_exception(self, request, exception):
# 记录异常
logger.error(
f'未处理的异常: {exception}\n'
f'请求路径: {request.path}\n'
f'用户: {request.user}\n'
f'Traceback: {traceback.format_exc()}'
)
# API 请求返回 JSON
if request.path.startswith('/api/'):
return JsonResponse({
'error': str(exception) if settings.DEBUG else '服务器内部错误',
'status': 500
}, status=500)
# 其他请求返回 None,使用默认异常处理
return None
3.6 跨域中间件
class CORSMiddleware:
"""跨域资源共享中间件"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 处理预检请求
if request.method == 'OPTIONS':
response = HttpResponse()
else:
response = self.get_response(request)
# 添加 CORS 头
origin = request.META.get('HTTP_ORIGIN')
if origin and self.is_origin_allowed(origin):
response['Access-Control-Allow-Origin'] = origin
response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-CSRFToken'
response['Access-Control-Allow-Credentials'] = 'true'
response['Access-Control-Max-Age'] = '86400'
return response
def is_origin_allowed(self, origin):
allowed_origins = [
'http://localhost:3000',
'https://example.com',
]
return origin in allowed_origins
四、信号基础
4.1 信号概念
Django 信号机制允许解耦的应用在框架的其他地方发生操作时得到通知。信号允许特定的发送者通知一组接收者某些操作已经发生。
from django.db.models.signals import pre_save, post_save, pre_delete, post_delete
from django.core.signals import request_started, request_finished
from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
4.2 内置信号
# 模型信号
from django.db.models.signals import (
pre_init, # 模型实例化前
post_init, # 模型实例化后
pre_save, # 保存前
post_save, # 保存后
pre_delete, # 删除前
post_delete, # 删除后
m2m_changed, # 多对多关系变化
)
# 请求信号
from django.core.signals import (
request_started, # 请求开始
request_finished, # 请求结束
got_request_exception, # 请求异常
)
# 认证信号
from django.contrib.auth.signals import (
user_logged_in, # 用户登录
user_logged_out, # 用户登出
user_login_failed, # 登录失败
)
# 数据库信号
from django.db.backends.signals import connection_created
# 迁移信号
from django.db.models.signals import (
pre_migrate,
post_migrate,
)
五、使用信号
5.1 连接信号
from django.db.models.signals import post_save
from django.dispatch import receiver
from .models import Article
# 方式一:装饰器
@receiver(post_save, sender=Article)
def article_post_save(sender, instance, created, **kwargs):
if created:
print(f'新文章创建: {instance.title}')
else:
print(f'文章更新: {instance.title}')
# 方式二:connect 方法
def article_handler(sender, instance, **kwargs):
print(f'文章操作: {instance.title}')
post_save.connect(article_handler, sender=Article)
# 方式三:在 apps.py 中连接
class BlogConfig(AppConfig):
name = 'blog'
def ready(self):
import blog.signals # 导入信号模块
5.2 信号处理最佳实践
# blog/signals.py
from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from django.core.cache import cache
from .models import Article
@receiver(post_save, sender=Article)
def invalidate_article_cache(sender, instance, **kwargs):
"""文章保存后清除缓存"""
cache_keys = [
f'article_{instance.pk}',
'article_list',
f'category_{instance.category_id}_articles',
]
cache.delete_many(cache_keys)
@receiver(post_save, sender=Article)
def update_search_index(sender, instance, created, **kwargs):
"""更新搜索索引"""
from .tasks import index_article
index_article.delay(instance.pk)
@receiver(post_delete, sender=Article)
def cleanup_article_files(sender, instance, **kwargs):
"""删除文章时清理文件"""
if instance.cover:
instance.cover.delete(save=False)
# blog/apps.py
from django.apps import AppConfig
class BlogConfig(AppConfig):
name = 'blog'
def ready(self):
import blog.signals # noqa
5.3 自定义信号
# signals.py
from django.dispatch import Signal
# 定义信号
article_published = Signal() # 文章发布
article_viewed = Signal() # 文章浏览
comment_posted = Signal() # 评论发布
# 发送信号
def publish_article(article):
article.status = 'published'
article.save()
# 发送信号
article_published.send(
sender=article.__class__,
article=article,
user=article.author,
)
# 接收信号
@receiver(article_published)
def notify_subscribers(sender, article, user, **kwargs):
"""通知订阅者"""
subscribers = article.author.subscribers.all()
for subscriber in subscribers:
send_notification(subscriber, article)
@receiver(article_published)
def update_statistics(sender, article, **kwargs):
"""更新统计"""
article.author.article_count += 1
article.author.save()
# views.py
def article_detail(request, pk):
article = get_object_or_404(Article, pk=pk)
# 发送浏览信号
article_viewed.send(
sender=Article,
article=article,
user=request.user,
request=request,
)
return render(request, 'article_detail.html', {'article': article})
六、常用信号实现
6.1 用户相关信号
from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
from django.dispatch import receiver
from django.utils import timezone
@receiver(user_logged_in)
def on_user_logged_in(sender, request, user, **kwargs):
"""用户登录后"""
# 更新最后登录 IP
user.last_login_ip = get_client_ip(request)
user.save(update_fields=['last_login_ip'])
# 记录登录日志
LoginLog.objects.create(
user=user,
ip=get_client_ip(request),
user_agent=request.META.get('HTTP_USER_AGENT', ''),
action='login',
)
@receiver(user_logged_out)
def on_user_logged_out(sender, request, user, **kwargs):
"""用户登出后"""
if user:
LoginLog.objects.create(
user=user,
ip=get_client_ip(request),
action='logout',
)
@receiver(user_login_failed)
def on_user_login_failed(sender, credentials, request, **kwargs):
"""登录失败"""
username = credentials.get('username', '')
ip = get_client_ip(request)
LoginLog.objects.create(
username=username,
ip=ip,
action='failed',
)
# 检查是否需要锁定
recent_failures = LoginLog.objects.filter(
ip=ip,
action='failed',
created_at__gte=timezone.now() - timezone.timedelta(minutes=30)
).count()
if recent_failures >= 5:
block_ip(ip)
6.2 模型生命周期信号
from django.db.models.signals import pre_save, post_save, pre_delete, post_delete
from django.dispatch import receiver
@receiver(pre_save, sender=Article)
def article_pre_save(sender, instance, **kwargs):
"""文章保存前"""
# 自动生成 slug
if not instance.slug:
from django.utils.text import slugify
instance.slug = slugify(instance.title)
# 计算阅读时间
word_count = len(instance.content.split())
instance.reading_time = max(1, word_count // 200)
# 检查是否首次发布
if instance.pk:
old_instance = Article.objects.get(pk=instance.pk)
if old_instance.status != 'published' and instance.status == 'published':
instance.published_at = timezone.now()
@receiver(post_save, sender=Article)
def article_post_save(sender, instance, created, **kwargs):
"""文章保存后"""
if created:
# 新文章通知
notify_admins(f'新文章: {instance.title}')
# 创建默认缩略图
if not instance.cover:
generate_default_cover(instance)
# 更新作者文章数
instance.author.update_article_count()
@receiver(post_delete, sender=Article)
def article_post_delete(sender, instance, **kwargs):
"""文章删除后"""
# 清理文件
if instance.cover:
instance.cover.delete(save=False)
# 更新统计
instance.author.update_article_count()
# 清理缓存
cache.delete(f'article_{instance.pk}')
6.3 多对多关系信号
from django.db.models.signals import m2m_changed
from django.dispatch import receiver
@receiver(m2m_changed, sender=Article.tags.through)
def article_tags_changed(sender, instance, action, pk_set, **kwargs):
"""文章标签变化"""
if action == 'pre_add':
print(f'即将添加标签: {pk_set}')
elif action == 'post_add':
print(f'已添加标签: {pk_set}')
# 更新标签计数
for tag_id in pk_set:
Tag.objects.filter(pk=tag_id).update(
article_count=F('article_count') + 1
)
elif action == 'pre_remove':
print(f'即将移除标签: {pk_set}')
elif action == 'post_remove':
print(f'已移除标签: {pk_set}')
for tag_id in pk_set:
Tag.objects.filter(pk=tag_id).update(
article_count=F('article_count') - 1
)
elif action == 'pre_clear':
print('即将清空所有标签')
elif action == 'post_clear':
print('已清空所有标签')
6.4 自动创建关联对象
from django.db.models.signals import post_save
from django.dispatch import receiver
from django.contrib.auth import get_user_model
User = get_user_model()
@receiver(post_save, sender=User)
def create_user_profile(sender, instance, created, **kwargs):
"""创建用户资料"""
if created:
Profile.objects.create(user=instance)
@receiver(post_save, sender=User)
def create_user_settings(sender, instance, created, **kwargs):
"""创建用户设置"""
if created:
UserSettings.objects.create(user=instance)
@receiver(post_save, sender=User)
def create_user_statistics(sender, instance, created, **kwargs):
"""创建用户统计"""
if created:
UserStatistics.objects.create(user=instance)
七、信号高级用法
7.1 条件信号处理
@receiver(post_save, sender=Article)
def conditional_handler(sender, instance, created, **kwargs):
"""条件信号处理"""
# 使用 update_fields 检查
update_fields = kwargs.get('update_fields')
if update_fields is not None:
# 只更新了特定字段
if 'views' in update_fields:
# 浏览量更新,不需要其他处理
return
# 检查状态变化
if hasattr(instance, '_original_status'):
if instance._original_status != instance.status:
handle_status_change(instance)
7.2 临时禁用信号
import functools
from django.db.models.signals import post_save
class DisableSignals:
"""临时禁用信号的上下文管理器"""
def __init__(self, disabled_signals=None):
self.disabled_signals = disabled_signals or []
self.stashed_signals = {}
def __enter__(self):
for signal in self.disabled_signals:
self.stashed_signals[signal] = signal.receivers
signal.receivers = []
def __exit__(self, exc_type, exc_val, exc_tb):
for signal in self.disabled_signals:
signal.receivers = self.stashed_signals[signal]
# 使用
with DisableSignals([post_save]):
article.save() # 不会触发 post_save 信号
# 装饰器版本
def disable_signals(disabled_signals):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with DisableSignals(disabled_signals):
return func(*args, **kwargs)
return wrapper
return decorator
@disable_signals([post_save])
def bulk_update_articles():
for article in articles:
article.save()
7.3 异步信号处理
from django.db.models.signals import post_save
from django.dispatch import receiver
from celery import shared_task
@shared_task
def async_process_article(article_id):
"""异步处理文章"""
article = Article.objects.get(pk=article_id)
# 耗时操作
generate_summary(article)
update_search_index(article)
notify_subscribers(article)
@receiver(post_save, sender=Article)
def article_post_save_async(sender, instance, created, **kwargs):
"""异步信号处理"""
if created:
# 使用 Celery 异步处理
async_process_article.delay(instance.pk)
7.4 信号调试
import logging
from django.db.models.signals import post_save
from django.dispatch import receiver
logger = logging.getLogger('signals')
class SignalDebugMiddleware:
"""信号调试中间件"""
def __init__(self, get_response):
self.get_response = get_response
self.connect_debug_handlers()
def __call__(self, request):
return self.get_response(request)
def connect_debug_handlers(self):
from django.db.models import signals as model_signals
for name in dir(model_signals):
signal = getattr(model_signals, name)
if hasattr(signal, 'send'):
original_send = signal.send
def debug_send(self=signal, name=name):
def wrapper(sender, **kwargs):
logger.debug(f'Signal: {name}, Sender: {sender}')
return original_send(sender, **kwargs)
return wrapper
signal.send = debug_send()
八、性能考虑
8.1 中间件性能优化
class OptimizedMiddleware:
"""优化的中间件"""
def __init__(self, get_response):
self.get_response = get_response
# 预编译正则表达式
self.skip_patterns = [
re.compile(r'^/static/'),
re.compile(r'^/media/'),
re.compile(r'^/favicon\.ico$'),
]
def __call__(self, request):
# 跳过不需要处理的请求
if self.should_skip(request):
return self.get_response(request)
# 处理请求
return self.process_request(request)
def should_skip(self, request):
for pattern in self.skip_patterns:
if pattern.match(request.path):
return True
return False
def process_request(self, request):
# 实际处理逻辑
return self.get_response(request)
8.2 信号性能优化
# 使用 dispatch_uid 防止重复注册
@receiver(post_save, sender=Article, dispatch_uid='article_post_save_unique')
def article_post_save(sender, instance, **kwargs):
pass
# 批量操作时禁用信号
def bulk_import_articles(articles_data):
with DisableSignals([post_save]):
articles = [Article(**data) for data in articles_data]
Article.objects.bulk_create(articles)
# 手动触发后续处理
for article in articles:
process_new_article(article)
# 延迟处理
from django.db import transaction
@receiver(post_save, sender=Article)
def article_post_save_deferred(sender, instance, created, **kwargs):
if created:
# 在事务提交后执行
transaction.on_commit(lambda: process_article.delay(instance.pk))
九、测试
9.1 测试中间件
from django.test import TestCase, RequestFactory, override_settings
class MiddlewareTestCase(TestCase):
def setUp(self):
self.factory = RequestFactory()
self.middleware = SimpleMiddleware(lambda r: HttpResponse('OK'))
def test_middleware_processes_request(self):
request = self.factory.get('/test/')
response = self.middleware(request)
self.assertEqual(response.status_code, 200)
@override_settings(MAINTENANCE_MODE=True)
def test_maintenance_mode(self):
middleware = MaintenanceMiddleware(lambda r: HttpResponse('OK'))
request = self.factory.get('/test/')
response = middleware(request)
self.assertEqual(response.status_code, 503)
9.2 测试信号
from django.test import TestCase
from unittest.mock import patch, MagicMock
from django.db.models.signals import post_save
class SignalTestCase(TestCase):
@patch('blog.signals.notify_subscribers')
def test_article_published_signal(self, mock_notify):
"""测试文章发布信号"""
article = Article.objects.create(
title='Test',
status='published'
)
mock_notify.assert_called_once()
def test_signal_with_mock_receiver(self):
"""使用 mock receiver 测试信号"""
handler = MagicMock()
post_save.connect(handler, sender=Article)
article = Article.objects.create(title='Test')
handler.assert_called_once()
call_kwargs = handler.call_args[1]
self.assertEqual(call_kwargs['instance'], article)
self.assertTrue(call_kwargs['created'])
post_save.disconnect(handler, sender=Article)