中间件与安全认证
2026/3/20大约 9 分钟
中间件与安全认证
第一章:中间件基础
什么是中间件?
中间件是在每个请求被处理之前和每个响应返回之前执行的代码。它可以用于:
- 请求/响应日志记录
- 认证和授权
- CORS 处理
- 请求限流
- 错误处理
- 响应压缩
基础中间件
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import time
import logging
app = FastAPI()
logger = logging.getLogger(__name__)
# 使用装饰器创建中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 请求日志中间件
@app.middleware("http")
async def log_requests(request: Request, call_next):
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
logger.info(f"Response: {response.status_code}")
return response
类形式中间件
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from typing import Callable
import time
class TimingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: Callable
) -> Response:
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = f"{process_time:.4f}"
return response
class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: Callable
) -> Response:
import uuid
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
# 注册中间件
app.add_middleware(TimingMiddleware)
app.add_middleware(RequestIDMiddleware)
纯 ASGI 中间件
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.datastructures import MutableHeaders
class PureASGIMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
headers.append("X-Custom-Header", "value")
await send(message)
await self.app(scope, receive, send_wrapper)
# 添加中间件
app.add_middleware(PureASGIMiddleware)
第二章:常用内置中间件
CORS 中间件
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"https://example.com",
],
allow_origin_regex=r"https://.*\.example\.com", # 正则匹配
allow_credentials=True, # 允许携带 Cookie
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["X-Custom-Header"], # 暴露给前端的响应头
max_age=600, # 预检请求缓存时间(秒)
)
# 开发环境允许所有来源
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
GZip 压缩中间件
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
app = FastAPI()
# 响应压缩
app.add_middleware(
GZipMiddleware,
minimum_size=1000 # 最小压缩大小(字节)
)
信任代理中间件
from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware
app = FastAPI()
# 信任的主机
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com"]
)
HTTPS 重定向中间件
from fastapi import FastAPI
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
app = FastAPI()
# 强制 HTTPS
app.add_middleware(HTTPSRedirectMiddleware)
第三章:OAuth2 认证
OAuth2 密码模式
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from typing import Optional
from datetime import datetime, timedelta
import jwt
app = FastAPI()
# 配置
SECRET_KEY = "your-super-secret-key-keep-it-safe"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# OAuth2 配置
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="token",
scopes={
"read": "读取权限",
"write": "写入权限",
"admin": "管理员权限"
}
)
# 模型
class Token(BaseModel):
access_token: str
token_type: str
expires_in: int
class TokenData(BaseModel):
username: Optional[str] = None
scopes: list[str] = []
class User(BaseModel):
username: str
email: str
full_name: Optional[str] = None
disabled: Optional[bool] = None
scopes: list[str] = []
class UserInDB(User):
hashed_password: str
# 模拟数据库
fake_users_db = {
"admin": {
"username": "admin",
"email": "admin@example.com",
"full_name": "Admin User",
"disabled": False,
"hashed_password": "fakehashedsecret",
"scopes": ["read", "write", "admin"]
},
"user": {
"username": "user",
"email": "user@example.com",
"full_name": "Normal User",
"disabled": False,
"hashed_password": "fakehashedsecret",
"scopes": ["read"]
}
}
# 工具函数
def verify_password(plain_password: str, hashed_password: str) -> bool:
# 实际应用中使用 passlib
return f"fakehashed{plain_password}" == hashed_password
def get_user(db: dict, username: str) -> Optional[UserInDB]:
if username in db:
return UserInDB(**db[username])
return None
def authenticate_user(db: dict, username: str, password: str) -> Optional[UserInDB]:
user = get_user(db, username)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(
data: dict,
expires_delta: Optional[timedelta] = None
) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
scopes: list = payload.get("scopes", [])
if username is None:
raise credentials_exception
token_data = TokenData(username=username, scopes=scopes)
except jwt.PyJWTError:
raise credentials_exception
user = get_user(fake_users_db, token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
# 登录端点
@app.post("/token", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(
fake_users_db,
form_data.username,
form_data.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# 验证请求的 scopes
for scope in form_data.scopes:
if scope not in user.scopes:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Not authorized for scope: {scope}"
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username, "scopes": form_data.scopes or user.scopes},
expires_delta=access_token_expires
)
return Token(
access_token=access_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
# 受保护的端点
@app.get("/users/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
OAuth2 Scopes(权限范围)
from fastapi import Security
from fastapi.security import SecurityScopes
# 验证权限
async def get_current_user_with_scopes(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme)
) -> User:
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
token_scopes = payload.get("scopes", [])
if username is None:
raise credentials_exception
except jwt.PyJWTError:
raise credentials_exception
user = get_user(fake_users_db, username)
if user is None:
raise credentials_exception
# 验证权限
for scope in security_scopes.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Not enough permissions. Required: {scope}",
headers={"WWW-Authenticate": authenticate_value},
)
return user
# 使用 Security 指定所需权限
@app.get("/items/")
async def read_items(
current_user: User = Security(get_current_user_with_scopes, scopes=["read"])
):
return {"items": []}
@app.post("/items/")
async def create_item(
current_user: User = Security(get_current_user_with_scopes, scopes=["write"])
):
return {"status": "created"}
@app.delete("/items/{item_id}")
async def delete_item(
item_id: int,
current_user: User = Security(get_current_user_with_scopes, scopes=["admin"])
):
return {"status": "deleted"}
第四章:JWT 深入
完整的 JWT 实现
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from datetime import datetime, timedelta
from typing import Optional
import jwt
from passlib.context import CryptContext
app = FastAPI()
# 配置
class Settings:
SECRET_KEY: str = "your-super-secret-key"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
settings = Settings()
# 密码哈希
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class TokenPair(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
class TokenPayload(BaseModel):
sub: str
exp: datetime
type: str # "access" or "refresh"
scopes: list[str] = []
def create_access_token(
subject: str,
scopes: list[str] = [],
expires_delta: Optional[timedelta] = None
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {
"sub": subject,
"exp": expire,
"type": "access",
"scopes": scopes,
"iat": datetime.utcnow()
}
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def create_refresh_token(
subject: str,
expires_delta: Optional[timedelta] = None
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": subject,
"exp": expire,
"type": "refresh",
"iat": datetime.utcnow()
}
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def create_token_pair(subject: str, scopes: list[str] = []) -> TokenPair:
return TokenPair(
access_token=create_access_token(subject, scopes),
refresh_token=create_refresh_token(subject)
)
def decode_token(token: str) -> TokenPayload:
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
return TokenPayload(**payload)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired"
)
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
# Token 刷新端点
@app.post("/token/refresh")
async def refresh_token(refresh_token: str):
payload = decode_token(refresh_token)
if payload.type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type"
)
# 创建新的 token pair
return create_token_pair(payload.sub)
Token 黑名单
from datetime import datetime
from typing import Set
import redis.asyncio as redis
# Redis 连接
redis_client = redis.from_url("redis://localhost:6379")
class TokenBlacklist:
@staticmethod
async def add(token: str, exp: datetime):
"""添加 token 到黑名单"""
ttl = int((exp - datetime.utcnow()).total_seconds())
if ttl > 0:
await redis_client.setex(f"blacklist:{token}", ttl, "1")
@staticmethod
async def is_blacklisted(token: str) -> bool:
"""检查 token 是否在黑名单中"""
result = await redis_client.get(f"blacklist:{token}")
return result is not None
# 在验证 token 时检查黑名单
async def validate_token(token: str = Depends(oauth2_scheme)) -> TokenPayload:
if await TokenBlacklist.is_blacklisted(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been revoked"
)
return decode_token(token)
# 登出端点
@app.post("/logout")
async def logout(token: str = Depends(oauth2_scheme)):
payload = decode_token(token)
await TokenBlacklist.add(token, payload.exp)
return {"message": "Successfully logged out"}
第五章:API Key 认证
Header API Key
from fastapi import FastAPI, Security, HTTPException, status
from fastapi.security import APIKeyHeader
app = FastAPI()
API_KEY = "your-api-key-here"
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key"
)
return api_key
@app.get("/secure-data")
async def secure_data(api_key: str = Depends(verify_api_key)):
return {"data": "This is protected data"}
Query API Key
from fastapi.security import APIKeyQuery
api_key_query = APIKeyQuery(name="api_key", auto_error=True)
async def verify_api_key_query(api_key: str = Security(api_key_query)):
if api_key != API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key"
)
return api_key
Cookie API Key
from fastapi.security import APIKeyCookie
api_key_cookie = APIKeyCookie(name="api_key", auto_error=True)
async def verify_api_key_cookie(api_key: str = Security(api_key_cookie)):
if api_key != API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key"
)
return api_key
多种认证方式组合
from fastapi import FastAPI, Security, HTTPException, Depends
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
from typing import Optional
app = FastAPI()
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
async def get_api_key(
api_key_header: Optional[str] = Security(api_key_header),
api_key_query: Optional[str] = Security(api_key_query),
oauth_token: Optional[str] = Security(oauth2_scheme)
):
# 优先使用 OAuth2 token
if oauth_token:
# 验证 OAuth2 token
return {"type": "oauth2", "token": oauth_token}
# 其次使用 Header API Key
if api_key_header:
if api_key_header == "valid-api-key":
return {"type": "api_key", "key": api_key_header}
# 最后使用 Query API Key
if api_key_query:
if api_key_query == "valid-api-key":
return {"type": "api_key", "key": api_key_query}
raise HTTPException(
status_code=403,
detail="No valid authentication provided"
)
@app.get("/flexible-auth")
async def flexible_auth(auth = Depends(get_api_key)):
return {"auth": auth}
第六章:高级安全特性
请求限流
from fastapi import FastAPI, Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
import time
from collections import defaultdict
import asyncio
app = FastAPI()
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
requests_per_minute: int = 60,
burst_size: int = 10
):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.burst_size = burst_size
self.requests = defaultdict(list)
self.lock = asyncio.Lock()
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()
minute_ago = current_time - 60
async with self.lock:
# 清理过期记录
self.requests[client_ip] = [
t for t in self.requests[client_ip]
if t > minute_ago
]
# 检查限流
if len(self.requests[client_ip]) >= self.requests_per_minute:
return HTTPException(
status_code=429,
detail="Too many requests"
)
# 检查突发
recent = [
t for t in self.requests[client_ip]
if t > current_time - 1
]
if len(recent) >= self.burst_size:
return HTTPException(
status_code=429,
detail="Too many requests (burst limit)"
)
# 记录请求
self.requests[client_ip].append(current_time)
response = await call_next(request)
# 添加限流信息到响应头
remaining = self.requests_per_minute - len(self.requests[client_ip])
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
return response
app.add_middleware(RateLimitMiddleware, requests_per_minute=100)
基于 Redis 的分布式限流
import redis.asyncio as redis
from fastapi import FastAPI, Request, HTTPException, Depends
import time
app = FastAPI()
redis_client = redis.from_url("redis://localhost:6379")
async def rate_limit(
request: Request,
limit: int = 100,
window: int = 60
):
"""滑动窗口限流"""
client_ip = request.client.host
key = f"rate_limit:{client_ip}"
current_time = int(time.time())
window_start = current_time - window
pipe = redis_client.pipeline()
# 移除过期记录
pipe.zremrangebyscore(key, 0, window_start)
# 添加当前请求
pipe.zadd(key, {str(current_time): current_time})
# 获取当前窗口请求数
pipe.zcard(key)
# 设置过期时间
pipe.expire(key, window)
results = await pipe.execute()
request_count = results[2]
if request_count > limit:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers={
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
"Retry-After": str(window)
}
)
return {
"limit": limit,
"remaining": limit - request_count,
"reset": current_time + window
}
@app.get("/api/data")
async def get_data(rate_info: dict = Depends(rate_limit)):
return {"data": "success", "rate_limit": rate_info}
CSRF 保护
from fastapi import FastAPI, Request, HTTPException, Form, Depends
from fastapi.responses import HTMLResponse
from starlette.middleware.sessions import SessionMiddleware
import secrets
app = FastAPI()
# 添加会话中间件
app.add_middleware(SessionMiddleware, secret_key="your-session-secret")
def generate_csrf_token() -> str:
return secrets.token_urlsafe(32)
async def get_csrf_token(request: Request) -> str:
if "csrf_token" not in request.session:
request.session["csrf_token"] = generate_csrf_token()
return request.session["csrf_token"]
async def verify_csrf_token(
request: Request,
csrf_token: str = Form(...)
):
session_token = request.session.get("csrf_token")
if not session_token or csrf_token != session_token:
raise HTTPException(
status_code=403,
detail="CSRF token validation failed"
)
# 使用后重新生成
request.session["csrf_token"] = generate_csrf_token()
return True
@app.get("/form", response_class=HTMLResponse)
async def get_form(csrf_token: str = Depends(get_csrf_token)):
return f"""
<html>
<body>
<form method="post" action="/submit">
<input type="hidden" name="csrf_token" value="{csrf_token}">
<input type="text" name="data">
<button type="submit">Submit</button>
</form>
</body>
</html>
"""
@app.post("/submit")
async def submit_form(
data: str = Form(...),
_: bool = Depends(verify_csrf_token)
):
return {"data": data, "status": "success"}
安全响应头
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
app = FastAPI()
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# 安全响应头
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = "default-src 'self'"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = "geolocation=(), microphone=()"
return response
app.add_middleware(SecurityHeadersMiddleware)
常见问题
Q1:如何实现单点登录(SSO)?
# 使用 OpenID Connect
from authlib.integrations.starlette_client import OAuth
oauth = OAuth()
oauth.register(
name='google',
client_id='YOUR_CLIENT_ID',
client_secret='YOUR_CLIENT_SECRET',
authorize_url='https://accounts.google.com/o/oauth2/v2/auth',
authorize_params={'scope': 'openid email profile'},
access_token_url='https://oauth2.googleapis.com/token',
jwks_uri='https://www.googleapis.com/oauth2/v3/certs',
)
@app.get("/login/google")
async def login_google(request: Request):
redirect_uri = request.url_for('auth_callback')
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.get("/auth/callback")
async def auth_callback(request: Request):
token = await oauth.google.authorize_access_token(request)
user_info = token.get('userinfo')
# 创建或更新用户,生成 JWT
return {"user": user_info}
Q2:如何处理 Token 自动续期?
@app.middleware("http")
async def refresh_token_middleware(request: Request, call_next):
response = await call_next(request)
# 检查是否需要刷新 token
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload = decode_token(token)
# 如果 token 即将过期(5分钟内),自动刷新
if (payload.exp - datetime.utcnow()).total_seconds() < 300:
new_token = create_access_token(payload.sub)
response.headers["X-New-Token"] = new_token
except:
pass
return response
Q3:如何记录认证审计日志?
import logging
from datetime import datetime
audit_logger = logging.getLogger("audit")
async def audit_log(
event_type: str,
user_id: str = None,
ip_address: str = None,
details: dict = None
):
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"event_type": event_type,
"user_id": user_id,
"ip_address": ip_address,
"details": details or {}
}
audit_logger.info(log_entry)
@app.post("/token")
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(form_data.username, form_data.password)
if not user:
await audit_log(
"LOGIN_FAILED",
user_id=form_data.username,
ip_address=request.client.host,
details={"reason": "Invalid credentials"}
)
raise HTTPException(status_code=401, detail="Invalid credentials")
await audit_log(
"LOGIN_SUCCESS",
user_id=user.username,
ip_address=request.client.host
)
return create_token_pair(user.username)
学习资源
- FastAPI 安全文档:https://fastapi.tiangolo.com/tutorial/security/
- OAuth2 规范:https://oauth.net/2/
- JWT 规范:https://jwt.io/
- OWASP 安全指南:https://owasp.org/www-project-web-security-testing-guide/