FastAPI与SQLAlchemy 2.0完全指南

📂 所属阶段:第三阶段 — 数据持久化(数据库篇)
🔗 相关章节:async-await-原理与实战 · Redis-集成 · 数据库迁移工具Alembic

目录

为什么选择SQLAlchemy 2.0?

SQLAlchemy 2.0的核心改进

SQLAlchemy 2.0代表了Python ORM领域的重要里程碑,带来了革命性的变化:

改进点详细说明优势
原生异步支持AsyncSession + asyncpg / aiomysql无需同步线程池,真正的异步性能
统一查询语法Core和ORM使用同一套select() API学习成本降低,代码一致性提高
更严格的类型提示Mapped[int]等注解IDE支持更好,代码即文档
现代化API设计更简洁的语法,更好的错误处理开发体验显著提升
性能优化更高效的查询执行,更好的内存管理应用性能提升

项目依赖安装

# SQLAlchemy 2.0 异步支持
pip install sqlalchemy[asyncio]

# 数据库驱动(选择其中一个)
# PostgreSQL (生产环境推荐)
pip install asyncpg

# SQLite (开发/测试环境)
pip install aiosqlite

# MySQL
pip install aiomysql

# 其他数据库支持
pip install asyncmy  # MySQL alternative
pip install aiosqlite  # SQLite async driver

完整依赖示例

# requirements.txt
fastapi==0.104.1
sqlalchemy[asyncio]==2.0.23
asyncpg==0.29.0  # PostgreSQL
aiosqlite==0.19.0  # SQLite
aiomysql==0.2.0  # MySQL
alembic==1.13.1  # 数据库迁移
pydantic==2.5.0
python-multipart==0.0.6
passlib[bcrypt]==1.7.4
cryptography==41.0.8

异步数据库配置

数据库连接配置详解

# database.py
from sqlalchemy.ext.asyncio import (
    AsyncSession, 
    create_async_engine, 
    async_sessionmaker,
    AsyncEngine
)
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.pool import QueuePool
from config import get_settings
import logging
from typing import AsyncGenerator
import asyncio

settings = get_settings()

# 配置日志
logger = logging.getLogger(__name__)

class DatabaseManager:
    """数据库管理器 - 集中管理数据库连接和会话"""
    
    def __init__(self):
        self.engine: AsyncEngine = None
        self.sessionmaker: async_sessionmaker = None
        self._initialized = False
    
    async def initialize(self):
        """初始化数据库连接"""
        if self._initialized:
            return
        
        # 数据库连接字符串配置
        DATABASE_URL = settings.database_url
        
        # 创建异步引擎
        self.engine = create_async_engine(
            DATABASE_URL,
            # 连接池配置
            poolclass=QueuePool,
            pool_size=settings.database_pool_size,        # 连接池大小
            max_overflow=settings.database_max_overflow,  # 最大溢出连接数
            pool_pre_ping=True,                          # 连接前检测
            pool_recycle=3600,                           # 连接回收时间(秒)
            pool_timeout=30,                             # 连接超时时间
            echo=settings.database_echo,                 # SQL日志输出
            echo_pool=True,                              # 连接池日志
            future=True,                                 # 使用SQLAlchemy 2.0 API
        )
        
        # 会话工厂配置
        self.sessionmaker = async_sessionmaker(
            self.engine,
            class_=AsyncSession,
            expire_on_commit=False,       # 提交后不自动刷新对象
            autocommit=False,             # 手动提交事务
            autoflush=False,              # 手动刷新
            expire_on_commit=True,        # 提交后过期对象
        )
        
        self._initialized = True
        logger.info("Database initialized successfully")
    
    async def dispose(self):
        """关闭数据库连接"""
        if self.engine:
            await self.engine.dispose()
            self._initialized = False
            logger.info("Database connections disposed")
    
    async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
        """获取数据库会话"""
        async with self.sessionmaker() as session:
            try:
                yield session
            except Exception:
                await session.rollback()
                raise
            finally:
                await session.close()

# 全局数据库管理器实例
db_manager = DatabaseManager()

# 基础模型类
class Base(DeclarativeBase):
    """所有模型的基类"""
    pass

# 便捷的会话获取函数
async def get_db() -> AsyncGenerator[AsyncSession, None]:
    """依赖注入用的数据库会话获取函数"""
    async for session in db_manager.get_session():
        yield session

配置管理

# config.py
from pydantic import BaseSettings
from functools import lru_cache
from typing import List, Optional

class Settings(BaseSettings):
    # 应用配置
    app_name: str = "FastAPI SQLAlchemy 2.0 App"
    version: str = "1.0.0"
    debug: bool = False
    
    # 数据库配置
    database_url: str = "postgresql+asyncpg://user:password@localhost/dbname"
    database_pool_size: int = 20
    database_max_overflow: int = 10
    database_echo: bool = False  # 开发环境设为True以查看SQL
    database_pool_timeout: int = 30
    database_pool_recycle: int = 3600
    
    # Redis配置(可选)
    redis_url: str = "redis://localhost:6379/0"
    
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

@lru_cache()
def get_settings() -> Settings:
    return Settings()

settings = get_settings()

在FastAPI中管理数据库生命周期

# main.py
from fastapi import FastAPI
from contextlib import asynccontextmanager
from database import db_manager, Base
import logging

logger = logging.getLogger(__name__)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用生命周期管理"""
    logger.info("Initializing database...")
    
    # 启动时初始化数据库
    await db_manager.initialize()
    
    # 创建所有表(如果不存在)
    async with db_manager.engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    
    logger.info("Database initialized successfully")
    yield
    
    # 关闭时清理资源
    logger.info("Shutting down database...")
    await db_manager.dispose()
    logger.info("Database connections disposed")

app = FastAPI(
    title="FastAPI SQLAlchemy 2.0 API",
    version="1.0.0",
    description="使用SQLAlchemy 2.0异步ORM的FastAPI应用",
    lifespan=lifespan
)

# 健康检查端点
@app.get("/health")
async def health_check():
    """健康检查端点"""
    return {
        "status": "healthy",
        "database": "connected",
        "timestamp": "2024-01-15T10:30:00Z"
    }

# 根路径
@app.get("/")
async def root():
    """API根路径"""
    return {
        "message": "FastAPI SQLAlchemy 2.0 API",
        "version": "1.0.0",
        "docs": "/docs"
    }

数据库依赖注入

# dependencies.py
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db

# 数据库会话依赖
async def get_database_session() -> AsyncSession:
    """获取数据库会话的依赖函数"""
    async with db_manager.get_session() as session:
        yield session

# 在路由中使用
from fastapi import APIRouter
from sqlalchemy.ext.asyncio import AsyncSession

router = APIRouter()

@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_database_session)):
    """示例路由 - 使用数据库会话"""
    # 在这里可以使用db进行数据库操作
    return {"users": []}

模型定义与关系映射

用户模型定义

# models/user.py
from sqlalchemy import (
    String, Integer, Boolean, DateTime, ForeignKey, Text, 
    Index, UniqueConstraint, CheckConstraint
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import List, TYPE_CHECKING

if TYPE_CHECKING:
    from .post import Post
    from .comment import Comment

class User(Base):
    """用户模型"""
    __tablename__ = "users"
    
    # 基础字段
    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="用户ID")
    name: Mapped[str] = mapped_column(String(50), nullable=False, comment="用户姓名")
    email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False, comment="邮箱地址")
    hashed_password: Mapped[str] = mapped_column(String(255), nullable=False, comment="哈希密码")
    
    # 状态字段
    is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment="是否活跃")
    is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否超级用户")
    is_verified: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否已验证")
    
    # 时间戳
    created_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now(),
        comment="创建时间"
    )
    updated_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now(),
        onupdate=func.now(),
        comment="更新时间"
    )
    
    # 关系映射
    posts: Mapped[List["Post"]] = relationship(
        "Post", 
        back_populates="author", 
        lazy="selectin",
        cascade="all, delete-orphan"
    )
    comments: Mapped[List["Comment"]] = relationship(
        "Comment", 
        back_populates="author", 
        lazy="selectin",
        cascade="all, delete-orphan"
    )
    
    # 索引定义
    __table_args__ = (
        Index("idx_users_email", "email"),
        Index("idx_users_created_at", "created_at"),
        UniqueConstraint("email", name="unique_email"),
        CheckConstraint("length(name) >= 2", name="name_length_check"),
    )

    def __repr__(self):
        return f"<User(id={self.id}, email='{self.email}')>"
    
    def to_dict(self) -> dict:
        """转换为字典格式"""
        return {
            "id": self.id,
            "name": self.name,
            "email": self.email,
            "is_active": self.is_active,
            "is_superuser": self.is_superuser,
            "is_verified": self.is_verified,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "updated_at": self.updated_at.isoformat() if self.updated_at else None,
        }

帖子模型(一对多关系)

# models/post.py
from sqlalchemy import (
    String, Integer, Boolean, DateTime, ForeignKey, Text, 
    Index, UniqueConstraint
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import List, TYPE_CHECKING

if TYPE_CHECKING:
    from .user import User
    from .comment import Comment
    from .tag import Tag

class Post(Base):
    """帖子模型"""
    __tablename__ = "posts"
    
    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="帖子ID")
    title: Mapped[str] = mapped_column(String(200), nullable=False, comment="标题")
    slug: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, comment="URL友好标识")
    content: Mapped[str] = mapped_column(Text, nullable=False, comment="内容")
    excerpt: Mapped[str] = mapped_column(Text, comment="摘要")
    
    # 状态字段
    published: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否发布")
    featured: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否特色")
    
    # 关联字段
    author_id: Mapped[int] = mapped_column(
        ForeignKey("users.id", ondelete="CASCADE"), 
        nullable=False, 
        comment="作者ID"
    )
    
    # 时间戳
    created_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now(),
        comment="创建时间"
    )
    updated_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        onupdate=func.now(),
        comment="更新时间"
    )
    published_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        nullable=True,
        comment="发布时间"
    )
    
    # 关系映射
    author: Mapped["User"] = relationship("User", back_populates="posts")
    comments: Mapped[List["Comment"]] = relationship(
        "Comment", 
        back_populates="post", 
        lazy="selectin",
        cascade="all, delete-orphan"
    )
    tags: Mapped[List["Tag"]] = relationship(
        "Tag", 
        secondary="post_tags", 
        back_populates="posts", 
        lazy="selectin"
    )
    
    # 索引定义
    __table_args__ = (
        Index("idx_posts_slug", "slug"),
        Index("idx_posts_author", "author_id"),
        Index("idx_posts_published", "published"),
        Index("idx_posts_created_at", "created_at"),
        Index("idx_posts_featured", "featured"),
        UniqueConstraint("slug", name="unique_post_slug"),
    )

    def __repr__(self):
        return f"<Post(id={self.id}, title='{self.title}')>"
    
    def to_dict(self) -> dict:
        """转换为字典格式"""
        return {
            "id": self.id,
            "title": self.title,
            "slug": self.slug,
            "content": self.content,
            "excerpt": self.excerpt,
            "published": self.published,
            "featured": self.featured,
            "author_id": self.author_id,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "updated_at": self.updated_at.isoformat() if self.updated_at else None,
            "published_at": self.published_at.isoformat() if self.published_at else None,
        }

标签模型(多对多关系)

# models/tag.py
from sqlalchemy import String, Integer, DateTime, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import List, TYPE_CHECKING

if TYPE_CHECKING:
    from .post import Post

# 多对多关联表
from sqlalchemy import Table, Column, ForeignKey

post_tags = Table(
    "post_tags",
    Base.metadata,
    Column("post_id", Integer, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True),
    Column("tag_id", Integer, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)

class Tag(Base):
    """标签模型"""
    __tablename__ = "tags"
    
    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="标签ID")
    name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, comment="标签名称")
    slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, comment="URL友好标识")
    description: Mapped[str] = mapped_column(Text, comment="标签描述")
    
    # 时间戳
    created_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now(),
        comment="创建时间"
    )
    updated_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        onupdate=func.now(),
        comment="更新时间"
    )
    
    # 关系映射
    posts: Mapped[List["Post"]] = relationship(
        "Post", 
        secondary=post_tags, 
        back_populates="tags", 
        lazy="selectin"
    )

    def __repr__(self):
        return f"<Tag(id={self.id}, name='{self.name}')>"
    
    def to_dict(self) -> dict:
        """转换为字典格式"""
        return {
            "id": self.id,
            "name": self.name,
            "slug": self.slug,
            "description": self.description,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "updated_at": self.updated_at.isoformat() if self.updated_at else None,
        }

评论模型

# models/comment.py
from sqlalchemy import String, Integer, DateTime, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .user import User
    from .post import Post

class Comment(Base):
    """评论模型"""
    __tablename__ = "comments"
    
    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="评论ID")
    content: Mapped[str] = mapped_column(Text, nullable=False, comment="评论内容")
    approved: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否审核通过")
    
    # 关联字段
    author_id: Mapped[int] = mapped_column(
        ForeignKey("users.id", ondelete="CASCADE"), 
        nullable=False, 
        comment="作者ID"
    )
    post_id: Mapped[int] = mapped_column(
        ForeignKey("posts.id", ondelete="CASCADE"), 
        nullable=False, 
        comment="帖子ID"
    )
    parent_id: Mapped[int] = mapped_column(
        ForeignKey("comments.id", ondelete="CASCADE"), 
        nullable=True, 
        comment="父评论ID(用于嵌套评论)"
    )
    
    # 时间戳
    created_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now(),
        comment="创建时间"
    )
    updated_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        onupdate=func.now(),
        comment="更新时间"
    )
    
    # 关系映射
    author: Mapped["User"] = relationship("User", back_populates="comments")
    post: Mapped["Post"] = relationship("Post", back_populates="comments")
    parent: Mapped["Comment"] = relationship("Comment", remote_side=[id], backref="replies")
    
    # 索引定义
    __table_args__ = (
        Index("idx_comments_post", "post_id"),
        Index("idx_comments_author", "author_id"),
        Index("idx_comments_approved", "approved"),
        Index("idx_comments_parent", "parent_id"),
        Index("idx_comments_created_at", "created_at"),
    )

    def __repr__(self):
        return f"<Comment(id={self.id}, post_id={self.post_id})>"
    
    def to_dict(self) -> dict:
        """转换为字典格式"""
        return {
            "id": self.id,
            "content": self.content,
            "approved": self.approved,
            "author_id": self.author_id,
            "post_id": self.post_id,
            "parent_id": self.parent_id,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "updated_at": self.updated_at.isoformat() if self.updated_at else None,
        }

CRUD操作与Repository模式

Pydantic模型定义

# schemas/user.py
from pydantic import BaseModel, EmailStr, validator
from datetime import datetime
from typing import Optional, List
from enum import Enum

class UserRole(str, Enum):
    """用户角色枚举"""
    USER = "user"
    ADMIN = "admin"
    SUPERUSER = "superuser"

class UserBase(BaseModel):
    """用户基础模型"""
    name: str
    email: EmailStr

class UserCreate(UserBase):
    """用户创建模型"""
    password: str
    
    @validator('name')
    def name_must_be_valid(cls, v):
        if len(v.strip()) < 2:
            raise ValueError('姓名长度至少为2个字符')
        return v.strip()

class UserUpdate(BaseModel):
    """用户更新模型"""
    name: Optional[str] = None
    email: Optional[EmailStr] = None
    is_active: Optional[bool] = None
    password: Optional[str] = None

class UserInDB(UserBase):
    """数据库中的用户模型"""
    id: int
    is_active: bool
    is_superuser: bool
    is_verified: bool
    created_at: datetime
    updated_at: Optional[datetime] = None

    class Config:
        from_attributes = True

class UserPublic(BaseModel):
    """公开的用户信息模型"""
    id: int
    name: str
    email: EmailStr
    created_at: datetime

    class Config:
        from_attributes = True

Repository模式实现

# repositories/base_repository.py
from typing import TypeVar, Generic, Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, func
from sqlalchemy.orm import selectinload
from database import Base
from pydantic import BaseModel

T = TypeVar('T', bound=Base)
S = TypeVar('S', bound=BaseModel)

class BaseRepository(Generic[T]):
    """基础仓库类 - 提供通用的CRUD操作"""
    
    def __init__(self, model: T, db_session: AsyncSession):
        self.model = model
        self.db = db_session

    async def get_by_id(self, id: int) -> Optional[T]:
        """根据ID获取单个记录"""
        stmt = select(self.model).where(self.model.id == id)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_multi(
        self, 
        skip: int = 0, 
        limit: int = 100, 
        filters: Optional[Dict[str, Any]] = None,
        order_by: Optional[str] = None,
        descending: bool = False
    ) -> List[T]:
        """获取多条记录"""
        stmt = select(self.model)
        
        # 应用过滤条件
        if filters:
            for field, value in filters.items():
                stmt = stmt.where(getattr(self.model, field) == value)
        
        # 排序
        if order_by:
            attr = getattr(self.model, order_by)
            stmt = stmt.order_by(attr.desc() if descending else attr.asc())
        
        stmt = stmt.offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return list(result.scalars().all())

    async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
        """计数"""
        stmt = select(func.count()).select_from(self.model)
        if filters:
            for field, value in filters.items():
                stmt = stmt.where(getattr(self.model, field) == value)
        result = await self.db.execute(stmt)
        return result.scalar_one()

    async def create(self, obj_in: Dict[str, Any]) -> T:
        """创建记录"""
        db_obj = self.model(**obj_in)
        self.db.add(db_obj)
        await self.db.commit()
        await self.db.refresh(db_obj)
        return db_obj

    async def update(self, db_obj: T, obj_in: Dict[str, Any]) -> T:
        """更新记录"""
        for field, value in obj_in.items():
            if value is not None:
                setattr(db_obj, field, value)
        await self.db.commit()
        await self.db.refresh(db_obj)
        return db_obj

    async def delete(self, id: int) -> bool:
        """删除记录"""
        stmt = delete(self.model).where(self.model.id == id)
        result = await self.db.execute(stmt)
        await self.db.commit()
        return result.rowcount > 0

    async def exists(self, **kwargs) -> bool:
        """检查记录是否存在"""
        stmt = select(self.model).filter_by(**kwargs)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none() is not None

用户Repository实现

# repositories/user_repository.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_
from sqlalchemy.orm import selectinload
from repositories.base_repository import BaseRepository
from models.user import User
from models.post import Post
from models.comment import Comment
from passlib.context import CryptContext
import logging

logger = logging.getLogger(__name__)

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

class UserRepository(BaseRepository[User]):
    """用户仓库 - 继承基础仓库并添加用户特有方法"""
    
    def __init__(self, db_session: AsyncSession):
        super().__init__(User, db_session)

    async def get_by_email(self, email: str) -> Optional[User]:
        """根据邮箱获取用户"""
        stmt = select(User).where(User.email == email)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_by_email_with_posts(self, email: str) -> Optional[User]:
        """根据邮箱获取用户及其发布的帖子"""
        stmt = (
            select(User)
            .options(selectinload(User.posts))
            .where(User.email == email)
        )
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def authenticate(self, email: str, password: str) -> Optional[User]:
        """用户认证"""
        user = await self.get_by_email(email)
        if not user or not pwd_context.verify(password, user.hashed_password):
            return None
        return user

    async def get_multi_with_filters(
        self, 
        skip: int = 0, 
        limit: int = 20, 
        search: Optional[str] = None,
        active_only: bool = True,
        role: Optional[str] = None
    ) -> List[User]:
        """带过滤条件的多条记录获取"""
        stmt = select(User)
        
        # 搜索条件
        if search:
            stmt = stmt.where(
                or_(
                    User.name.ilike(f"%{search}%"),
                    User.email.ilike(f"%{search}%")
                )
            )
        
        # 活跃状态
        if active_only:
            stmt = stmt.where(User.is_active == True)
        
        # 角色过滤
        if role:
            if role == "admin":
                stmt = stmt.where(User.is_superuser == True)
            elif role == "verified":
                stmt = stmt.where(User.is_verified == True)
        
        stmt = stmt.offset(skip).limit(limit).order_by(User.created_at.desc())
        result = await self.db.execute(stmt)
        return list(result.scalars().all())

    async def create_with_hashed_password(self, user_data: dict) -> User:
        """创建用户并哈希密码"""
        if 'password' in user_data:
            user_data = user_data.copy()
            user_data['hashed_password'] = pwd_context.hash(user_data.pop('password'))
        
        return await self.create(user_data)

    async def update_password(self, user_id: int, new_password: str) -> User:
        """更新用户密码"""
        user = await self.get_by_id(user_id)
        if not user:
            raise ValueError("User not found")
        
        hashed_password = pwd_context.hash(new_password)
        return await self.update(user, {"hashed_password": hashed_password})

    async def deactivate_user(self, user_id: int) -> bool:
        """停用用户"""
        user = await self.get_by_id(user_id)
        if not user:
            return False
        
        await self.update(user, {"is_active": False})
        return True

    async def get_user_statistics(self, user_id: int) -> dict:
        """获取用户统计数据"""
        user = await self.get_by_id(user_id)
        if not user:
            return {}
        
        # 计算用户发布的帖子数
        post_count_stmt = select(func.count(Post.id)).where(Post.author_id == user_id)
        post_count_result = await self.db.execute(post_count_stmt)
        post_count = post_count_result.scalar_one()
        
        # 计算用户发表的评论数
        comment_count_stmt = select(func.count(Comment.id)).where(Comment.author_id == user_id)
        comment_count_result = await self.db.execute(comment_count_stmt)
        comment_count = comment_count_result.scalar_one()
        
        return {
            "id": user.id,
            "name": user.name,
            "email": user.email,
            "post_count": post_count,
            "comment_count": comment_count,
            "created_at": user.created_at,
            "is_active": user.is_active,
        }

帖子Repository实现

# repositories/post_repository.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, desc, asc
from sqlalchemy.orm import selectinload
from repositories.base_repository import BaseRepository
from models.post import Post
from models.tag import Tag
from models.user import User
from datetime import datetime

class PostRepository(BaseRepository[Post]):
    """帖子仓库"""
    
    def __init__(self, db_session: AsyncSession):
        super().__init__(Post, db_session)

    async def get_by_slug(self, slug: str) -> Optional[Post]:
        """根据slug获取帖子"""
        stmt = select(Post).where(Post.slug == slug)
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_published_posts(
        self, 
        skip: int = 0, 
        limit: int = 20,
        featured_only: bool = False,
        tag_filter: Optional[str] = None
    ) -> List[Post]:
        """获取已发布的帖子"""
        stmt = (
            select(Post)
            .join(Post.author)
            .options(selectinload(Post.author))
            .options(selectinload(Post.tags))
            .where(Post.published == True)
        )
        
        # 特色文章过滤
        if featured_only:
            stmt = stmt.where(Post.featured == True)
        
        # 标签过滤
        if tag_filter:
            stmt = stmt.join(Post.tags).where(Tag.name == tag_filter)
        
        stmt = stmt.order_by(Post.published_at.desc()).offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return list(result.scalars().all())

    async def search_posts(
        self, 
        query: str, 
        skip: int = 0, 
        limit: int = 20
    ) -> List[Post]:
        """搜索帖子"""
        stmt = (
            select(Post)
            .join(Post.author)
            .options(selectinload(Post.author))
            .options(selectinload(Post.tags))
            .where(
                or_(
                    Post.title.ilike(f"%{query}%"),
                    Post.content.ilike(f"%{query}%"),
                    Post.excerpt.ilike(f"%{query}%")
                )
            )
            .where(Post.published == True)
            .order_by(desc(Post.published_at))
            .offset(skip).limit(limit)
        )
        result = await self.db.execute(stmt)
        return list(result.scalars().all())

    async def get_posts_by_author(
        self, 
        author_id: int, 
        skip: int = 0, 
        limit: int = 20,
        include_drafts: bool = False
    ) -> List[Post]:
        """根据作者获取帖子"""
        stmt = (
            select(Post)
            .options(selectinload(Post.tags))
            .where(Post.author_id == author_id)
        )
        
        if not include_drafts:
            stmt = stmt.where(Post.published == True)
        
        stmt = stmt.order_by(desc(Post.created_at)).offset(skip).limit(limit)
        result = await self.db.execute(stmt)
        return list(result.scalars().all())

    async def create_with_tags(self, post_data: dict, tag_ids: List[int] = None) -> Post:
        """创建帖子并关联标签"""
        post = await self.create(post_data)
        
        if tag_ids:
            from sqlalchemy import insert
            values = [{"post_id": post.id, "tag_id": tag_id} for tag_id in tag_ids]
            if values:
                await self.db.execute(insert(post_tags), values)
                await self.db.commit()
        
        # 重新获取帖子以包含标签信息
        return await self.get_by_id(post.id)

    async def update_with_tags(self, post_id: int, post_data: dict, tag_ids: List[int] = None) -> Post:
        """更新帖子并关联标签"""
        post = await self.get_by_id(post_id)
        if not post:
            raise ValueError("Post not found")
        
        # 更新帖子基本信息
        updated_post = await self.update(post, post_data)
        
        # 更新标签关联
        if tag_ids is not None:
            # 先删除现有的标签关联
            from sqlalchemy import delete
            await self.db.execute(delete(post_tags).where(post_tags.c.post_id == post_id))
            
            # 添加新的标签关联
            if tag_ids:
                from sqlalchemy import insert
                values = [{"post_id": post_id, "tag_id": tag_id} for tag_id in tag_ids]
                await self.db.execute(insert(post_tags), values)
            
            await self.db.commit()
        
        # 重新获取帖子以包含最新的标签信息
        return await self.get_by_id(post_id)

    async def get_featured_posts(self, limit: int = 5) -> List[Post]:
        """获取特色帖子"""
        stmt = (
            select(Post)
            .join(Post.author)
            .options(selectinload(Post.author))
            .options(selectinload(Post.tags))
            .where(and_(Post.published == True, Post.featured == True))
            .order_by(desc(Post.published_at))
            .limit(limit)
        )
        result = await self.db.execute(stmt)
        return list(result.scalars().all())

在路由中使用数据库

依赖注入配置

# dependencies.py
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from repositories.user_repository import UserRepository
from repositories.post_repository import PostRepository
from models.user import User
from typing import Optional

async def get_user_repository(db: AsyncSession = Depends(get_db)) -> UserRepository:
    """获取用户仓库依赖"""
    return UserRepository(db)

async def get_post_repository(db: AsyncSession = Depends(get_db)) -> PostRepository:
    """获取帖子仓库依赖"""
    return PostRepository(db)

async def get_current_user(
    db: AsyncSession = Depends(get_db),
    token: str = Depends(oauth2_scheme)
) -> User:
    """获取当前用户依赖"""
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        user_id: int = payload.get("sub")
        if user_id is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    
    user = await UserRepository(db).get_by_id(user_id)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
    """获取当前活跃用户"""
    if not current_user.is_active:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

用户路由实现

# routers/users.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from schemas.user import UserInDB, UserCreate, UserUpdate, UserPublic
from repositories.user_repository import UserRepository
from dependencies import get_user_repository, get_current_active_user
from models.user import User
from passlib.context import CryptContext
import logging

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/users", tags=["用户管理"])

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

@router.get("/", response_model=List[UserPublic])
async def list_users(
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    search: Optional[str] = Query(None, min_length=1),
    active_only: bool = Query(True),
    repo: UserRepository = Depends(get_user_repository)
):
    """获取用户列表"""
    try:
        users = await repo.get_multi_with_filters(
            skip=skip,
            limit=limit,
            search=search,
            active_only=active_only
        )
        return [UserPublic.from_orm(user) for user in users]
    except Exception as e:
        logger.error(f"Error listing users: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

@router.get("/me", response_model=UserInDB)
async def get_current_user_profile(
    current_user: User = Depends(get_current_active_user)
):
    """获取当前用户信息"""
    return UserInDB.from_orm(current_user)

@router.get("/{user_id}", response_model=UserPublic)
async def get_user(
    user_id: int,
    repo: UserRepository = Depends(get_user_repository)
):
    """根据ID获取用户"""
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return UserPublic.from_orm(user)

@router.post("/", response_model=UserInDB, status_code=status.HTTP_201_CREATED)
async def create_user(
    user_data: UserCreate,
    repo: UserRepository = Depends(get_user_repository)
):
    """创建新用户"""
    # 检查邮箱是否已存在
    existing_user = await repo.get_by_email(user_data.email)
    if existing_user:
        raise HTTPException(status_code=409, detail="邮箱已被注册")
    
    try:
        user = await repo.create_with_hashed_password(user_data.model_dump())
        logger.info(f"User created successfully: {user.email}")
        return UserInDB.from_orm(user)
    except Exception as e:
        logger.error(f"Error creating user: {e}")
        raise HTTPException(status_code=500, detail="创建用户失败")

@router.put("/{user_id}", response_model=UserInDB)
async def update_user(
    user_id: int,
    user_data: UserUpdate,
    current_user: User = Depends(get_current_active_user),
    repo: UserRepository = Depends(get_user_repository)
):
    """更新用户信息"""
    # 检查权限:只能更新自己的信息或管理员可以更新任意用户
    if current_user.id != user_id and not current_user.is_superuser:
        raise HTTPException(status_code=403, detail="无权修改他人信息")
    
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    
    try:
        update_data = user_data.model_dump(exclude_unset=True)
        
        # 如果更新密码,需要哈希
        if update_data.get("password"):
            update_data["hashed_password"] = pwd_context.hash(update_data.pop("password"))
        
        updated_user = await repo.update(user, update_data)
        logger.info(f"User updated successfully: {updated_user.email}")
        return UserInDB.from_orm(updated_user)
    except Exception as e:
        logger.error(f"Error updating user: {e}")
        raise HTTPException(status_code=500, detail="更新用户失败")

@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
    user_id: int,
    current_user: User = Depends(get_current_active_user),
    repo: UserRepository = Depends(get_user_repository)
):
    """删除用户"""
    # 检查权限:只能删除自己的账号或管理员可以删除任意用户
    if current_user.id != user_id and not current_user.is_superuser:
        raise HTTPException(status_code=403, detail="无权删除用户")
    
    if current_user.id == user_id and current_user.is_superuser:
        raise HTTPException(status_code=400, detail="不能删除自己的超级用户账号")
    
    success = await repo.delete(user_id)
    if not success:
        raise HTTPException(status_code=404, detail="用户不存在")
    
    logger.info(f"User deleted successfully: {user_id}")
    return None

@router.get("/{user_id}/stats", response_model=dict)
async def get_user_stats(
    user_id: int,
    current_user: User = Depends(get_current_active_user),
    repo: UserRepository = Depends(get_user_repository)
):
    """获取用户统计数据"""
    if current_user.id != user_id and not current_user.is_superuser:
        raise HTTPException(status_code=403, detail="无权查看用户统计数据")
    
    stats = await repo.get_user_statistics(user_id)
    if not stats:
        raise HTTPException(status_code=404, detail="用户不存在")
    
    return stats

帖子路由实现

# routers/posts.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from schemas.post import PostCreate, PostUpdate, PostResponse
from repositories.post_repository import PostRepository
from repositories.user_repository import UserRepository
from dependencies import get_post_repository, get_current_active_user
from models.user import User
from models.post import Post
import logging

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/posts", tags=["帖子管理"])

@router.get("/", response_model=List[PostResponse])
async def list_posts(
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    published: bool = Query(True),
    featured: Optional[bool] = Query(None),
    tag: Optional[str] = Query(None),
    search: Optional[str] = Query(None),
    repo: PostRepository = Depends(get_post_repository)
):
    """获取帖子列表"""
    try:
        if search:
            posts = await repo.search_posts(search, skip, limit)
        elif featured is not None:
            posts = await repo.get_published_posts(skip, limit, featured_only=featured, tag_filter=tag)
        else:
            posts = await repo.get_multi(skip, limit)
        
        return [PostResponse.from_orm(post) for post in posts]
    except Exception as e:
        logger.error(f"Error listing posts: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

@router.get("/featured", response_model=List[PostResponse])
async def get_featured_posts(
    limit: int = Query(5, ge=1, le=20),
    repo: PostRepository = Depends(get_post_repository)
):
    """获取特色帖子"""
    try:
        posts = await repo.get_featured_posts(limit)
        return [PostResponse.from_orm(post) for post in posts]
    except Exception as e:
        logger.error(f"Error getting featured posts: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

@router.get("/{post_id}", response_model=PostResponse)
async def get_post(
    post_id: int,
    repo: PostRepository = Depends(get_post_repository)
):
    """根据ID获取帖子"""
    post = await repo.get_by_id(post_id)
    if not post:
        raise HTTPException(status_code=404, detail="帖子不存在")
    return PostResponse.from_orm(post)

@router.post("/", response_model=PostResponse, status_code=status.HTTP_201_CREATED)
async def create_post(
    post_data: PostCreate,
    current_user: User = Depends(get_current_active_user),
    post_repo: PostRepository = Depends(get_post_repository),
    user_repo: UserRepository = Depends(get_user_repository)
):
    """创建新帖子"""
    try:
        # 验证作者存在
        author = await user_repo.get_by_id(current_user.id)
        if not author or not author.is_active:
            raise HTTPException(status_code=400, detail="作者不存在或已被禁用")
        
        # 准备帖子数据
        post_dict = post_data.model_dump()
        post_dict["author_id"] = current_user.id
        
        # 如果是发布状态,设置发布时间
        if post_dict.get("published", False):
            from datetime import datetime
            post_dict["published_at"] = datetime.utcnow()
        
        post = await post_repo.create_with_tags(post_dict, post_data.tag_ids)
        logger.info(f"Post created successfully: {post.title}")
        return PostResponse.from_orm(post)
    except Exception as e:
        logger.error(f"Error creating post: {e}")
        raise HTTPException(status_code=500, detail="创建帖子失败")

@router.put("/{post_id}", response_model=PostResponse)
async def update_post(
    post_id: int,
    post_data: PostUpdate,
    current_user: User = Depends(get_current_active_user),
    repo: PostRepository = Depends(get_post_repository)
):
    """更新帖子"""
    post = await repo.get_by_id(post_id)
    if not post:
        raise HTTPException(status_code=404, detail="帖子不存在")
    
    # 检查权限
    if post.author_id != current_user.id and not current_user.is_superuser:
        raise HTTPException(status_code=403, detail="无权修改此帖子")
    
    try:
        update_data = post_data.model_dump(exclude_unset=True)
        
        # 如果发布了帖子且之前未发布,设置发布时间
        if update_data.get("published") and not post.published:
            from datetime import datetime
            update_data["published_at"] = datetime.utcnow()
        
        updated_post = await repo.update_with_tags(post_id, update_data, post_data.tag_ids)
        logger.info(f"Post updated successfully: {updated_post.title}")
        return PostResponse.from_orm(updated_post)
    except Exception as e:
        logger.error(f"Error updating post: {e}")
        raise HTTPException(status_code=500, detail="更新帖子失败")

@router.delete("/{post_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_post(
    post_id: int,
    current_user: User = Depends(get_current_active_user),
    repo: PostRepository = Depends(get_post_repository)
):
    """删除帖子"""
    post = await repo.get_by_id(post_id)
    if not post:
        raise HTTPException(status_code=404, detail="帖子不存在")
    
    # 检查权限
    if post.author_id != current_user.id and not current_user.is_superuser:
        raise HTTPException(status_code=403, detail="无权删除此帖子")
    
    success = await repo.delete(post_id)
    if not success:
        raise HTTPException(status_code=404, detail="帖子不存在")
    
    logger.info(f"Post deleted successfully: {post_id}")
    return None

@router.get("/author/{author_id}", response_model=List[PostResponse])
async def get_posts_by_author(
    author_id: int,
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    include_drafts: bool = Query(False),
    repo: PostRepository = Depends(get_post_repository)
):
    """根据作者获取帖子"""
    posts = await repo.get_posts_by_author(author_id, skip, limit, include_drafts)
    return [PostResponse.from_orm(post) for post in posts]

事务处理与并发控制

事务管理基础

# services/transaction_service.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from typing import Callable, Any
import logging

logger = logging.getLogger(__name__)

class TransactionService:
    """事务服务 - 提供事务管理功能"""
    
    def __init__(self, db_session: AsyncSession):
        self.db = db_session

    async def execute_in_transaction(self, operation: Callable) -> Any:
        """在事务中执行操作"""
        try:
            result = await operation()
            await self.db.commit()
            return result
        except Exception as e:
            await self.db.rollback()
            logger.error(f"Transaction failed: {e}")
            raise

    async def batch_create(self, model_class, data_list: list) -> list:
        """批量创建"""
        objects = []
        try:
            for data in data_list:
                obj = model_class(**data)
                self.db.add(obj)
                objects.append(obj)
            await self.db.commit()
            # 刷新对象以获取ID
            for obj in objects:
                await self.db.refresh(obj)
            return objects
        except IntegrityError as e:
            await self.db.rollback()
            logger.error(f"Batch create failed due to integrity constraint: {e}")
            raise
        except Exception as e:
            await self.db.rollback()
            logger.error(f"Batch create failed: {e}")
            raise

    async def atomic_update(self, model_class, conditions: dict, updates: dict):
        """原子更新"""
        try:
            from sqlalchemy import update
            stmt = (
                update(model_class)
                .where(*[getattr(model_class, k) == v for k, v in conditions.items()])
                .values(**updates)
            )
            result = await self.db.execute(stmt)
            await self.db.commit()
            return result.rowcount
        except Exception as e:
            await self.db.rollback()
            logger.error(f"Atomic update failed: {e}")
            raise

乐观锁实现

# models/base.py (扩展基础模型)
from sqlalchemy import Integer
from sqlalchemy.orm import mapped_column
from database import Base

class VersionedBase(Base):
    """带版本控制的基础模型"""
    __abstract__ = True
    
    version: Mapped[int] = mapped_column(Integer, default=1, nullable=False)

# models/user.py (使用版本控制)
class User(VersionedBase):
    # ... 其他字段
    pass

# services/optimistic_lock_service.py
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy import and_

class OptimisticLockService:
    """乐观锁服务"""
    
    def __init__(self, db_session: AsyncSession):
        self.db = db_session

    async def update_with_version_check(self, model_instance, new_data: dict, expected_version: int):
        """使用版本号进行乐观锁更新"""
        from sqlalchemy import update
        
        # 构建更新语句,包含版本检查
        stmt = (
            update(type(model_instance))
            .where(
                and_(
                    type(model_instance).id == model_instance.id,
                    type(model_instance).version == expected_version
                )
            )
            .values(**new_data, version=expected_version + 1)
        )
        
        result = await self.db.execute(stmt)
        
        if result.rowcount == 0:
            # 检查记录是否存在
            existing = await self.db.get(type(model_instance), model_instance.id)
            if existing:
                raise ValueError("数据已被其他用户修改,请刷新后重试")
            else:
                raise ValueError("记录不存在")
        
        await self.db.commit()
        return result.rowcount

分布式锁(使用Redis)

# services/distributed_lock_service.py
import asyncio
import time
import uuid
from typing import Optional
import redis.asyncio as redis

class DistributedLockService:
    """分布式锁服务"""
    
    def __init__(self, redis_client: redis.Redis, default_timeout: int = 30):
        self.redis = redis_client
        self.default_timeout = default_timeout

    async def acquire_lock(self, lock_key: str, timeout: Optional[int] = None) -> Optional[str]:
        """获取分布式锁"""
        if timeout is None:
            timeout = self.default_timeout
            
        lock_value = str(uuid.uuid4())
        end_time = time.time() + timeout
        
        while time.time() < end_time:
            # 使用SET命令的NX和EX选项来原子性地设置锁
            if await self.redis.set(lock_key, lock_value, nx=True, ex=timeout):
                return lock_value
            
            # 等待一段时间后重试
            await asyncio.sleep(0.01)
        
        return None

    async def release_lock(self, lock_key: str, lock_value: str) -> bool:
        """释放分布式锁"""
        lua_script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
        """
        return await self.redis.eval(lua_script, 1, lock_key, lock_value)

    async def execute_with_lock(self, lock_key: str, operation, timeout: Optional[int] = None):
        """在分布式锁保护下执行操作"""
        lock_value = await self.acquire_lock(lock_key, timeout)
        if not lock_value:
            raise TimeoutError("Unable to acquire distributed lock")
        
        try:
            return await operation()
        finally:
            await self.release_lock(lock_key, lock_value)

性能优化与最佳实践

查询优化技巧

# services/query_optimization_service.py
from sqlalchemy import select, func
from sqlalchemy.orm import selectinload, joinedload
from typing import List

class QueryOptimizationService:
    """查询优化服务"""
    
    def __init__(self, db_session: AsyncSession):
        self.db = db_session

    async def get_user_with_posts_optimized(self, user_id: int):
        """优化的用户+帖子查询(避免N+1问题)"""
        stmt = (
            select(User)
            .options(selectinload(User.posts))
            .where(User.id == user_id)
        )
        result = await self.db.execute(stmt)
        return result.scalar_one_or_none()

    async def get_posts_with_author_and_tags(self, skip: int = 0, limit: int = 20):
        """优化的帖子+作者+标签查询"""
        stmt = (
            select(Post)
            .join(Post.author)
            .options(joinedload(Post.author))
            .options(selectinload(Post.tags))
            .where(Post.published == True)
            .order_by(Post.published_at.desc())
            .offset(skip)
            .limit(limit)
        )
        result = await self.db.execute(stmt)
        return result.scalars().all()

    async def get_aggregated_post_stats(self):
        """聚合查询获取帖子统计信息"""
        stmt = (
            select(
                User.name,
                func.count(Post.id).label('post_count'),
                func.avg(func.length(Post.content)).label('avg_content_length')
            )
            .join(User.posts)
            .group_by(User.id, User.name)
            .having(func.count(Post.id) > 0)
        )
        result = await self.db.execute(stmt)
        return result.all()

    async def paginated_search(self, query_text: str, page: int = 1, page_size: int = 20):
        """分页搜索"""
        offset = (page - 1) * page_size
        
        # 搜索查询
        search_stmt = (
            select(Post)
            .where(
                Post.title.ilike(f"%{query_text}%") |
                Post.content.ilike(f"%{query_text}%")
            )
            .where(Post.published == True)
            .order_by(Post.published_at.desc())
            .offset(offset)
            .limit(page_size)
        )
        
        # 总数查询
        count_stmt = (
            select(func.count(Post.id))
            .where(
                Post.title.ilike(f"%{query_text}%") |
                Post.content.ilike(f"%{query_text}%")
            )
            .where(Post.published == True)
        )
        
        search_result = await self.db.execute(search_stmt)
        count_result = await self.db.execute(count_stmt)
        
        posts = search_result.scalars().all()
        total_count = count_result.scalar_one()
        
        return {
            "posts": posts,
            "total": total_count,
            "page": page,
            "page_size": page_size,
            "total_pages": (total_count + page_size - 1) // page_size
        }

连接池优化

# config/database_config.py
from pydantic import BaseSettings
from typing import Optional

class DatabaseConfig(BaseSettings):
    """数据库配置"""
    database_url: str
    database_pool_size: int = 20
    database_max_overflow: int = 10
    database_pool_timeout: int = 30
    database_pool_recycle: int = 3600
    database_pool_pre_ping: bool = True
    database_echo: bool = False
    
    # PostgreSQL特定配置
    postgresql_ssl_mode: str = "prefer"
    postgresql_statement_timeout: int = 30000  # 30秒
    
    # 连接池健康检查
    pool_health_check_threshold: int = 10  # 秒
    
    class Config:
        env_file = ".env"

# database/engine_manager.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.pool import QueuePool
from config.database_config import DatabaseConfig
import logging

logger = logging.getLogger(__name__)

class EngineManager:
    """数据库引擎管理器"""
    
    def __init__(self, config: DatabaseConfig):
        self.config = config
        self.engine: Optional[AsyncEngine] = None

    def create_engine(self) -> AsyncEngine:
        """创建数据库引擎"""
        if self.engine is None:
            self.engine = create_async_engine(
                self.config.database_url,
                poolclass=QueuePool,
                pool_size=self.config.database_pool_size,
                max_overflow=self.config.database_max_overflow,
                pool_timeout=self.config.database_pool_timeout,
                pool_recycle=self.config.database_pool_recycle,
                pool_pre_ping=self.config.database_pool_pre_ping,
                echo=self.config.database_echo,
                # 连接池回收策略
                pool_recycle_callback=lambda conn: logger.info("Connection recycled"),
            )
            logger.info("Database engine created successfully")
        
        return self.engine

    async def dispose_engine(self):
        """销毁数据库引擎"""
        if self.engine:
            await self.engine.dispose()
            self.engine = None
            logger.info("Database engine disposed")

缓存策略

# services/cache_service.py
import json
import pickle
from typing import Optional, Any
from datetime import timedelta
import redis.asyncio as redis

class CacheService:
    """缓存服务"""
    
    def __init__(self, redis_client: redis.Redis):
        self.redis = redis_client

    async def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        value = await self.redis.get(key)
        if value:
            return pickle.loads(value)
        return None

    async def set(self, key: str, value: Any, expire: timedelta = timedelta(hours=1)):
        """设置缓存"""
        pickled_value = pickle.dumps(value)
        await self.redis.setex(key, int(expire.total_seconds()), pickled_value)

    async def delete(self, key: str):
        """删除缓存"""
        await self.redis.delete(key)

    async def get_or_set(self, key: str, fetch_func, expire: timedelta = timedelta(hours=1)):
        """获取或设置缓存"""
        value = await self.get(key)
        if value is None:
            value = await fetch_func()
            await self.set(key, value, expire)
        return value

# 使用缓存的服务示例
class CachedUserService:
    """带缓存的用户服务"""
    
    def __init__(self, db_session: AsyncSession, cache_service: CacheService):
        self.db = db_session
        self.cache = cache_service

    async def get_user_by_id(self, user_id: int):
        """获取用户(带缓存)"""
        cache_key = f"user:{user_id}"
        
        # 先尝试从缓存获取
        user = await self.cache.get(cache_key)
        if user:
            return user
        
        # 缓存未命中,从数据库获取
        stmt = select(User).where(User.id == user_id)
        result = await self.db.execute(stmt)
        user = result.scalar_one_or_none()
        
        if user:
            # 存入缓存
            await self.cache.set(cache_key, user, timedelta(minutes=30))
        
        return user

    async def invalidate_user_cache(self, user_id: int):
        """使用户缓存失效"""
        cache_key = f"user:{user_id}"
        await self.cache.delete(cache_key)

常见陷阱与避坑指南

陷阱1:N+1查询问题

# ❌ 错误:N+1查询
async def get_posts_with_authors_bad():
    posts = await repo.get_multi()  # 获取所有帖子
    result = []
    for post in posts:
        # 每次循环都查询数据库获取作者信息
        author = await user_repo.get_by_id(post.author_id)  # N次查询
        result.append({
            "post": post,
            "author": author
        })
    return result

# ✅ 正确:使用预加载
async def get_posts_with_authors_good():
    stmt = (
        select(Post)
        .join(Post.author)
        .options(selectinload(Post.author))  # 预加载作者信息
    )
    result = await db.execute(stmt)
    posts = result.scalars().all()
    return posts

陷阱2:会话生命周期管理

# ❌ 错误:会话在函数返回后关闭
async def get_user_incorrect(db: AsyncSession):
    user = await repo.get_by_id(1)
    # 注意:此时会话可能已经关闭,访问user.posts会失败
    return user.posts  # 可能抛出DetachedInstanceError

# ✅ 正确:在会话范围内完成所有操作
async def get_user_correct(db: AsyncSession):
    stmt = (
        select(User)
        .options(selectinload(User.posts))
        .where(User.id == 1)
    )
    result = await db.execute(stmt)
    user = result.scalar_one_or_none()
    # 在会话范围内访问关联数据
    return [post.to_dict() for post in user.posts] if user else []

陷阱3:事务边界不清

# ❌ 错误:事务边界不清晰
async def transfer_money_bad(from_account_id: int, to_account_id: int, amount: float):
    # 这些操作不在同一个事务中
    await deduct_amount(from_account_id, amount)    # 可能成功
    await add_amount(to_account_id, amount)        # 可能失败,导致资金丢失

# ✅ 正确:明确的事务边界
async def transfer_money_good(from_account_id: int, to_account_id: int, amount: float, db: AsyncSession):
    async with db.begin():  # 明确的事务边界
        await deduct_amount(from_account_id, amount, db)
        await add_amount(to_account_id, amount, db)
    # 事务自动提交,异常自动回滚

陷阱4:并发控制不当

# ❌ 错误:没有并发控制
async def update_stock_bad(product_id: int, quantity_change: int, db: AsyncSession):
    product = await get_product(product_id, db)
    new_quantity = product.stock - quantity_change
    if new_quantity < 0:
        raise ValueError("库存不足")
    product.stock = new_quantity
    await db.commit()  # 在高并发下可能导致库存变为负数

# ✅ 正确:使用数据库锁
async def update_stock_good(product_id: int, quantity_change: int, db: AsyncSession):
    from sqlalchemy import update
    stmt = (
        update(Product)
        .where(Product.id == product_id)
        .values(stock=Product.stock - quantity_change)
        .where(Product.stock >= quantity_change)  # 在数据库层面检查
    )
    result = await db.execute(stmt)
    if result.rowcount == 0:
        raise ValueError("库存不足")
    await db.commit()

陷阱5:连接泄漏

# ❌ 错误:可能泄漏连接
async def get_data_bad():
    session = async_sessionmaker()()
    try:
        # 如果这里抛出异常,连接可能不会被正确关闭
        result = await session.execute(select(User))
        return result.scalars().all()
    except Exception:
        # 异常处理可能不完善
        pass

# ✅ 正确:使用上下文管理器
async def get_data_good():
    async with async_sessionmaker()() as session:
        try:
            result = await session.execute(select(User))
            return result.scalars().all()
        except Exception:
            await session.rollback()
            raise
        # 连接会自动关闭

与其他ORM对比

SQLAlchemy 2.0 vs Django ORM

特性SQLAlchemy 2.0Django ORM
异步支持✅ 原生异步❌ 同步为主
类型提示✅ 优秀的类型支持❌ 类型支持较弱
灵活性✅ 高度灵活❌ 约定优于配置
学习曲线⚠️ 中等✅ 简单易学
性能✅ 高性能⚠️ 中等
生态系统✅ 丰富的生态系统✅ 完整的Web框架

SQLAlchemy 2.0 vs Peewee

特性SQLAlchemy 2.0Peewee
异步支持✅ 完整异步❌ 主要是同步
查询API✅ 现代化API✅ 简洁API
复杂查询✅ 强大的复杂查询⚠️ 简单查询
社区支持✅ 活跃社区⚠️ 小而精
文档质量✅ 优秀文档✅ 良好文档
企业级特性✅ 丰富企业级特性⚠️ 基础特性

SQLAlchemy 2.0 vs Tortoise ORM

特性SQLAlchemy 2.0Tortoise ORM
异步支持✅ 原生异步✅ 专为异步设计
Django风格❌ SQLAlchemy风格✅ Django风格
迁移工具✅ Alembic✅ 内置迁移
类型提示✅ 优秀的类型支持✅ Python类型提示
成熟度✅ 非常成熟⚠️ 相对年轻
生态系统✅ 丰富的生态系统⚠️ 生态较小

相关教程

使用SQLAlchemy 2.0异步ORM时,建议采用Repository模式封装数据访问逻辑,这样可以提高代码的可测试性和可维护性。同时,注意使用适当的加载策略来避免N+1查询问题。 在高并发场景下,合理配置连接池参数和使用缓存策略可以显著提升应用性能。同时,注意在数据库层面进行适当的索引优化。

总结

组件作用最佳实践
create_async_engine创建异步数据库引擎配置合适的连接池参数
AsyncSession异步数据库会话使用上下文管理器确保资源释放
Mapped[T] + mapped_columnSQLAlchemy 2.0类型注解使用类型提示提高代码质量
DeclarativeBase所有模型的基类统一模型管理
select() + execute()异步查询使用预加载避免N+1问题
selectinload预加载关联,防止N+1在需要关联数据时使用
lifespan 管理器管理数据库连接生命周期确保应用启动和关闭时正确处理连接
Repository模式数据访问层抽象分离业务逻辑和数据访问逻辑

💡 核心思想:SQLAlchemy 2.0的异步ORM为企业级应用提供了强大的数据持久化能力。通过合理的配置、优化的查询策略和适当的并发控制,可以构建高性能、可维护的数据库应用。

SQLAlchemy 2.0与FastAPI的结合为企业级应用提供了完整的异步数据库解决方案,通过Repository模式、事务管理、连接池优化等技术手段,可以构建出高性能、可扩展的现代化Web应用。


1. 为什么选择 SQLAlchemy 2.0?

1.1 SQLAlchemy 2.0 的核心改进

  • 原生异步支持AsyncSession + asyncpg / aiomysql,无需同步线程池
  • 统一查询语法:Core 和 ORM 使用同一套 select() API
  • 更严格的类型提示Mapped[int] 等注解让代码即文档
  • 迁移友好:与 Alembic 深度集成

1.2 项目依赖安装

pip install sqlalchemy[asyncio] asyncpg aiosqlite
# sqlalchemy[asyncio] 包含 asyncio 扩展
# asyncpg → PostgreSQL(生产推荐)
# aiosqlite → SQLite(开发/轻量)

2. 基础配置

2.1 数据库连接配置

# database.py
from sqlalchemy.ext.asyncio import (
    AsyncSession, create_async_engine, async_sessionmaker
)
from sqlalchemy.orm import DeclarativeBase
from config import get_settings

settings = get_settings()

# PostgreSQL 连接字符串格式:postgresql+asyncpg://
DATABASE_URL = settings.database_url  # "postgresql+asyncpg://user:pass@localhost/mydb"
# SQLite 连接字符串格式:sqlite+aiosqlite:///
# DATABASE_URL = "sqlite+aiosqlite:///./app.db"

engine = create_async_engine(
    DATABASE_URL,
    echo=settings.debug,          # 打印 SQL 语句(开发用)
    pool_size=20,                  # 连接池大小
    max_overflow=10,               # 超出 pool_size 的连接数
    pool_pre_ping=True,            # 连接前检测是否有效
)

# 会话工厂
AsyncSessionLocal = async_sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,        # 提交后不自动刷新对象
    autocommit=False,
    autoflush=False,
)

# 基础模型类
class Base(DeclarativeBase):
    pass

2.2 在 FastAPI 中管理数据库生命周期

# main.py
from contextlib import asynccontextmanager
from database import engine, Base

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动:创建所有表
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    yield
    # 关闭:释放连接池
    await engine.dispose()

app = FastAPI(lifespan=lifespan)

2.3 数据库依赖

# dependencies.py
from database import AsyncSessionLocal

async def get_db() -> AsyncSession:
    async with AsyncSessionLocal() as session:
        yield session

3. 模型定义

3.1 定义用户模型

# models/user.py
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime

class User(Base):
    __tablename__ = "users"

    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
    name: Mapped[str] = mapped_column(String(50), nullable=False)
    email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
    hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
    is_active: Mapped[bool] = mapped_column(Boolean, default=True)
    is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
    created_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        server_default=func.now()
    )
    updated_at: Mapped[datetime] = mapped_column(
        DateTime(timezone=True),
        onupdate=func.now()
    )

    # 关系
    posts: Mapped[list["Post"]] = relationship("Post", back_populates="author", lazy="selectin")

    def __repr__(self):
        return f"<User {self.email}>"

3.2 定义帖子模型(多对多关系)

# models/post.py
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, Text, Table, Column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime

# 多对多关联表
tags = Table(
    "tags", Base.metadata,
    Column("post_id", Integer, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True),
    Column("tag_id", Integer, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)

class Tag(Base):
    __tablename__ = "tags"
    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(30), unique=True, nullable=False)

class Post(Base):
    __tablename__ = "posts"

    id: Mapped[int] = mapped_column(primary_key=True)
    title: Mapped[str] = mapped_column(String(200), nullable=False)
    content: Mapped[str] = mapped_column(Text, nullable=False)
    published: Mapped[bool] = mapped_column(Boolean, default=False)
    author_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
    created_at: Mapped[datetime] = mapped_column(server_default=func.now())

    # 关系
    author: Mapped["User"] = relationship("User", back_populates="posts")
    tags: Mapped[list["Tag"]] = relationship("Tag", secondary=tags, lazy="selectin")

4. CRUD 操作

4.1 基础查询

# schemas/user.py
from pydantic import BaseModel, EmailStr
from datetime import datetime

class UserBase(BaseModel):
    name: str
    email: EmailStr

class UserCreate(UserBase):
    password: str

class UserUpdate(BaseModel):
    name: str | None = None
    email: EmailStr | None = None
    password: str | None = None

class UserSchema(UserBase):
    id: int
    is_active: bool
    created_at: datetime

    class Config:
        from_attributes = True

4.2 Repository 模式封装

# repositories/user_repository.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from sqlalchemy.orm import selectinload
from models.user import User
from models.post import Post
from typing import Optional

class UserRepository:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def get_by_id(self, user_id: int) -> Optional[User]:
        result = await self.db.execute(
            select(User).where(User.id == user_id)
        )
        return result.scalar_one_or_none()

    async def get_by_email(self, email: str) -> Optional[User]:
        result = await self.db.execute(
            select(User).where(User.email == email)
        )
        return result.scalar_one_or_none()

    async def get_multi(
        self, skip: int = 0, limit: int = 20, search: str | None = None
    ) -> list[User]:
        query = select(User)
        if search:
            query = query.where(
                or_(
                    User.name.ilike(f"%{search}%"),
                    User.email.ilike(f"%{search}%")
                )
            )
        query = query.offset(skip).limit(limit).order_by(User.created_at.desc())
        result = await self.db.execute(query)
        return list(result.scalars().all())

    async def create(self, user_data: dict) -> User:
        user = User(**user_data)
        self.db.add(user)
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def update(self, user: User, update_data: dict) -> User:
        for key, value in update_data.items():
            if value is not None:
                setattr(user, key, value)
        await self.db.commit()
        await self.db.refresh(user)
        return user

    async def delete(self, user: User) -> None:
        await self.db.delete(user)
        await self.db.commit()

    async def get_user_posts(self, user_id: int) -> list[Post]:
        result = await self.db.execute(
            select(Post)
            .where(Post.author_id == user_id)
            .options(selectinload(Post.tags))
            .order_by(Post.created_at.desc())
        )
        return list(result.scalars().all())

5. 在路由中使用

5.1 用户 CRUD 路由

# routers/users.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from dependencies import get_db
from repositories.user_repository import UserRepository
from schemas.user import UserSchema, UserCreate, UserUpdate
from passlib.context import CryptContext

router = APIRouter(prefix="/users", tags=["用户管理"])
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_repo(db: AsyncSession = Depends(get_db)) -> UserRepository:
    return UserRepository(db)

def hash_password(password: str) -> str:
    return pwd_context.hash(password)

@router.get("/", response_model=List[UserSchema])
async def list_users(
    skip: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
    search: str | None = None,
    repo: UserRepository = Depends(get_repo),
):
    users = await repo.get_multi(skip=skip, limit=limit, search=search)
    return users

@router.get("/{user_id}", response_model=UserSchema)
async def get_user(user_id: int, repo: UserRepository = Depends(get_repo)):
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return user

@router.post("/", response_model=UserSchema, status_code=status.HTTP_201_CREATED)
async def create_user(
    data: UserCreate,
    repo: UserRepository = Depends(get_repo),
):
    existing = await repo.get_by_email(data.email)
    if existing:
        raise HTTPException(status_code=409, detail="邮箱已被注册")

    user = await repo.create({
        **data.model_dump(exclude={"password"}),
        "hashed_password": hash_password(data.password),
    })
    return user

@router.put("/{user_id}", response_model=UserSchema)
async def update_user(
    user_id: int,
    data: UserUpdate,
    repo: UserRepository = Depends(get_repo),
):
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")

    update_dict = data.model_dump(exclude_unset=True)
    if "password" in update_dict:
        update_dict["hashed_password"] = hash_password(update_dict.pop("password"))

    updated = await repo.update(user, update_dict)
    return updated

@router.delete("/{user_id}", status_code=204)
async def delete_user(user_id: int, repo: UserRepository = Depends(get_repo)):
    user = await repo.get_by_id(user_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    await repo.delete(user)

6. 事务与批量操作

6.1 事务控制

async def bulk_create_posts(db: AsyncSession, posts_data: list[dict]):
    async with db.begin():
        for data in posts_data:
            post = Post(**data)
            db.add(post)
    # 自动提交,异常自动回滚
    await db.commit()

6.2 乐观锁(版本控制)

class User(Base):
    # ... 其他字段
    version: Mapped[int] = mapped_column(Integer, default=1)

async def update_with_optimistic_lock(
    repo: UserRepository, user_id: int, expected_version: int, data: dict
):
    user = await repo.get_by_id(user_id)
    if user.version != expected_version:
        raise HTTPException(409, "数据已被其他人修改,请刷新后重试")
    data["version"] = expected_version + 1
    return await repo.update(user, data)

7. 常用查询模式速查

# 条件查询
select(User).where(User.is_active == True)

# 模糊搜索
User.name.ilike("%alice%")

# 分页
select(User).offset(0).limit(10)

# 排序
select(User).order_by(User.created_at.desc())

# 预加载关联(避免 N+1 问题)
select(Post).options(selectinload(Post.author), selectinload(Post.tags))

# 聚合统计
from sqlalchemy import func, select
stmt = select(User.role, func.count(User.id)).group_by(User.role)
result = await db.execute(stmt)

# 联表查询
stmt = select(Post.title, User.name).join(User, Post.author_id == User.id)
result = await db.execute(stmt)

8. 小结

组件作用
create_async_engine创建异步数据库引擎
AsyncSessionLocal异步会话工厂
Mapped[T] + mapped_columnSQLAlchemy 2.0 类型注解
DeclarativeBase所有模型的基类
select() + execute()异步查询
selectinload预加载关联,防止 N+1
lifespan 管理器在 FastAPI 启动/关闭时管理连接池

💡 最佳实践:使用 Repository 模式封装数据访问,路由只处理 HTTP 逻辑,数据库操作交给 Repository。好处是代码可测试,业务逻辑与框架解耦。


🔗 扩展阅读