SQLAlchemy 2.0 实战:异步 ORM 的配置与模型定义

📂 所属阶段:第三阶段 — 数据持久化(数据库篇)
🔗 相关章节:依赖注入系统 · Redis 集成 · Alembic 数据库迁移


1. 为什么选择 SQLAlchemy 2.0?

1.1 SQLAlchemy 2.0 的核心改进

  • 原生异步支持AsyncSession + asyncpg / aiomysql,无需同步线程池
  • 统一查询语法:Core 和 ORM 使用同一套 select() API
  • 更严格的类型提示Mapped[int] 等注解让代码即文档
  • 迁移友好:与 Alembic 深度集成

1.2 项目依赖安装

pip install sqlalchemy[asyncio] asyncpg aiosqlite
# sqlalchemy[asyncio] 包含 asyncio 扩展
# asyncpg → PostgreSQL(生产推荐)
# aiosqlite → SQLite(开发/轻量)

2. 基础配置

2.1 数据库连接配置

# database.py
from sqlalchemy.ext.asyncio import (
    AsyncSession, create_async_engine, async_sessionmaker
)
from sqlalchemy.orm import DeclarativeBase
from config import get_settings

settings = get_settings()

# PostgreSQL 连接字符串格式:postgresql+asyncpg://
DATABASE_URL = settings.database_url  # "postgresql+asyncpg://user:pass@localhost/mydb"
# SQLite 连接字符串格式:sqlite+aiosqlite:///
# DATABASE_URL = "sqlite+aiosqlite:///./app.db"

engine = create_async_engine(
    DATABASE_URL,
    echo=settings.debug,          # 打印 SQL 语句(开发用)
    pool_size=20,                  # 连接池大小
    max_overflow=10,               # 超出 pool_size 的连接数
    pool_pre_ping=True,            # 连接前检测是否有效
)

# 会话工厂
AsyncSessionLocal = async_sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,        # 提交后不自动刷新对象
    autocommit=False,
    autoflush=False,
)

# 基础模型类
class Base(DeclarativeBase):
    pass

2.2 在 FastAPI 中管理数据库生命周期

# main.py
from contextlib import asynccontextmanager
from database import engine, Base

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动:创建所有表
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    yield
    # 关闭:释放连接池
    await engine.dispose()

app = FastAPI(lifespan=lifespan)

2.3 数据库依赖

# dependencies.py
from database import AsyncSessionLocal

async def get_db() -> AsyncSession:
    async with AsyncSessionLocal() as session:
        yield session

3. 模型定义

3.1 定义用户模型

# models/user.py
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime

class User(Base):
    __tablename__ = "users"

    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
    name: Mapped[str] = mapped_column(String(50), nullable=False)
    email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
    hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
    is_active: Mapped[bool] = mapped_column(Boolean, default=True)
    is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
    created_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now()
    )
    updated_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        onupdate=func.now()
    )

    # 关系
    posts: Mapped[list["Post"]] = relationship("Post", back_populates="author", lazy="selectin")

    def __repr__(self):
        return f"<User {self.email}>"

3.2 定义帖子模型(多对多关系)

# models/post.py
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, Text, Table, Column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime

# 多对多关联表
tags = Table(
    "tags", Base.metadata,
    Column("post_id", Integer, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True),
    Column("tag_id", Integer, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)

class Tag(Base):
    __tablename__ = "tags"
    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(30), unique=True, nullable=False)

class Post(Base):
    __tablename__ = "posts"

    id: Mapped[int] = mapped_column(primary_key=True)
    title: Mapped[str] = mapped_column(String(200), nullable=False)
    content: Mapped[str] = mapped_column(Text, nullable=False)
    published: Mapped[bool] = mapped_column(Boolean, default=False)
    author_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
    created_at: Mapped[datetime] = mapped_column(server_default=func.now())

    # 关系
    author: Mapped["User"] = relationship("User", back_populates="posts")
    tags: Mapped[list["Tag"]] = relationship("Tag", secondary=tags, lazy="selectin")

4. CRUD 操作

4.1 基础查询

# schemas/user.py
from pydantic import BaseModel, EmailStr
from datetime import datetime

class UserBase(BaseModel):
    name: str
    email: EmailStr

class UserCreate(UserBase):
    password: str

class UserUpdate(BaseModel):
    name: str | None = None
    email: EmailStr | None = None
    password: str | None = None

class UserSchema(UserBase):
    id: int
    is_active: bool
    created_at: datetime

    class Config:
        from_attributes = True

4.2 Repository 模式封装

# repositories/user_repository.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from sqlalchemy.orm import selectinload
from models.user import User
from models.post import Post
from typing import Optional

class UserRepository:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def get_by_id(self, user_id: int) -> Optional[User]:
        result = await self.db.execute(
            select(User).where(User.id == user_id)
        )
        return result.scalar_one_or_none()

    async def get_by_email(self, email: str) -> Optional[User]:
        result = await self.db.execute(
            select(User).where(User.email == email)
        )
        return result.scalar_one_or_none()

    async def get_multi(
        self, skip: int = 0, limit: int = 20, search: str | None = None
    ) -> list[User]:
        query = select(User)
        if search:
            query = query.where(
                or_(
                    User.name.ilike(f"%{search}%"),
                    User.email.ilike(f"%{search}%")
                )
            )
        query = query.offset(skip).limit(limit).order_by(User.created_at.desc())
        result = await self.db.execute(query)
        return list(result.scalars().all())

    async def create(self, user_data: dict) -> User:
        user = User(**user_data)
        self.db.add(user)
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def update(self, user: User, update_data: dict) -> User:
        for key, value in update_data.items():
            if value is not None:
                setattr(user, key, value)
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def delete(self, user: User) -> None:
        await self.db.delete(user)
        await self.db.commit()

    async def get_user_posts(self, user_id: int) -> list[Post]:
        result = await self.db.execute(
            select(Post)
            .where(Post.author_id == user_id)
            .options(selectinload(Post.tags))
            .order_by(Post.created_at.desc())
        )
        return list(result.scalars().all())

5. 在路由中使用

5.1 用户 CRUD 路由

# routers/users.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from dependencies import get_db
from repositories.user_repository import UserRepository
from schemas.user import UserSchema, UserCreate, UserUpdate
from passlib.context import CryptContext

router = APIRouter(prefix="/users", tags=["用户管理"])
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_repo(db: AsyncSession = Depends(get_db)) -> UserRepository:
    return UserRepository(db)

def hash_password(password: str) -> str:
    return pwd_context.hash(password)

@router.get("/", response_model=List[UserSchema])
async def list_users(
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    search: str | None = None,
    repo: UserRepository = Depends(get_repo),
):
    users = await repo.get_multi(skip=skip, limit=limit, search=search)
    return users

@router.get("/{user_id}", response_model=UserSchema)
async def get_user(user_id: int, repo: UserRepository = Depends(get_repo)):
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return user

@router.post("/", response_model=UserSchema, status_code=status.HTTP_201_CREATED)
async def create_user(
    data: UserCreate,
    repo: UserRepository = Depends(get_repo),
):
    existing = await repo.get_by_email(data.email)
    if existing:
        raise HTTPException(status_code=409, detail="邮箱已被注册")

    user = await repo.create({
        **data.model_dump(exclude={"password"}),
        "hashed_password": hash_password(data.password),
    })
    return user

@router.put("/{user_id}", response_model=UserSchema)
async def update_user(
    user_id: int,
    data: UserUpdate,
    repo: UserRepository = Depends(get_repo),
):
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")

    update_dict = data.model_dump(exclude_unset=True)
    if "password" in update_dict:
        update_dict["hashed_password"] = hash_password(update_dict.pop("password"))

    updated = await repo.update(user, update_dict)
    return updated

@router.delete("/{user_id}", status_code=204)
async def delete_user(user_id: int, repo: UserRepository = Depends(get_repo)):
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    await repo.delete(user)

6. 事务与批量操作

6.1 事务控制

async def bulk_create_posts(db: AsyncSession, posts_data: list[dict]):
    async with db.begin():
        for data in posts_data:
            post = Post(**data)
            db.add(post)
    # 自动提交,异常自动回滚
    await db.commit()

6.2 乐观锁(版本控制)

class User(Base):
    # ... 其他字段
    version: Mapped[int] = mapped_column(Integer, default=1)

async def update_with_optimistic_lock(
    repo: UserRepository, user_id: int, expected_version: int, data: dict
):
    user = await repo.get_by_id(user_id)
    if user.version != expected_version:
        raise HTTPException(409, "数据已被其他人修改,请刷新后重试")
    data["version"] = expected_version + 1
    return await repo.update(user, data)

7. 常用查询模式速查

# 条件查询
select(User).where(User.is_active == True)

# 模糊搜索
User.name.ilike("%alice%")

# 分页
select(User).offset(0).limit(10)

# 排序
select(User).order_by(User.created_at.desc())

# 预加载关联(避免 N+1 问题)
select(Post).options(selectinload(Post.author), selectinload(Post.tags))

# 聚合统计
from sqlalchemy import func, select
stmt = select(User.role, func.count(User.id)).group_by(User.role)
result = await db.execute(stmt)

# 联表查询
stmt = select(Post.title, User.name).join(User, Post.author_id == User.id)
result = await db.execute(stmt)

8. 小结

组件作用
create_async_engine创建异步数据库引擎
AsyncSessionLocal异步会话工厂
Mapped[T] + mapped_columnSQLAlchemy 2.0 类型注解
DeclarativeBase所有模型的基类
select() + execute()异步查询
selectinload预加载关联,防止 N+1
lifespan 管理器在 FastAPI 启动/关闭时管理连接池

💡 最佳实践:使用 Repository 模式封装数据访问,路由只处理 HTTP 逻辑,数据库操作交给 Repository。好处是代码可测试,业务逻辑与框架解耦。


🔗 扩展阅读