数据库操作
2026/3/20大约 10 分钟
数据库操作
第一章:数据库基础配置
SQLAlchemy 简介
SQLAlchemy 是 Python 中最流行的 ORM(对象关系映射)库,FastAPI 官方推荐使用 SQLAlchemy 进行数据库操作。SQLAlchemy 2.0 版本引入了更好的异步支持和类型提示。
安装依赖
# 核心依赖
pip install sqlalchemy
# 异步支持
pip install sqlalchemy[asyncio]
# 数据库驱动
pip install asyncpg # PostgreSQL 异步驱动
pip install aiomysql # MySQL 异步驱动
pip install aiosqlite # SQLite 异步驱动
# 数据库迁移
pip install alembic
同步数据库配置
# app/db/base.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
# 数据库 URL
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
# PostgreSQL: "postgresql://user:password@localhost/dbname"
# MySQL: "mysql+pymysql://user:password@localhost/dbname"
# 创建引擎
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False}, # SQLite 特有
echo=True, # 打印 SQL 语句
pool_size=5,
max_overflow=10
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 声明基类
class Base(DeclarativeBase):
pass
异步数据库配置
# app/db/base.py
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import DeclarativeBase
# 异步数据库 URL
DATABASE_URL = "postgresql+asyncpg://user:password@localhost/dbname"
# SQLite: "sqlite+aiosqlite:///./sql_app.db"
# 创建异步引擎
engine = create_async_engine(
DATABASE_URL,
echo=True,
pool_size=5,
max_overflow=10,
pool_pre_ping=True, # 连接健康检查
pool_recycle=3600, # 连接回收时间(秒)
)
# 创建异步会话工厂
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False
)
# 声明基类
class Base(DeclarativeBase):
pass
# 依赖注入
async def get_db():
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
第二章:模型定义
基础模型
# app/models/user.py
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.orm import relationship, Mapped, mapped_column
from sqlalchemy.sql import func
from datetime import datetime
from typing import Optional, List
from app.db.base import Base
class User(Base):
__tablename__ = "users"
# 使用新的 Mapped 类型注解(SQLAlchemy 2.0)
id: Mapped[int] = mapped_column(primary_key=True, index=True)
username: Mapped[str] = mapped_column(String(50), unique=True, index=True)
email: Mapped[str] = mapped_column(String(100), unique=True, index=True)
hashed_password: Mapped[str] = mapped_column(String(255))
full_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
bio: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
is_active: Mapped[bool] = mapped_column(default=True)
is_superuser: Mapped[bool] = mapped_column(default=False)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
onupdate=func.now(),
nullable=True
)
# 关系
posts: Mapped[List["Post"]] = relationship("Post", back_populates="author")
comments: Mapped[List["Comment"]] = relationship("Comment", back_populates="author")
def __repr__(self) -> str:
return f"<User(id={self.id}, username={self.username})>"
关系模型
# app/models/post.py
from sqlalchemy import String, Text, ForeignKey, Table, Column, Integer
from sqlalchemy.orm import relationship, Mapped, mapped_column
from datetime import datetime
from typing import Optional, List
from app.db.base import Base
# 多对多关联表
post_tags = Table(
"post_tags",
Base.metadata,
Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True),
Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True)
)
class Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
title: Mapped[str] = mapped_column(String(200))
slug: Mapped[str] = mapped_column(String(200), unique=True, index=True)
content: Mapped[str] = mapped_column(Text)
is_published: Mapped[bool] = mapped_column(default=False)
# 外键
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
# 关系 - 多对一
author: Mapped["User"] = relationship("User", back_populates="posts")
# 关系 - 一对多
comments: Mapped[List["Comment"]] = relationship(
"Comment",
back_populates="post",
cascade="all, delete-orphan"
)
# 关系 - 多对多
tags: Mapped[List["Tag"]] = relationship(
"Tag",
secondary=post_tags,
back_populates="posts"
)
class Tag(Base):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
slug: Mapped[str] = mapped_column(String(50), unique=True)
# 反向关系
posts: Mapped[List["Post"]] = relationship(
"Post",
secondary=post_tags,
back_populates="tags"
)
class Comment(Base):
__tablename__ = "comments"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
content: Mapped[str] = mapped_column(Text)
# 外键
post_id: Mapped[int] = mapped_column(ForeignKey("posts.id"))
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
# 关系
post: Mapped["Post"] = relationship("Post", back_populates="comments")
author: Mapped["User"] = relationship("User", back_populates="comments")
混入类和抽象模型
# app/models/mixins.py
from sqlalchemy import DateTime
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.sql import func
from datetime import datetime
from typing import Optional
class TimestampMixin:
"""时间戳混入"""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
onupdate=func.now(),
nullable=True
)
class SoftDeleteMixin:
"""软删除混入"""
is_deleted: Mapped[bool] = mapped_column(default=False)
deleted_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True
)
# 使用混入
class Article(Base, TimestampMixin, SoftDeleteMixin):
__tablename__ = "articles"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(Text)
第三章:CRUD 操作
基础 CRUD 类
# app/crud/base.py
from typing import TypeVar, Generic, Type, Optional, List, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
from pydantic import BaseModel
from app.db.base import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(
self,
db: AsyncSession,
id: int
) -> Optional[ModelType]:
"""获取单个对象"""
result = await db.execute(
select(self.model).where(self.model.id == id)
)
return result.scalar_one_or_none()
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100
) -> List[ModelType]:
"""获取多个对象"""
result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
return result.scalars().all()
async def create(
self,
db: AsyncSession,
*,
obj_in: CreateSchemaType
) -> ModelType:
"""创建对象"""
obj_data = obj_in.model_dump()
db_obj = self.model(**obj_data)
db.add(db_obj)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | dict[str, Any]
) -> ModelType:
"""更新对象"""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
if hasattr(db_obj, field):
setattr(db_obj, field, value)
db.add(db_obj)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def delete(
self,
db: AsyncSession,
*,
id: int
) -> Optional[ModelType]:
"""删除对象"""
obj = await self.get(db, id)
if obj:
await db.delete(obj)
await db.flush()
return obj
async def count(self, db: AsyncSession) -> int:
"""统计数量"""
from sqlalchemy import func
result = await db.execute(
select(func.count()).select_from(self.model)
)
return result.scalar()
具体 CRUD 实现
# app/crud/user.py
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
async def get_by_email(
self,
db: AsyncSession,
*,
email: str
) -> Optional[User]:
"""根据邮箱获取用户"""
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def get_by_username(
self,
db: AsyncSession,
*,
username: str
) -> Optional[User]:
"""根据用户名获取用户"""
result = await db.execute(
select(User).where(User.username == username)
)
return result.scalar_one_or_none()
async def create(
self,
db: AsyncSession,
*,
obj_in: UserCreate
) -> User:
"""创建用户(密码加密)"""
hashed_password = pwd_context.hash(obj_in.password)
db_obj = User(
username=obj_in.username,
email=obj_in.email,
hashed_password=hashed_password,
full_name=obj_in.full_name
)
db.add(db_obj)
await db.flush()
await db.refresh(db_obj)
return db_obj
async def authenticate(
self,
db: AsyncSession,
*,
username: str,
password: str
) -> Optional[User]:
"""验证用户"""
user = await self.get_by_username(db, username=username)
if not user:
return None
if not pwd_context.verify(password, user.hashed_password):
return None
return user
async def is_active(self, user: User) -> bool:
return user.is_active
async def is_superuser(self, user: User) -> bool:
return user.is_superuser
# 创建单例
user_crud = CRUDUser(User)
复杂查询
# app/crud/post.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from sqlalchemy.orm import selectinload, joinedload
from app.models.post import Post, Tag
from app.models.user import User
class CRUDPost:
async def get_with_relations(
self,
db: AsyncSession,
post_id: int
) -> Optional[Post]:
"""获取文章及其关联数据"""
result = await db.execute(
select(Post)
.options(
selectinload(Post.author),
selectinload(Post.comments).selectinload(Comment.author),
selectinload(Post.tags)
)
.where(Post.id == post_id)
)
return result.scalar_one_or_none()
async def search(
self,
db: AsyncSession,
*,
keyword: Optional[str] = None,
author_id: Optional[int] = None,
tag_ids: Optional[List[int]] = None,
is_published: Optional[bool] = None,
skip: int = 0,
limit: int = 20
) -> List[Post]:
"""搜索文章"""
query = select(Post).options(
selectinload(Post.author),
selectinload(Post.tags)
)
conditions = []
if keyword:
conditions.append(
or_(
Post.title.ilike(f"%{keyword}%"),
Post.content.ilike(f"%{keyword}%")
)
)
if author_id:
conditions.append(Post.author_id == author_id)
if is_published is not None:
conditions.append(Post.is_published == is_published)
if tag_ids:
query = query.join(Post.tags).where(Tag.id.in_(tag_ids))
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(Post.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
return result.scalars().unique().all()
async def get_stats(
self,
db: AsyncSession,
author_id: int
) -> dict:
"""获取作者统计信息"""
# 文章数量
post_count = await db.execute(
select(func.count(Post.id))
.where(Post.author_id == author_id)
)
# 已发布数量
published_count = await db.execute(
select(func.count(Post.id))
.where(and_(
Post.author_id == author_id,
Post.is_published == True
))
)
return {
"total_posts": post_count.scalar(),
"published_posts": published_count.scalar()
}
async def bulk_create(
self,
db: AsyncSession,
posts_data: List[dict]
) -> List[Post]:
"""批量创建"""
posts = [Post(**data) for data in posts_data]
db.add_all(posts)
await db.flush()
return posts
async def bulk_update(
self,
db: AsyncSession,
post_ids: List[int],
update_data: dict
) -> int:
"""批量更新"""
from sqlalchemy import update
result = await db.execute(
update(Post)
.where(Post.id.in_(post_ids))
.values(**update_data)
)
return result.rowcount
post_crud = CRUDPost()
第四章:数据库迁移
Alembic 配置
# 初始化 Alembic
alembic init alembic
# alembic/env.py
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
# 导入模型
from app.db.base import Base
from app.models import user, post # 导入所有模型
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""离线迁移"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""异步迁移"""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""在线迁移"""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
迁移命令
# 创建迁移
alembic revision --autogenerate -m "create users table"
# 执行迁移
alembic upgrade head
# 回滚迁移
alembic downgrade -1
# 查看历史
alembic history
# 查看当前版本
alembic current
迁移脚本示例
# alembic/versions/001_create_users_table.py
"""create users table
Revision ID: 001
Create Date: 2024-01-01 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '001'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('username', sa.String(50), nullable=False),
sa.Column('email', sa.String(100), nullable=False),
sa.Column('hashed_password', sa.String(255), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False, default=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_users_username', 'users', ['username'], unique=True)
op.create_index('ix_users_email', 'users', ['email'], unique=True)
def downgrade() -> None:
op.drop_index('ix_users_email', table_name='users')
op.drop_index('ix_users_username', table_name='users')
op.drop_table('users')
第五章:事务管理
自动事务
from fastapi import FastAPI, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.base import get_db
app = FastAPI()
# 依赖注入自动管理事务
async def get_db():
async with async_session() as session:
try:
yield session
await session.commit() # 成功则提交
except Exception:
await session.rollback() # 失败则回滚
raise
@app.post("/users/")
async def create_user(db: AsyncSession = Depends(get_db)):
# 如果这里抛出异常,事务会自动回滚
user = User(username="test")
db.add(user)
# 不需要手动 commit,依赖会处理
return user
手动事务控制
from sqlalchemy.ext.asyncio import AsyncSession
async def transfer_money(
db: AsyncSession,
from_account_id: int,
to_account_id: int,
amount: float
):
"""转账示例 - 需要事务保证"""
async with db.begin(): # 开始事务
# 查询账户
from_account = await db.get(Account, from_account_id)
to_account = await db.get(Account, to_account_id)
if from_account.balance < amount:
raise ValueError("余额不足")
# 扣款和加款
from_account.balance -= amount
to_account.balance += amount
# 创建交易记录
transaction = Transaction(
from_account_id=from_account_id,
to_account_id=to_account_id,
amount=amount
)
db.add(transaction)
# 事务自动提交或回滚
嵌套事务(Savepoint)
from sqlalchemy.ext.asyncio import AsyncSession
async def complex_operation(db: AsyncSession):
"""复杂操作 - 使用保存点"""
# 外层事务
async with db.begin():
user = User(username="outer")
db.add(user)
try:
# 内层保存点
async with db.begin_nested():
post = Post(title="inner", author=user)
db.add(post)
# 模拟错误
raise ValueError("内层错误")
except ValueError:
# 内层回滚,外层继续
pass
# 外层事务继续
await db.flush() # user 会被保存
第六章:性能优化
N+1 问题解决
from sqlalchemy.orm import selectinload, joinedload, subqueryload
# 错误示例 - N+1 问题
async def get_posts_bad(db: AsyncSession):
result = await db.execute(select(Post))
posts = result.scalars().all()
for post in posts:
# 每次访问 author 都会触发一次查询
print(post.author.username) # N 次额外查询
# 正确示例 - 使用 selectinload
async def get_posts_good(db: AsyncSession):
result = await db.execute(
select(Post).options(selectinload(Post.author))
)
posts = result.scalars().all()
for post in posts:
# author 已经预加载
print(post.author.username) # 无额外查询
# joinedload - 单次 JOIN 查询(适合一对一)
async def get_post_with_author(db: AsyncSession, post_id: int):
result = await db.execute(
select(Post)
.options(joinedload(Post.author))
.where(Post.id == post_id)
)
return result.scalar_one_or_none()
# selectinload - 额外 IN 查询(适合一对多)
async def get_posts_with_comments(db: AsyncSession):
result = await db.execute(
select(Post).options(
selectinload(Post.comments).selectinload(Comment.author)
)
)
return result.scalars().all()
批量操作
from sqlalchemy import insert, update, delete
# 批量插入
async def bulk_insert_users(db: AsyncSession, users_data: list[dict]):
await db.execute(
insert(User),
users_data
)
await db.commit()
# 批量更新
async def bulk_update_status(db: AsyncSession, user_ids: list[int], status: bool):
await db.execute(
update(User)
.where(User.id.in_(user_ids))
.values(is_active=status)
)
await db.commit()
# 批量删除
async def bulk_delete_users(db: AsyncSession, user_ids: list[int]):
await db.execute(
delete(User).where(User.id.in_(user_ids))
)
await db.commit()
连接池配置
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine(
DATABASE_URL,
# 连接池大小
pool_size=20,
# 最大溢出连接数
max_overflow=10,
# 连接超时(秒)
pool_timeout=30,
# 连接回收时间(秒)
pool_recycle=1800,
# 连接健康检查
pool_pre_ping=True,
# 打印 SQL
echo=False,
# 连接池类型
# pool_class=NullPool # 禁用连接池
)
查询优化
# 只选择需要的列
async def get_user_names(db: AsyncSession):
result = await db.execute(
select(User.id, User.username) # 只查询特定列
)
return result.all()
# 使用分页
async def get_users_paginated(db: AsyncSession, page: int, size: int):
offset = (page - 1) * size
result = await db.execute(
select(User)
.order_by(User.id)
.offset(offset)
.limit(size)
)
return result.scalars().all()
# 使用索引
# 在模型中定义索引
class User(Base):
__tablename__ = "users"
__table_args__ = (
Index("ix_users_email_active", "email", "is_active"),
)
# 使用 exists 代替 count
from sqlalchemy import exists
async def user_exists(db: AsyncSession, email: str) -> bool:
result = await db.execute(
select(exists().where(User.email == email))
)
return result.scalar()
常见问题
Q1:如何处理数据库连接泄露?
# 使用上下文管理器确保连接释放
async def safe_db_operation():
async with async_session() as session:
try:
result = await session.execute(select(User))
return result.scalars().all()
finally:
await session.close() # 确保关闭
# 配置连接池超时
engine = create_async_engine(
DATABASE_URL,
pool_timeout=30, # 获取连接超时
pool_recycle=3600, # 定期回收连接
)
Q2:如何处理并发更新?
from sqlalchemy import update
from sqlalchemy.orm import with_for_update
# 使用 FOR UPDATE 锁定行
async def update_balance(db: AsyncSession, user_id: int, amount: float):
result = await db.execute(
select(User)
.where(User.id == user_id)
.with_for_update() # 行锁
)
user = result.scalar_one()
user.balance += amount
await db.commit()
# 使用乐观锁
class User(Base):
version: Mapped[int] = mapped_column(default=0)
async def optimistic_update(db: AsyncSession, user_id: int, data: dict):
user = await db.get(User, user_id)
current_version = user.version
result = await db.execute(
update(User)
.where(and_(User.id == user_id, User.version == current_version))
.values(**data, version=current_version + 1)
)
if result.rowcount == 0:
raise ConcurrencyError("数据已被其他用户修改")
Q3:如何优雅地处理软删除?
from sqlalchemy import event
from sqlalchemy.orm import Query
class SoftDeleteMixin:
is_deleted: Mapped[bool] = mapped_column(default=False)
deleted_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)
def soft_delete(self):
self.is_deleted = True
self.deleted_at = datetime.now()
# 全局过滤已删除记录
@event.listens_for(Query, "before_compile", retval=True)
def soft_delete_filter(query):
for desc in query.column_descriptions:
entity = desc.get("entity")
if entity and hasattr(entity, "is_deleted"):
query = query.filter(entity.is_deleted == False)
return query
学习资源
- SQLAlchemy 官方文档:https://docs.sqlalchemy.org/
- FastAPI 数据库教程:https://fastapi.tiangolo.com/tutorial/sql-databases/
- Alembic 文档:https://alembic.sqlalchemy.org/
- 异步 SQLAlchemy:https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html