依赖注入系统
2026/3/20大约 9 分钟
依赖注入系统
第一章:依赖注入基础
什么是依赖注入?
依赖注入(Dependency Injection, DI)是一种设计模式,用于实现控制反转(IoC)。在 FastAPI 中,依赖注入系统允许你声明函数所需的依赖,FastAPI 会自动解析并注入这些依赖。
依赖注入的优势
- 代码复用:共享逻辑可以被多个端点使用
- 关注点分离:业务逻辑与基础设施代码分离
- 易于测试:可以轻松替换依赖进行单元测试
- 类型安全:完整的类型提示支持
基础依赖示例
from fastapi import FastAPI, Depends
from typing import Optional
app = FastAPI()
# 简单的依赖函数
async def common_parameters(
q: Optional[str] = None,
skip: int = 0,
limit: int = 100
):
return {"q": q, "skip": skip, "limit": limit}
# 在多个端点中使用相同的依赖
@app.get("/items/")
async def read_items(commons: dict = Depends(common_parameters)):
return {"message": "Read items", "params": commons}
@app.get("/users/")
async def read_users(commons: dict = Depends(common_parameters)):
return {"message": "Read users", "params": commons}
类作为依赖
from fastapi import FastAPI, Depends
from typing import Optional
app = FastAPI()
class QueryParams:
def __init__(
self,
q: Optional[str] = None,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
order: str = "asc"
):
self.q = q
self.skip = skip
self.limit = limit
self.sort_by = sort_by
self.order = order
def get_offset(self) -> int:
return self.skip
def get_sql_order(self) -> str:
if self.sort_by:
return f"ORDER BY {self.sort_by} {self.order.upper()}"
return ""
@app.get("/items/")
async def read_items(params: QueryParams = Depends(QueryParams)):
return {
"query": params.q,
"offset": params.get_offset(),
"limit": params.limit,
"order": params.get_sql_order()
}
# 简写形式(当类型和 Depends 参数相同时)
@app.get("/items-short/")
async def read_items_short(params: QueryParams = Depends()):
return {"params": params.__dict__}
第二章:依赖层级与嵌套
嵌套依赖
from fastapi import FastAPI, Depends, HTTPException
from typing import Optional
app = FastAPI()
# 第一层:获取数据库连接
async def get_db():
db = DatabaseConnection()
try:
yield db
finally:
await db.close()
# 第二层:获取当前用户(依赖数据库)
async def get_current_user(
token: str,
db = Depends(get_db)
):
user = await db.get_user_by_token(token)
if not user:
raise HTTPException(status_code=401, detail="Invalid token")
return user
# 第三层:验证用户权限(依赖当前用户)
async def verify_admin(
current_user = Depends(get_current_user)
):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="Admin required")
return current_user
# 使用多层依赖
@app.get("/admin/dashboard")
async def admin_dashboard(admin = Depends(verify_admin)):
return {"admin": admin.username, "message": "Welcome to admin dashboard"}
依赖链条可视化
请求 -> get_db() -> get_current_user() -> verify_admin() -> admin_dashboard()
↓
返回数据库实例
↓
返回用户对象
↓
返回管理员对象
↓
处理请求
多重依赖
from fastapi import FastAPI, Depends, Query
from typing import Optional
app = FastAPI()
# 分页依赖
async def pagination(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
return {"page": page, "page_size": page_size, "offset": (page - 1) * page_size}
# 过滤依赖
async def filtering(
search: Optional[str] = None,
status: Optional[str] = None,
category: Optional[str] = None
):
filters = {}
if search:
filters["search"] = search
if status:
filters["status"] = status
if category:
filters["category"] = category
return filters
# 排序依赖
async def sorting(
sort_by: str = "created_at",
order: str = Query("desc", regex="^(asc|desc)$")
):
return {"sort_by": sort_by, "order": order}
# 组合多个依赖
@app.get("/items/")
async def list_items(
pagination: dict = Depends(pagination),
filters: dict = Depends(filtering),
sorting: dict = Depends(sorting)
):
return {
"pagination": pagination,
"filters": filters,
"sorting": sorting
}
第三章:依赖的生命周期
yield 依赖(资源管理)
from fastapi import FastAPI, Depends
from typing import Generator
import asyncio
app = FastAPI()
# 同步 yield 依赖
def get_sync_db() -> Generator:
db = SyncDatabase()
db.connect()
try:
yield db
finally:
db.disconnect()
print("数据库连接已关闭")
# 异步 yield 依赖
async def get_async_db():
db = AsyncDatabase()
await db.connect()
try:
yield db
finally:
await db.disconnect()
print("异步数据库连接已关闭")
# 带错误处理的 yield 依赖
async def get_db_with_transaction():
db = Database()
await db.connect()
transaction = await db.begin_transaction()
try:
yield db
await transaction.commit()
print("事务已提交")
except Exception as e:
await transaction.rollback()
print(f"事务已回滚: {e}")
raise
finally:
await db.close()
@app.post("/items/")
async def create_item(db = Depends(get_db_with_transaction)):
# 如果这里抛出异常,事务会自动回滚
await db.execute("INSERT INTO items ...")
return {"status": "created"}
依赖缓存
from fastapi import FastAPI, Depends
from uuid import uuid4
app = FastAPI()
# 默认情况下,同一请求中的依赖会被缓存
async def get_request_id():
request_id = str(uuid4())
print(f"生成 request_id: {request_id}")
return request_id
async def log_request(request_id: str = Depends(get_request_id)):
print(f"记录请求: {request_id}")
return request_id
async def process_request(request_id: str = Depends(get_request_id)):
print(f"处理请求: {request_id}")
return request_id
@app.get("/test-cache/")
async def test_cache(
log_id: str = Depends(log_request),
process_id: str = Depends(process_request)
):
# log_id 和 process_id 是相同的,因为 get_request_id 只调用一次
return {"log_id": log_id, "process_id": process_id}
# 禁用缓存(每次都调用)
@app.get("/test-no-cache/")
async def test_no_cache(
log_id: str = Depends(log_request),
process_id: str = Depends(get_request_id, use_cache=False)
):
# log_id 和 process_id 可能不同
return {"log_id": log_id, "process_id": process_id}
第四章:路由级别依赖
路由器依赖
from fastapi import FastAPI, APIRouter, Depends, HTTPException, Header
from typing import Optional
app = FastAPI()
# 验证 API 密钥
async def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != "secret-api-key":
raise HTTPException(status_code=403, detail="Invalid API Key")
return x_api_key
# 验证管理员权限
async def verify_admin_role(x_user_role: str = Header(...)):
if x_user_role != "admin":
raise HTTPException(status_code=403, detail="Admin role required")
return x_user_role
# 创建带依赖的路由器
admin_router = APIRouter(
prefix="/admin",
tags=["admin"],
dependencies=[Depends(verify_api_key), Depends(verify_admin_role)]
)
@admin_router.get("/users")
async def list_users():
return {"users": []}
@admin_router.get("/settings")
async def get_settings():
return {"settings": {}}
# 公开路由器(无依赖)
public_router = APIRouter(prefix="/public", tags=["public"])
@public_router.get("/health")
async def health_check():
return {"status": "ok"}
# 只需要 API 密钥的路由器
api_router = APIRouter(
prefix="/api",
tags=["api"],
dependencies=[Depends(verify_api_key)]
)
@api_router.get("/data")
async def get_data():
return {"data": []}
# 注册路由器
app.include_router(admin_router)
app.include_router(public_router)
app.include_router(api_router)
全局依赖
from fastapi import FastAPI, Depends, Request
from typing import Optional
import time
import logging
logger = logging.getLogger(__name__)
# 请求日志依赖
async def log_request(request: Request):
start_time = time.time()
logger.info(f"开始请求: {request.method} {request.url}")
yield
process_time = time.time() - start_time
logger.info(f"完成请求: {request.method} {request.url} - {process_time:.3f}s")
# 全局请求 ID
async def add_request_id(request: Request):
request.state.request_id = str(time.time())
return request.state.request_id
# 创建应用时添加全局依赖
app = FastAPI(
dependencies=[Depends(log_request), Depends(add_request_id)]
)
@app.get("/")
async def root():
return {"message": "Hello"}
@app.get("/items/")
async def list_items():
return {"items": []}
第五章:高级依赖模式
参数化依赖
from fastapi import FastAPI, Depends, HTTPException
from typing import List, Callable
app = FastAPI()
# 工厂函数创建参数化依赖
def permission_checker(required_permissions: List[str]):
async def check_permissions(user_permissions: List[str] = []):
for perm in required_permissions:
if perm not in user_permissions:
raise HTTPException(
status_code=403,
detail=f"Missing permission: {perm}"
)
return True
return check_permissions
# 使用参数化依赖
@app.get("/users/", dependencies=[Depends(permission_checker(["users:read"]))])
async def list_users():
return {"users": []}
@app.post("/users/", dependencies=[Depends(permission_checker(["users:write"]))])
async def create_user():
return {"status": "created"}
@app.delete(
"/users/{user_id}",
dependencies=[Depends(permission_checker(["users:write", "users:delete"]))]
)
async def delete_user(user_id: int):
return {"status": "deleted"}
可调用类依赖
from fastapi import FastAPI, Depends, HTTPException
from typing import Optional
app = FastAPI()
class RateLimiter:
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = {} # 简化版,生产环境应使用 Redis
async def __call__(self, client_ip: str = "127.0.0.1"):
import time
current_time = time.time()
window_start = current_time - self.window_seconds
# 清理过期记录
if client_ip in self.requests:
self.requests[client_ip] = [
t for t in self.requests[client_ip]
if t > window_start
]
else:
self.requests[client_ip] = []
# 检查请求数
if len(self.requests[client_ip]) >= self.max_requests:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Max {self.max_requests} requests per {self.window_seconds}s"
)
# 记录请求
self.requests[client_ip].append(current_time)
return True
# 创建限流器实例
rate_limiter = RateLimiter(max_requests=10, window_seconds=60)
@app.get("/api/data", dependencies=[Depends(rate_limiter)])
async def get_data():
return {"data": "some data"}
# 不同端点使用不同的限流配置
strict_limiter = RateLimiter(max_requests=5, window_seconds=60)
@app.post("/api/expensive-operation", dependencies=[Depends(strict_limiter)])
async def expensive_operation():
return {"status": "done"}
上下文管理器依赖
from fastapi import FastAPI, Depends
from contextlib import asynccontextmanager
from typing import AsyncGenerator
app = FastAPI()
# 使用 asynccontextmanager 创建依赖
@asynccontextmanager
async def get_db_session() -> AsyncGenerator:
session = AsyncSession()
await session.begin()
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
# 转换为 FastAPI 依赖
async def get_session():
async with get_db_session() as session:
yield session
@app.post("/items/")
async def create_item(session = Depends(get_session)):
# 使用 session
pass
条件依赖
from fastapi import FastAPI, Depends, Header, HTTPException
from typing import Optional
import os
app = FastAPI()
# 根据环境变量启用/禁用依赖
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
async def optional_auth(
authorization: Optional[str] = Header(None)
):
if not REQUIRE_AUTH:
return None # 开发环境跳过认证
if not authorization:
raise HTTPException(status_code=401, detail="Authorization required")
return authorization
# 根据请求参数选择依赖
async def get_fast_db():
return {"type": "fast", "connection": "fast_db"}
async def get_slow_db():
return {"type": "slow", "connection": "slow_db"}
def get_db_by_priority(priority: str = "normal"):
if priority == "high":
return get_fast_db
return get_slow_db
@app.get("/data")
async def get_data(
priority: str = "normal",
auth = Depends(optional_auth)
):
db_dep = get_db_by_priority(priority)
db = await db_dep()
return {"db": db, "auth": auth}
第六章:实际应用案例
完整的认证系统
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from typing import Optional
from datetime import datetime, timedelta
import jwt
app = FastAPI()
# 配置
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# OAuth2 配置
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 模型
class User(BaseModel):
id: int
username: str
email: str
is_active: bool = True
is_admin: bool = False
class TokenData(BaseModel):
username: Optional[str] = None
scopes: list[str] = []
# 模拟用户数据库
fake_users_db = {
"admin": {
"id": 1,
"username": "admin",
"email": "admin@example.com",
"hashed_password": "fakehashedsecret",
"is_active": True,
"is_admin": True
}
}
# 工具函数
def verify_password(plain_password: str, hashed_password: str) -> bool:
return plain_password == hashed_password.replace("fakehashed", "")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
# 依赖函数
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except jwt.PyJWTError:
raise credentials_exception
user_dict = fake_users_db.get(token_data.username)
if user_dict is None:
raise credentials_exception
return User(**user_dict)
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def get_current_admin_user(
current_user: User = Depends(get_current_active_user)
) -> User:
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
)
return current_user
# 端点
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user_dict = fake_users_db.get(form_data.username)
if not user_dict:
raise HTTPException(status_code=400, detail="Incorrect username or password")
if not verify_password(form_data.password, user_dict["hashed_password"]):
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token = create_access_token(
data={"sub": user_dict["username"]},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/users/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
@app.get("/admin/users")
async def admin_list_users(admin: User = Depends(get_current_admin_user)):
return {"users": list(fake_users_db.values())}
数据库会话管理
from fastapi import FastAPI, Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from typing import AsyncGenerator
# 数据库配置
DATABASE_URL = "postgresql+asyncpg://user:password@localhost/dbname"
engine = create_async_engine(DATABASE_URL, echo=True)
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
class Base(DeclarativeBase):
pass
app = FastAPI()
# 数据库会话依赖
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
# 使用类型注解的依赖
from typing import Annotated
DBSession = Annotated[AsyncSession, Depends(get_db)]
@app.get("/items/")
async def list_items(db: DBSession):
result = await db.execute("SELECT * FROM items")
return {"items": result.scalars().all()}
@app.post("/items/")
async def create_item(db: DBSession, name: str):
item = Item(name=name)
db.add(item)
await db.flush()
return {"id": item.id, "name": item.name}
配置管理依赖
from fastapi import FastAPI, Depends
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
app_name: str = "My App"
debug: bool = False
database_url: str = "sqlite:///./test.db"
redis_url: str = "redis://localhost:6379"
secret_key: str = "secret"
api_key: str = ""
class Config:
env_file = ".env"
@lru_cache()
def get_settings() -> Settings:
"""缓存配置,避免重复读取"""
return Settings()
app = FastAPI()
@app.get("/info")
async def info(settings: Settings = Depends(get_settings)):
return {
"app_name": settings.app_name,
"debug": settings.debug
}
# 使用 Annotated 简化
from typing import Annotated
ConfigDep = Annotated[Settings, Depends(get_settings)]
@app.get("/config")
async def get_config(settings: ConfigDep):
return {"app_name": settings.app_name}
常见问题
Q1:依赖注入和中间件有什么区别?
- 依赖注入:作用于单个端点或路由器,可以返回值给端点函数
- 中间件:作用于所有请求,不能直接返回值给端点函数
选择建议:
- 需要在端点中使用返回值 → 依赖注入
- 全局请求处理(日志、CORS 等) → 中间件
- 需要条件性应用 → 依赖注入
Q2:如何在依赖中访问请求对象?
from fastapi import Depends, Request
async def get_client_info(request: Request):
return {
"client_host": request.client.host,
"method": request.method,
"path": request.url.path
}
@app.get("/info")
async def info(client_info: dict = Depends(get_client_info)):
return client_info
Q3:如何测试带依赖的端点?
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock
# 覆盖依赖
async def override_get_db():
return AsyncMock()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
response = client.get("/items/")
# 清理
app.dependency_overrides.clear()