高级特性与最佳实践
2026/3/20大约 7 分钟
高级特性与最佳实践
第一章:WebSocket 实时通信
WebSocket 基础
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List
app = FastAPI()
# 连接管理器
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
带认证的 WebSocket
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException
from typing import Optional
import jwt
app = FastAPI()
async def verify_token(token: str) -> Optional[dict]:
try:
payload = jwt.decode(token, "secret", algorithms=["HS256"])
return payload
except jwt.PyJWTError:
return None
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(...)
):
# 验证 token
user = await verify_token(token)
if not user:
await websocket.close(code=4001)
return
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
# 处理消息
response = {"user": user["username"], "message": data}
await websocket.send_json(response)
except WebSocketDisconnect:
print(f"User {user['username']} disconnected")
房间系统
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import Dict, Set
from dataclasses import dataclass, field
app = FastAPI()
@dataclass
class Room:
name: str
connections: Set[WebSocket] = field(default_factory=set)
async def broadcast(self, message: dict, exclude: WebSocket = None):
for ws in self.connections:
if ws != exclude:
await ws.send_json(message)
class RoomManager:
def __init__(self):
self.rooms: Dict[str, Room] = {}
def get_or_create_room(self, room_name: str) -> Room:
if room_name not in self.rooms:
self.rooms[room_name] = Room(name=room_name)
return self.rooms[room_name]
async def join_room(self, room_name: str, websocket: WebSocket, username: str):
room = self.get_or_create_room(room_name)
room.connections.add(websocket)
await room.broadcast({"type": "join", "user": username})
async def leave_room(self, room_name: str, websocket: WebSocket, username: str):
if room_name in self.rooms:
room = self.rooms[room_name]
room.connections.discard(websocket)
await room.broadcast({"type": "leave", "user": username})
if not room.connections:
del self.rooms[room_name]
async def send_message(self, room_name: str, message: dict, sender: WebSocket):
if room_name in self.rooms:
await self.rooms[room_name].broadcast(message, exclude=sender)
room_manager = RoomManager()
@app.websocket("/ws/{room_name}/{username}")
async def room_websocket(
websocket: WebSocket,
room_name: str,
username: str
):
await websocket.accept()
await room_manager.join_room(room_name, websocket, username)
try:
while True:
data = await websocket.receive_json()
message = {
"type": "message",
"user": username,
"content": data.get("content")
}
await room_manager.send_message(room_name, message, websocket)
except WebSocketDisconnect:
await room_manager.leave_room(room_name, websocket, username)
第二章:Server-Sent Events (SSE)
基础 SSE
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
from datetime import datetime
app = FastAPI()
async def event_generator():
while True:
data = {
"time": datetime.now().isoformat(),
"message": "Hello from server"
}
yield f"data: {data}\n\n"
await asyncio.sleep(1)
@app.get("/events")
async def get_events():
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
带事件类型的 SSE
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
import json
app = FastAPI()
async def notification_generator(user_id: str):
"""发送用户通知"""
event_id = 0
while True:
# 检查新通知
notifications = await get_user_notifications(user_id)
for notification in notifications:
event_id += 1
event_type = notification.get("type", "message")
data = json.dumps(notification)
yield f"id: {event_id}\n"
yield f"event: {event_type}\n"
yield f"data: {data}\n\n"
await asyncio.sleep(5)
@app.get("/notifications/{user_id}")
async def get_notifications(user_id: str):
return StreamingResponse(
notification_generator(user_id),
media_type="text/event-stream"
)
第三章:GraphQL 集成
Strawberry GraphQL
# pip install strawberry-graphql[fastapi]
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
import strawberry
from typing import List, Optional
# 类型定义
@strawberry.type
class User:
id: int
username: str
email: str
posts: List["Post"]
@strawberry.type
class Post:
id: int
title: str
content: str
author: User
# 查询
@strawberry.type
class Query:
@strawberry.field
async def users(self) -> List[User]:
return await get_all_users()
@strawberry.field
async def user(self, id: int) -> Optional[User]:
return await get_user_by_id(id)
@strawberry.field
async def posts(self, limit: int = 10) -> List[Post]:
return await get_posts(limit)
# 变更
@strawberry.type
class Mutation:
@strawberry.mutation
async def create_user(self, username: str, email: str) -> User:
return await create_user(username, email)
@strawberry.mutation
async def create_post(self, title: str, content: str, author_id: int) -> Post:
return await create_post(title, content, author_id)
# 订阅
@strawberry.type
class Subscription:
@strawberry.subscription
async def new_post(self) -> Post:
async for post in post_stream():
yield post
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
subscription=Subscription
)
app = FastAPI()
graphql_app = GraphQLRouter(schema)
app.include_router(graphql_app, prefix="/graphql")
DataLoader 解决 N+1 问题
from strawberry.dataloader import DataLoader
from typing import List
# 批量加载函数
async def load_users(user_ids: List[int]) -> List[User]:
users = await db.query(User).filter(User.id.in_(user_ids)).all()
user_map = {u.id: u for u in users}
return [user_map.get(uid) for uid in user_ids]
# 创建 DataLoader
user_loader = DataLoader(load_fn=load_users)
@strawberry.type
class Post:
id: int
title: str
author_id: int
@strawberry.field
async def author(self) -> User:
return await user_loader.load(self.author_id)
第四章:后台任务与调度
APScheduler 定时任务
from fastapi import FastAPI
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from contextlib import asynccontextmanager
scheduler = AsyncIOScheduler()
async def cleanup_expired_tokens():
"""清理过期令牌"""
print("Cleaning up expired tokens...")
await token_service.cleanup_expired()
async def send_daily_report():
"""发送每日报告"""
print("Sending daily report...")
await report_service.send_daily()
async def sync_external_data():
"""同步外部数据"""
print("Syncing external data...")
await sync_service.sync()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 添加定时任务
scheduler.add_job(
cleanup_expired_tokens,
CronTrigger(hour=0, minute=0), # 每天凌晨
id="cleanup_tokens"
)
scheduler.add_job(
send_daily_report,
CronTrigger(hour=8, minute=0), # 每天8点
id="daily_report"
)
scheduler.add_job(
sync_external_data,
"interval",
minutes=30, # 每30分钟
id="sync_data"
)
scheduler.start()
yield
scheduler.shutdown()
app = FastAPI(lifespan=lifespan)
# 动态管理任务
@app.post("/jobs/{job_id}/pause")
async def pause_job(job_id: str):
scheduler.pause_job(job_id)
return {"status": "paused"}
@app.post("/jobs/{job_id}/resume")
async def resume_job(job_id: str):
scheduler.resume_job(job_id)
return {"status": "resumed"}
@app.get("/jobs")
async def list_jobs():
jobs = scheduler.get_jobs()
return [{"id": j.id, "next_run": str(j.next_run_time)} for j in jobs]
第五章:事件驱动架构
事件发布订阅
from fastapi import FastAPI
from typing import Callable, Dict, List
import asyncio
app = FastAPI()
class EventBus:
def __init__(self):
self._handlers: Dict[str, List[Callable]] = {}
def subscribe(self, event_type: str, handler: Callable):
if event_type not in self._handlers:
self._handlers[event_type] = []
self._handlers[event_type].append(handler)
def unsubscribe(self, event_type: str, handler: Callable):
if event_type in self._handlers:
self._handlers[event_type].remove(handler)
async def publish(self, event_type: str, data: dict):
if event_type in self._handlers:
await asyncio.gather(*[
handler(data) for handler in self._handlers[event_type]
])
event_bus = EventBus()
# 定义事件处理器
async def on_user_created(data: dict):
print(f"User created: {data}")
await send_welcome_email(data["email"])
async def on_user_created_analytics(data: dict):
await analytics.track("user_signup", data)
# 注册处理器
event_bus.subscribe("user.created", on_user_created)
event_bus.subscribe("user.created", on_user_created_analytics)
# 使用装饰器注册
def event_handler(event_type: str):
def decorator(func):
event_bus.subscribe(event_type, func)
return func
return decorator
@event_handler("order.completed")
async def on_order_completed(data: dict):
await notify_warehouse(data)
# 在端点中发布事件
@app.post("/users/")
async def create_user(user: UserCreate):
new_user = await user_service.create(user)
await event_bus.publish("user.created", new_user.dict())
return new_user
使用 Redis Pub/Sub
import redis.asyncio as redis
import json
from fastapi import FastAPI
app = FastAPI()
class RedisPubSub:
def __init__(self, redis_url: str):
self.redis = redis.from_url(redis_url)
self.pubsub = self.redis.pubsub()
async def publish(self, channel: str, message: dict):
await self.redis.publish(channel, json.dumps(message))
async def subscribe(self, channel: str, handler):
await self.pubsub.subscribe(channel)
async for message in self.pubsub.listen():
if message["type"] == "message":
data = json.loads(message["data"])
await handler(data)
pubsub = RedisPubSub("redis://localhost:6379")
# 发布事件
@app.post("/orders/")
async def create_order(order: OrderCreate):
new_order = await order_service.create(order)
await pubsub.publish("orders", {
"event": "order.created",
"data": new_order.dict()
})
return new_order
第六章:最佳实践
项目结构
app/
├── __init__.py
├── main.py # 应用入口
├── config.py # 配置管理
├── api/
│ ├── __init__.py
│ ├── deps.py # 公共依赖
│ ├── v1/
│ │ ├── __init__.py
│ │ ├── router.py # 路由聚合
│ │ └── endpoints/
│ │ ├── __init__.py
│ │ ├── users.py
│ │ ├── items.py
│ │ └── auth.py
│ └── v2/ # API 版本控制
├── core/
│ ├── __init__.py
│ ├── security.py # 安全相关
│ ├── exceptions.py # 自定义异常
│ └── events.py # 事件处理
├── models/
│ ├── __init__.py
│ ├── base.py # 模型基类
│ ├── user.py
│ └── item.py
├── schemas/
│ ├── __init__.py
│ ├── user.py
│ └── item.py
├── crud/
│ ├── __init__.py
│ ├── base.py # CRUD 基类
│ ├── user.py
│ └── item.py
├── services/
│ ├── __init__.py
│ ├── user_service.py # 业务逻辑
│ └── email_service.py
├── db/
│ ├── __init__.py
│ ├── base.py # 数据库配置
│ └── session.py
├── middleware/
│ ├── __init__.py
│ └── logging.py
├── utils/
│ ├── __init__.py
│ └── helpers.py
└── tests/
├── __init__.py
├── conftest.py
└── test_api/
错误处理标准化
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional, Any
from enum import Enum
class ErrorCode(str, Enum):
VALIDATION_ERROR = "VALIDATION_ERROR"
NOT_FOUND = "NOT_FOUND"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
INTERNAL_ERROR = "INTERNAL_ERROR"
EXTERNAL_SERVICE_ERROR = "EXTERNAL_SERVICE_ERROR"
class ErrorResponse(BaseModel):
success: bool = False
error: dict
request_id: Optional[str] = None
class AppException(Exception):
def __init__(
self,
code: ErrorCode,
message: str,
status_code: int = 400,
details: Optional[Any] = None
):
self.code = code
self.message = message
self.status_code = status_code
self.details = details
class NotFoundError(AppException):
def __init__(self, resource: str, id: Any):
super().__init__(
code=ErrorCode.NOT_FOUND,
message=f"{resource} with id {id} not found",
status_code=404,
details={"resource": resource, "id": id}
)
class ValidationError(AppException):
def __init__(self, message: str, details: dict = None):
super().__init__(
code=ErrorCode.VALIDATION_ERROR,
message=message,
status_code=422,
details=details
)
# 注册异常处理器
app = FastAPI()
@app.exception_handler(AppException)
async def app_exception_handler(request: Request, exc: AppException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
error={
"code": exc.code.value,
"message": exc.message,
"details": exc.details
},
request_id=getattr(request.state, "request_id", None)
).model_dump()
)
日志最佳实践
import logging
import sys
from datetime import datetime
import json
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import uuid
# 配置结构化日志
class JSONFormatter(logging.Formatter):
def format(self, record):
log_data = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno,
}
if hasattr(record, "request_id"):
log_data["request_id"] = record.request_id
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
return json.dumps(log_data)
# 配置日志
def setup_logging():
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JSONFormatter())
root_logger = logging.getLogger()
root_logger.handlers = [handler]
root_logger.setLevel(logging.INFO)
# 降低第三方库日志级别
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
# 请求日志中间件
class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
logger = logging.getLogger(__name__)
logger.info(
f"Request started: {request.method} {request.url.path}",
extra={"request_id": request_id}
)
try:
response = await call_next(request)
logger.info(
f"Request completed: {response.status_code}",
extra={"request_id": request_id}
)
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
logger.exception(
f"Request failed: {str(e)}",
extra={"request_id": request_id}
)
raise
app = FastAPI()
setup_logging()
app.add_middleware(RequestLoggingMiddleware)
API 版本控制
from fastapi import FastAPI, APIRouter
app = FastAPI()
# V1 API
v1_router = APIRouter(prefix="/api/v1")
@v1_router.get("/users/{user_id}")
async def get_user_v1(user_id: int):
return {"version": "v1", "user_id": user_id}
# V2 API(新版本)
v2_router = APIRouter(prefix="/api/v2")
@v2_router.get("/users/{user_id}")
async def get_user_v2(user_id: int):
return {
"version": "v2",
"data": {"user_id": user_id, "extra_field": "new"}
}
app.include_router(v1_router)
app.include_router(v2_router)
# 使用 Header 进行版本控制
from fastapi import Header
@app.get("/users/{user_id}")
async def get_user(
user_id: int,
api_version: str = Header("v1", alias="X-API-Version")
):
if api_version == "v2":
return {"version": "v2", "data": {"user_id": user_id}}
return {"version": "v1", "user_id": user_id}
配置管理
from pydantic_settings import BaseSettings
from functools import lru_cache
from typing import Optional
class Settings(BaseSettings):
# 应用配置
APP_NAME: str = "FastAPI App"
DEBUG: bool = False
ENVIRONMENT: str = "production"
# 数据库配置
DATABASE_URL: str
DATABASE_POOL_SIZE: int = 20
DATABASE_MAX_OVERFLOW: int = 10
# Redis 配置
REDIS_URL: str = "redis://localhost:6379"
# 安全配置
SECRET_KEY: str
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
# 外部服务
EMAIL_SERVICE_URL: Optional[str] = None
SENTRY_DSN: Optional[str] = None
class Config:
env_file = ".env"
case_sensitive = True
@property
def is_development(self) -> bool:
return self.ENVIRONMENT == "development"
@property
def is_production(self) -> bool:
return self.ENVIRONMENT == "production"
@lru_cache()
def get_settings() -> Settings:
return Settings()
# 使用
from fastapi import Depends
@app.get("/config")
async def get_config(settings: Settings = Depends(get_settings)):
return {"app_name": settings.APP_NAME}
常见问题
Q1:如何优雅地处理大文件上传?
from fastapi import FastAPI, UploadFile, File, BackgroundTasks
import aiofiles
CHUNK_SIZE = 1024 * 1024 # 1MB
@app.post("/upload/large")
async def upload_large_file(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None
):
file_path = f"/tmp/{file.filename}"
async with aiofiles.open(file_path, "wb") as f:
while chunk := await file.read(CHUNK_SIZE):
await f.write(chunk)
# 后台处理文件
if background_tasks:
background_tasks.add_task(process_file, file_path)
return {"filename": file.filename, "status": "uploaded"}
Q2:如何实现请求追踪?
from opentelemetry import trace
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
tracer = trace.get_tracer(__name__)
# 自动追踪
FastAPIInstrumentor.instrument_app(app)
# 手动追踪
@app.get("/traced")
async def traced_endpoint():
with tracer.start_as_current_span("custom_operation"):
result = await some_operation()
return result
Q3:如何处理长轮询?
import asyncio
from fastapi import FastAPI
@app.get("/long-poll")
async def long_poll(timeout: int = 30):
try:
result = await asyncio.wait_for(
wait_for_update(),
timeout=timeout
)
return {"data": result}
except asyncio.TimeoutError:
return {"data": None, "timeout": True}
学习资源
- FastAPI 官方文档:https://fastapi.tiangolo.com/
- Starlette 文档:https://www.starlette.io/
- Pydantic 文档:https://docs.pydantic.dev/
- SQLAlchemy 文档:https://docs.sqlalchemy.org/
- Strawberry GraphQL:https://strawberry.rocks/
- FastAPI 最佳实践:https://github.com/zhanymkanov/fastapi-best-practices