#FastAPI与SQLAlchemy 2.0完全指南
📂 所属阶段:第三阶段 — 数据持久化(数据库篇)
🔗 相关章节:async-await-原理与实战 · Redis-集成 · 数据库迁移工具Alembic
#目录
- 为什么选择SQLAlchemy 2.0
- 异步数据库配置
- 模型定义与关系映射
- CRUD操作与Repository模式
- 在路由中使用数据库
- 事务处理与并发控制
- 性能优化与最佳实践
- 常见陷阱与避坑指南
- 与其他ORM对比
- 总结
#为什么选择SQLAlchemy 2.0?
#SQLAlchemy 2.0的核心改进
SQLAlchemy 2.0代表了Python ORM领域的重要里程碑,带来了革命性的变化:
| 改进点 | 详细说明 | 优势 |
|---|---|---|
| 原生异步支持 | AsyncSession + asyncpg / aiomysql | 无需同步线程池,真正的异步性能 |
| 统一查询语法 | Core和ORM使用同一套select() API | 学习成本降低,代码一致性提高 |
| 更严格的类型提示 | Mapped[int]等注解 | IDE支持更好,代码即文档 |
| 现代化API设计 | 更简洁的语法,更好的错误处理 | 开发体验显著提升 |
| 性能优化 | 更高效的查询执行,更好的内存管理 | 应用性能提升 |
#项目依赖安装
# SQLAlchemy 2.0 异步支持
pip install sqlalchemy[asyncio]
# 数据库驱动(选择其中一个)
# PostgreSQL (生产环境推荐)
pip install asyncpg
# SQLite (开发/测试环境)
pip install aiosqlite
# MySQL
pip install aiomysql
# 其他数据库支持
pip install asyncmy # MySQL alternative
pip install aiosqlite # SQLite async driver#完整依赖示例
# requirements.txt
fastapi==0.104.1
sqlalchemy[asyncio]==2.0.23
asyncpg==0.29.0 # PostgreSQL
aiosqlite==0.19.0 # SQLite
aiomysql==0.2.0 # MySQL
alembic==1.13.1 # 数据库迁移
pydantic==2.5.0
python-multipart==0.0.6
passlib[bcrypt]==1.7.4
cryptography==41.0.8#异步数据库配置
#数据库连接配置详解
# database.py
from sqlalchemy.ext.asyncio import (
AsyncSession,
create_async_engine,
async_sessionmaker,
AsyncEngine
)
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.pool import QueuePool
from config import get_settings
import logging
from typing import AsyncGenerator
import asyncio
settings = get_settings()
# 配置日志
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理器 - 集中管理数据库连接和会话"""
def __init__(self):
self.engine: AsyncEngine = None
self.sessionmaker: async_sessionmaker = None
self._initialized = False
async def initialize(self):
"""初始化数据库连接"""
if self._initialized:
return
# 数据库连接字符串配置
DATABASE_URL = settings.database_url
# 创建异步引擎
self.engine = create_async_engine(
DATABASE_URL,
# 连接池配置
poolclass=QueuePool,
pool_size=settings.database_pool_size, # 连接池大小
max_overflow=settings.database_max_overflow, # 最大溢出连接数
pool_pre_ping=True, # 连接前检测
pool_recycle=3600, # 连接回收时间(秒)
pool_timeout=30, # 连接超时时间
echo=settings.database_echo, # SQL日志输出
echo_pool=True, # 连接池日志
future=True, # 使用SQLAlchemy 2.0 API
)
# 会话工厂配置
self.sessionmaker = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False, # 提交后不自动刷新对象
autocommit=False, # 手动提交事务
autoflush=False, # 手动刷新
expire_on_commit=True, # 提交后过期对象
)
self._initialized = True
logger.info("Database initialized successfully")
async def dispose(self):
"""关闭数据库连接"""
if self.engine:
await self.engine.dispose()
self._initialized = False
logger.info("Database connections disposed")
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话"""
async with self.sessionmaker() as session:
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
# 全局数据库管理器实例
db_manager = DatabaseManager()
# 基础模型类
class Base(DeclarativeBase):
"""所有模型的基类"""
pass
# 便捷的会话获取函数
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""依赖注入用的数据库会话获取函数"""
async for session in db_manager.get_session():
yield session#配置管理
# config.py
from pydantic import BaseSettings
from functools import lru_cache
from typing import List, Optional
class Settings(BaseSettings):
# 应用配置
app_name: str = "FastAPI SQLAlchemy 2.0 App"
version: str = "1.0.0"
debug: bool = False
# 数据库配置
database_url: str = "postgresql+asyncpg://user:password@localhost/dbname"
database_pool_size: int = 20
database_max_overflow: int = 10
database_echo: bool = False # 开发环境设为True以查看SQL
database_pool_timeout: int = 30
database_pool_recycle: int = 3600
# Redis配置(可选)
redis_url: str = "redis://localhost:6379/0"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
@lru_cache()
def get_settings() -> Settings:
return Settings()
settings = get_settings()#在FastAPI中管理数据库生命周期
# main.py
from fastapi import FastAPI
from contextlib import asynccontextmanager
from database import db_manager, Base
import logging
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("Initializing database...")
# 启动时初始化数据库
await db_manager.initialize()
# 创建所有表(如果不存在)
async with db_manager.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database initialized successfully")
yield
# 关闭时清理资源
logger.info("Shutting down database...")
await db_manager.dispose()
logger.info("Database connections disposed")
app = FastAPI(
title="FastAPI SQLAlchemy 2.0 API",
version="1.0.0",
description="使用SQLAlchemy 2.0异步ORM的FastAPI应用",
lifespan=lifespan
)
# 健康检查端点
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {
"status": "healthy",
"database": "connected",
"timestamp": "2024-01-15T10:30:00Z"
}
# 根路径
@app.get("/")
async def root():
"""API根路径"""
return {
"message": "FastAPI SQLAlchemy 2.0 API",
"version": "1.0.0",
"docs": "/docs"
}#数据库依赖注入
# dependencies.py
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
# 数据库会话依赖
async def get_database_session() -> AsyncSession:
"""获取数据库会话的依赖函数"""
async with db_manager.get_session() as session:
yield session
# 在路由中使用
from fastapi import APIRouter
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_database_session)):
"""示例路由 - 使用数据库会话"""
# 在这里可以使用db进行数据库操作
return {"users": []}#模型定义与关系映射
#用户模型定义
# models/user.py
from sqlalchemy import (
String, Integer, Boolean, DateTime, ForeignKey, Text,
Index, UniqueConstraint, CheckConstraint
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import List, TYPE_CHECKING
if TYPE_CHECKING:
from .post import Post
from .comment import Comment
class User(Base):
"""用户模型"""
__tablename__ = "users"
# 基础字段
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="用户ID")
name: Mapped[str] = mapped_column(String(50), nullable=False, comment="用户姓名")
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False, comment="邮箱地址")
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False, comment="哈希密码")
# 状态字段
is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment="是否活跃")
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否超级用户")
is_verified: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否已验证")
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
comment="创建时间"
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
comment="更新时间"
)
# 关系映射
posts: Mapped[List["Post"]] = relationship(
"Post",
back_populates="author",
lazy="selectin",
cascade="all, delete-orphan"
)
comments: Mapped[List["Comment"]] = relationship(
"Comment",
back_populates="author",
lazy="selectin",
cascade="all, delete-orphan"
)
# 索引定义
__table_args__ = (
Index("idx_users_email", "email"),
Index("idx_users_created_at", "created_at"),
UniqueConstraint("email", name="unique_email"),
CheckConstraint("length(name) >= 2", name="name_length_check"),
)
def __repr__(self):
return f"<User(id={self.id}, email='{self.email}')>"
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"id": self.id,
"name": self.name,
"email": self.email,
"is_active": self.is_active,
"is_superuser": self.is_superuser,
"is_verified": self.is_verified,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}#帖子模型(一对多关系)
# models/post.py
from sqlalchemy import (
String, Integer, Boolean, DateTime, ForeignKey, Text,
Index, UniqueConstraint
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import List, TYPE_CHECKING
if TYPE_CHECKING:
from .user import User
from .comment import Comment
from .tag import Tag
class Post(Base):
"""帖子模型"""
__tablename__ = "posts"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="帖子ID")
title: Mapped[str] = mapped_column(String(200), nullable=False, comment="标题")
slug: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, comment="URL友好标识")
content: Mapped[str] = mapped_column(Text, nullable=False, comment="内容")
excerpt: Mapped[str] = mapped_column(Text, comment="摘要")
# 状态字段
published: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否发布")
featured: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否特色")
# 关联字段
author_id: Mapped[int] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
comment="作者ID"
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
comment="创建时间"
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
onupdate=func.now(),
comment="更新时间"
)
published_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="发布时间"
)
# 关系映射
author: Mapped["User"] = relationship("User", back_populates="posts")
comments: Mapped[List["Comment"]] = relationship(
"Comment",
back_populates="post",
lazy="selectin",
cascade="all, delete-orphan"
)
tags: Mapped[List["Tag"]] = relationship(
"Tag",
secondary="post_tags",
back_populates="posts",
lazy="selectin"
)
# 索引定义
__table_args__ = (
Index("idx_posts_slug", "slug"),
Index("idx_posts_author", "author_id"),
Index("idx_posts_published", "published"),
Index("idx_posts_created_at", "created_at"),
Index("idx_posts_featured", "featured"),
UniqueConstraint("slug", name="unique_post_slug"),
)
def __repr__(self):
return f"<Post(id={self.id}, title='{self.title}')>"
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"id": self.id,
"title": self.title,
"slug": self.slug,
"content": self.content,
"excerpt": self.excerpt,
"published": self.published,
"featured": self.featured,
"author_id": self.author_id,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"published_at": self.published_at.isoformat() if self.published_at else None,
}#标签模型(多对多关系)
# models/tag.py
from sqlalchemy import String, Integer, DateTime, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import List, TYPE_CHECKING
if TYPE_CHECKING:
from .post import Post
# 多对多关联表
from sqlalchemy import Table, Column, ForeignKey
post_tags = Table(
"post_tags",
Base.metadata,
Column("post_id", Integer, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True),
Column("tag_id", Integer, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)
class Tag(Base):
"""标签模型"""
__tablename__ = "tags"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="标签ID")
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, comment="标签名称")
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, comment="URL友好标识")
description: Mapped[str] = mapped_column(Text, comment="标签描述")
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
comment="创建时间"
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
onupdate=func.now(),
comment="更新时间"
)
# 关系映射
posts: Mapped[List["Post"]] = relationship(
"Post",
secondary=post_tags,
back_populates="tags",
lazy="selectin"
)
def __repr__(self):
return f"<Tag(id={self.id}, name='{self.name}')>"
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"id": self.id,
"name": self.name,
"slug": self.slug,
"description": self.description,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}#评论模型
# models/comment.py
from sqlalchemy import String, Integer, DateTime, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .user import User
from .post import Post
class Comment(Base):
"""评论模型"""
__tablename__ = "comments"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment="评论ID")
content: Mapped[str] = mapped_column(Text, nullable=False, comment="评论内容")
approved: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否审核通过")
# 关联字段
author_id: Mapped[int] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
comment="作者ID"
)
post_id: Mapped[int] = mapped_column(
ForeignKey("posts.id", ondelete="CASCADE"),
nullable=False,
comment="帖子ID"
)
parent_id: Mapped[int] = mapped_column(
ForeignKey("comments.id", ondelete="CASCADE"),
nullable=True,
comment="父评论ID(用于嵌套评论)"
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
comment="创建时间"
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
onupdate=func.now(),
comment="更新时间"
)
# 关系映射
author: Mapped["User"] = relationship("User", back_populates="comments")
post: Mapped["Post"] = relationship("Post", back_populates="comments")
parent: Mapped["Comment"] = relationship("Comment", remote_side=[id], backref="replies")
# 索引定义
__table_args__ = (
Index("idx_comments_post", "post_id"),
Index("idx_comments_author", "author_id"),
Index("idx_comments_approved", "approved"),
Index("idx_comments_parent", "parent_id"),
Index("idx_comments_created_at", "created_at"),
)
def __repr__(self):
return f"<Comment(id={self.id}, post_id={self.post_id})>"
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"id": self.id,
"content": self.content,
"approved": self.approved,
"author_id": self.author_id,
"post_id": self.post_id,
"parent_id": self.parent_id,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}#CRUD操作与Repository模式
#Pydantic模型定义
# schemas/user.py
from pydantic import BaseModel, EmailStr, validator
from datetime import datetime
from typing import Optional, List
from enum import Enum
class UserRole(str, Enum):
"""用户角色枚举"""
USER = "user"
ADMIN = "admin"
SUPERUSER = "superuser"
class UserBase(BaseModel):
"""用户基础模型"""
name: str
email: EmailStr
class UserCreate(UserBase):
"""用户创建模型"""
password: str
@validator('name')
def name_must_be_valid(cls, v):
if len(v.strip()) < 2:
raise ValueError('姓名长度至少为2个字符')
return v.strip()
class UserUpdate(BaseModel):
"""用户更新模型"""
name: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool] = None
password: Optional[str] = None
class UserInDB(UserBase):
"""数据库中的用户模型"""
id: int
is_active: bool
is_superuser: bool
is_verified: bool
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class UserPublic(BaseModel):
"""公开的用户信息模型"""
id: int
name: str
email: EmailStr
created_at: datetime
class Config:
from_attributes = True#Repository模式实现
# repositories/base_repository.py
from typing import TypeVar, Generic, Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, func
from sqlalchemy.orm import selectinload
from database import Base
from pydantic import BaseModel
T = TypeVar('T', bound=Base)
S = TypeVar('S', bound=BaseModel)
class BaseRepository(Generic[T]):
"""基础仓库类 - 提供通用的CRUD操作"""
def __init__(self, model: T, db_session: AsyncSession):
self.model = model
self.db = db_session
async def get_by_id(self, id: int) -> Optional[T]:
"""根据ID获取单个记录"""
stmt = select(self.model).where(self.model.id == id)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_multi(
self,
skip: int = 0,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
order_by: Optional[str] = None,
descending: bool = False
) -> List[T]:
"""获取多条记录"""
stmt = select(self.model)
# 应用过滤条件
if filters:
for field, value in filters.items():
stmt = stmt.where(getattr(self.model, field) == value)
# 排序
if order_by:
attr = getattr(self.model, order_by)
stmt = stmt.order_by(attr.desc() if descending else attr.asc())
stmt = stmt.offset(skip).limit(limit)
result = await self.db.execute(stmt)
return list(result.scalars().all())
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
"""计数"""
stmt = select(func.count()).select_from(self.model)
if filters:
for field, value in filters.items():
stmt = stmt.where(getattr(self.model, field) == value)
result = await self.db.execute(stmt)
return result.scalar_one()
async def create(self, obj_in: Dict[str, Any]) -> T:
"""创建记录"""
db_obj = self.model(**obj_in)
self.db.add(db_obj)
await self.db.commit()
await self.db.refresh(db_obj)
return db_obj
async def update(self, db_obj: T, obj_in: Dict[str, Any]) -> T:
"""更新记录"""
for field, value in obj_in.items():
if value is not None:
setattr(db_obj, field, value)
await self.db.commit()
await self.db.refresh(db_obj)
return db_obj
async def delete(self, id: int) -> bool:
"""删除记录"""
stmt = delete(self.model).where(self.model.id == id)
result = await self.db.execute(stmt)
await self.db.commit()
return result.rowcount > 0
async def exists(self, **kwargs) -> bool:
"""检查记录是否存在"""
stmt = select(self.model).filter_by(**kwargs)
result = await self.db.execute(stmt)
return result.scalar_one_or_none() is not None#用户Repository实现
# repositories/user_repository.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_
from sqlalchemy.orm import selectinload
from repositories.base_repository import BaseRepository
from models.user import User
from models.post import Post
from models.comment import Comment
from passlib.context import CryptContext
import logging
logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class UserRepository(BaseRepository[User]):
"""用户仓库 - 继承基础仓库并添加用户特有方法"""
def __init__(self, db_session: AsyncSession):
super().__init__(User, db_session)
async def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
stmt = select(User).where(User.email == email)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_email_with_posts(self, email: str) -> Optional[User]:
"""根据邮箱获取用户及其发布的帖子"""
stmt = (
select(User)
.options(selectinload(User.posts))
.where(User.email == email)
)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def authenticate(self, email: str, password: str) -> Optional[User]:
"""用户认证"""
user = await self.get_by_email(email)
if not user or not pwd_context.verify(password, user.hashed_password):
return None
return user
async def get_multi_with_filters(
self,
skip: int = 0,
limit: int = 20,
search: Optional[str] = None,
active_only: bool = True,
role: Optional[str] = None
) -> List[User]:
"""带过滤条件的多条记录获取"""
stmt = select(User)
# 搜索条件
if search:
stmt = stmt.where(
or_(
User.name.ilike(f"%{search}%"),
User.email.ilike(f"%{search}%")
)
)
# 活跃状态
if active_only:
stmt = stmt.where(User.is_active == True)
# 角色过滤
if role:
if role == "admin":
stmt = stmt.where(User.is_superuser == True)
elif role == "verified":
stmt = stmt.where(User.is_verified == True)
stmt = stmt.offset(skip).limit(limit).order_by(User.created_at.desc())
result = await self.db.execute(stmt)
return list(result.scalars().all())
async def create_with_hashed_password(self, user_data: dict) -> User:
"""创建用户并哈希密码"""
if 'password' in user_data:
user_data = user_data.copy()
user_data['hashed_password'] = pwd_context.hash(user_data.pop('password'))
return await self.create(user_data)
async def update_password(self, user_id: int, new_password: str) -> User:
"""更新用户密码"""
user = await self.get_by_id(user_id)
if not user:
raise ValueError("User not found")
hashed_password = pwd_context.hash(new_password)
return await self.update(user, {"hashed_password": hashed_password})
async def deactivate_user(self, user_id: int) -> bool:
"""停用用户"""
user = await self.get_by_id(user_id)
if not user:
return False
await self.update(user, {"is_active": False})
return True
async def get_user_statistics(self, user_id: int) -> dict:
"""获取用户统计数据"""
user = await self.get_by_id(user_id)
if not user:
return {}
# 计算用户发布的帖子数
post_count_stmt = select(func.count(Post.id)).where(Post.author_id == user_id)
post_count_result = await self.db.execute(post_count_stmt)
post_count = post_count_result.scalar_one()
# 计算用户发表的评论数
comment_count_stmt = select(func.count(Comment.id)).where(Comment.author_id == user_id)
comment_count_result = await self.db.execute(comment_count_stmt)
comment_count = comment_count_result.scalar_one()
return {
"id": user.id,
"name": user.name,
"email": user.email,
"post_count": post_count,
"comment_count": comment_count,
"created_at": user.created_at,
"is_active": user.is_active,
}#帖子Repository实现
# repositories/post_repository.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, desc, asc
from sqlalchemy.orm import selectinload
from repositories.base_repository import BaseRepository
from models.post import Post
from models.tag import Tag
from models.user import User
from datetime import datetime
class PostRepository(BaseRepository[Post]):
"""帖子仓库"""
def __init__(self, db_session: AsyncSession):
super().__init__(Post, db_session)
async def get_by_slug(self, slug: str) -> Optional[Post]:
"""根据slug获取帖子"""
stmt = select(Post).where(Post.slug == slug)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_published_posts(
self,
skip: int = 0,
limit: int = 20,
featured_only: bool = False,
tag_filter: Optional[str] = None
) -> List[Post]:
"""获取已发布的帖子"""
stmt = (
select(Post)
.join(Post.author)
.options(selectinload(Post.author))
.options(selectinload(Post.tags))
.where(Post.published == True)
)
# 特色文章过滤
if featured_only:
stmt = stmt.where(Post.featured == True)
# 标签过滤
if tag_filter:
stmt = stmt.join(Post.tags).where(Tag.name == tag_filter)
stmt = stmt.order_by(Post.published_at.desc()).offset(skip).limit(limit)
result = await self.db.execute(stmt)
return list(result.scalars().all())
async def search_posts(
self,
query: str,
skip: int = 0,
limit: int = 20
) -> List[Post]:
"""搜索帖子"""
stmt = (
select(Post)
.join(Post.author)
.options(selectinload(Post.author))
.options(selectinload(Post.tags))
.where(
or_(
Post.title.ilike(f"%{query}%"),
Post.content.ilike(f"%{query}%"),
Post.excerpt.ilike(f"%{query}%")
)
)
.where(Post.published == True)
.order_by(desc(Post.published_at))
.offset(skip).limit(limit)
)
result = await self.db.execute(stmt)
return list(result.scalars().all())
async def get_posts_by_author(
self,
author_id: int,
skip: int = 0,
limit: int = 20,
include_drafts: bool = False
) -> List[Post]:
"""根据作者获取帖子"""
stmt = (
select(Post)
.options(selectinload(Post.tags))
.where(Post.author_id == author_id)
)
if not include_drafts:
stmt = stmt.where(Post.published == True)
stmt = stmt.order_by(desc(Post.created_at)).offset(skip).limit(limit)
result = await self.db.execute(stmt)
return list(result.scalars().all())
async def create_with_tags(self, post_data: dict, tag_ids: List[int] = None) -> Post:
"""创建帖子并关联标签"""
post = await self.create(post_data)
if tag_ids:
from sqlalchemy import insert
values = [{"post_id": post.id, "tag_id": tag_id} for tag_id in tag_ids]
if values:
await self.db.execute(insert(post_tags), values)
await self.db.commit()
# 重新获取帖子以包含标签信息
return await self.get_by_id(post.id)
async def update_with_tags(self, post_id: int, post_data: dict, tag_ids: List[int] = None) -> Post:
"""更新帖子并关联标签"""
post = await self.get_by_id(post_id)
if not post:
raise ValueError("Post not found")
# 更新帖子基本信息
updated_post = await self.update(post, post_data)
# 更新标签关联
if tag_ids is not None:
# 先删除现有的标签关联
from sqlalchemy import delete
await self.db.execute(delete(post_tags).where(post_tags.c.post_id == post_id))
# 添加新的标签关联
if tag_ids:
from sqlalchemy import insert
values = [{"post_id": post_id, "tag_id": tag_id} for tag_id in tag_ids]
await self.db.execute(insert(post_tags), values)
await self.db.commit()
# 重新获取帖子以包含最新的标签信息
return await self.get_by_id(post_id)
async def get_featured_posts(self, limit: int = 5) -> List[Post]:
"""获取特色帖子"""
stmt = (
select(Post)
.join(Post.author)
.options(selectinload(Post.author))
.options(selectinload(Post.tags))
.where(and_(Post.published == True, Post.featured == True))
.order_by(desc(Post.published_at))
.limit(limit)
)
result = await self.db.execute(stmt)
return list(result.scalars().all())#在路由中使用数据库
#依赖注入配置
# dependencies.py
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from repositories.user_repository import UserRepository
from repositories.post_repository import PostRepository
from models.user import User
from typing import Optional
async def get_user_repository(db: AsyncSession = Depends(get_db)) -> UserRepository:
"""获取用户仓库依赖"""
return UserRepository(db)
async def get_post_repository(db: AsyncSession = Depends(get_db)) -> PostRepository:
"""获取帖子仓库依赖"""
return PostRepository(db)
async def get_current_user(
db: AsyncSession = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""获取当前用户依赖"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: int = payload.get("sub")
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await UserRepository(db).get_by_id(user_id)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""获取当前活跃用户"""
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user#用户路由实现
# routers/users.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from schemas.user import UserInDB, UserCreate, UserUpdate, UserPublic
from repositories.user_repository import UserRepository
from dependencies import get_user_repository, get_current_active_user
from models.user import User
from passlib.context import CryptContext
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/users", tags=["用户管理"])
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@router.get("/", response_model=List[UserPublic])
async def list_users(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, min_length=1),
active_only: bool = Query(True),
repo: UserRepository = Depends(get_user_repository)
):
"""获取用户列表"""
try:
users = await repo.get_multi_with_filters(
skip=skip,
limit=limit,
search=search,
active_only=active_only
)
return [UserPublic.from_orm(user) for user in users]
except Exception as e:
logger.error(f"Error listing users: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/me", response_model=UserInDB)
async def get_current_user_profile(
current_user: User = Depends(get_current_active_user)
):
"""获取当前用户信息"""
return UserInDB.from_orm(current_user)
@router.get("/{user_id}", response_model=UserPublic)
async def get_user(
user_id: int,
repo: UserRepository = Depends(get_user_repository)
):
"""根据ID获取用户"""
user = await repo.get_by_id(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return UserPublic.from_orm(user)
@router.post("/", response_model=UserInDB, status_code=status.HTTP_201_CREATED)
async def create_user(
user_data: UserCreate,
repo: UserRepository = Depends(get_user_repository)
):
"""创建新用户"""
# 检查邮箱是否已存在
existing_user = await repo.get_by_email(user_data.email)
if existing_user:
raise HTTPException(status_code=409, detail="邮箱已被注册")
try:
user = await repo.create_with_hashed_password(user_data.model_dump())
logger.info(f"User created successfully: {user.email}")
return UserInDB.from_orm(user)
except Exception as e:
logger.error(f"Error creating user: {e}")
raise HTTPException(status_code=500, detail="创建用户失败")
@router.put("/{user_id}", response_model=UserInDB)
async def update_user(
user_id: int,
user_data: UserUpdate,
current_user: User = Depends(get_current_active_user),
repo: UserRepository = Depends(get_user_repository)
):
"""更新用户信息"""
# 检查权限:只能更新自己的信息或管理员可以更新任意用户
if current_user.id != user_id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="无权修改他人信息")
user = await repo.get_by_id(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
try:
update_data = user_data.model_dump(exclude_unset=True)
# 如果更新密码,需要哈希
if update_data.get("password"):
update_data["hashed_password"] = pwd_context.hash(update_data.pop("password"))
updated_user = await repo.update(user, update_data)
logger.info(f"User updated successfully: {updated_user.email}")
return UserInDB.from_orm(updated_user)
except Exception as e:
logger.error(f"Error updating user: {e}")
raise HTTPException(status_code=500, detail="更新用户失败")
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: int,
current_user: User = Depends(get_current_active_user),
repo: UserRepository = Depends(get_user_repository)
):
"""删除用户"""
# 检查权限:只能删除自己的账号或管理员可以删除任意用户
if current_user.id != user_id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="无权删除用户")
if current_user.id == user_id and current_user.is_superuser:
raise HTTPException(status_code=400, detail="不能删除自己的超级用户账号")
success = await repo.delete(user_id)
if not success:
raise HTTPException(status_code=404, detail="用户不存在")
logger.info(f"User deleted successfully: {user_id}")
return None
@router.get("/{user_id}/stats", response_model=dict)
async def get_user_stats(
user_id: int,
current_user: User = Depends(get_current_active_user),
repo: UserRepository = Depends(get_user_repository)
):
"""获取用户统计数据"""
if current_user.id != user_id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="无权查看用户统计数据")
stats = await repo.get_user_statistics(user_id)
if not stats:
raise HTTPException(status_code=404, detail="用户不存在")
return stats#帖子路由实现
# routers/posts.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from schemas.post import PostCreate, PostUpdate, PostResponse
from repositories.post_repository import PostRepository
from repositories.user_repository import UserRepository
from dependencies import get_post_repository, get_current_active_user
from models.user import User
from models.post import Post
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/posts", tags=["帖子管理"])
@router.get("/", response_model=List[PostResponse])
async def list_posts(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
published: bool = Query(True),
featured: Optional[bool] = Query(None),
tag: Optional[str] = Query(None),
search: Optional[str] = Query(None),
repo: PostRepository = Depends(get_post_repository)
):
"""获取帖子列表"""
try:
if search:
posts = await repo.search_posts(search, skip, limit)
elif featured is not None:
posts = await repo.get_published_posts(skip, limit, featured_only=featured, tag_filter=tag)
else:
posts = await repo.get_multi(skip, limit)
return [PostResponse.from_orm(post) for post in posts]
except Exception as e:
logger.error(f"Error listing posts: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/featured", response_model=List[PostResponse])
async def get_featured_posts(
limit: int = Query(5, ge=1, le=20),
repo: PostRepository = Depends(get_post_repository)
):
"""获取特色帖子"""
try:
posts = await repo.get_featured_posts(limit)
return [PostResponse.from_orm(post) for post in posts]
except Exception as e:
logger.error(f"Error getting featured posts: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{post_id}", response_model=PostResponse)
async def get_post(
post_id: int,
repo: PostRepository = Depends(get_post_repository)
):
"""根据ID获取帖子"""
post = await repo.get_by_id(post_id)
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
return PostResponse.from_orm(post)
@router.post("/", response_model=PostResponse, status_code=status.HTTP_201_CREATED)
async def create_post(
post_data: PostCreate,
current_user: User = Depends(get_current_active_user),
post_repo: PostRepository = Depends(get_post_repository),
user_repo: UserRepository = Depends(get_user_repository)
):
"""创建新帖子"""
try:
# 验证作者存在
author = await user_repo.get_by_id(current_user.id)
if not author or not author.is_active:
raise HTTPException(status_code=400, detail="作者不存在或已被禁用")
# 准备帖子数据
post_dict = post_data.model_dump()
post_dict["author_id"] = current_user.id
# 如果是发布状态,设置发布时间
if post_dict.get("published", False):
from datetime import datetime
post_dict["published_at"] = datetime.utcnow()
post = await post_repo.create_with_tags(post_dict, post_data.tag_ids)
logger.info(f"Post created successfully: {post.title}")
return PostResponse.from_orm(post)
except Exception as e:
logger.error(f"Error creating post: {e}")
raise HTTPException(status_code=500, detail="创建帖子失败")
@router.put("/{post_id}", response_model=PostResponse)
async def update_post(
post_id: int,
post_data: PostUpdate,
current_user: User = Depends(get_current_active_user),
repo: PostRepository = Depends(get_post_repository)
):
"""更新帖子"""
post = await repo.get_by_id(post_id)
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
# 检查权限
if post.author_id != current_user.id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="无权修改此帖子")
try:
update_data = post_data.model_dump(exclude_unset=True)
# 如果发布了帖子且之前未发布,设置发布时间
if update_data.get("published") and not post.published:
from datetime import datetime
update_data["published_at"] = datetime.utcnow()
updated_post = await repo.update_with_tags(post_id, update_data, post_data.tag_ids)
logger.info(f"Post updated successfully: {updated_post.title}")
return PostResponse.from_orm(updated_post)
except Exception as e:
logger.error(f"Error updating post: {e}")
raise HTTPException(status_code=500, detail="更新帖子失败")
@router.delete("/{post_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_post(
post_id: int,
current_user: User = Depends(get_current_active_user),
repo: PostRepository = Depends(get_post_repository)
):
"""删除帖子"""
post = await repo.get_by_id(post_id)
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
# 检查权限
if post.author_id != current_user.id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="无权删除此帖子")
success = await repo.delete(post_id)
if not success:
raise HTTPException(status_code=404, detail="帖子不存在")
logger.info(f"Post deleted successfully: {post_id}")
return None
@router.get("/author/{author_id}", response_model=List[PostResponse])
async def get_posts_by_author(
author_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
include_drafts: bool = Query(False),
repo: PostRepository = Depends(get_post_repository)
):
"""根据作者获取帖子"""
posts = await repo.get_posts_by_author(author_id, skip, limit, include_drafts)
return [PostResponse.from_orm(post) for post in posts]#事务处理与并发控制
#事务管理基础
# services/transaction_service.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from typing import Callable, Any
import logging
logger = logging.getLogger(__name__)
class TransactionService:
"""事务服务 - 提供事务管理功能"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def execute_in_transaction(self, operation: Callable) -> Any:
"""在事务中执行操作"""
try:
result = await operation()
await self.db.commit()
return result
except Exception as e:
await self.db.rollback()
logger.error(f"Transaction failed: {e}")
raise
async def batch_create(self, model_class, data_list: list) -> list:
"""批量创建"""
objects = []
try:
for data in data_list:
obj = model_class(**data)
self.db.add(obj)
objects.append(obj)
await self.db.commit()
# 刷新对象以获取ID
for obj in objects:
await self.db.refresh(obj)
return objects
except IntegrityError as e:
await self.db.rollback()
logger.error(f"Batch create failed due to integrity constraint: {e}")
raise
except Exception as e:
await self.db.rollback()
logger.error(f"Batch create failed: {e}")
raise
async def atomic_update(self, model_class, conditions: dict, updates: dict):
"""原子更新"""
try:
from sqlalchemy import update
stmt = (
update(model_class)
.where(*[getattr(model_class, k) == v for k, v in conditions.items()])
.values(**updates)
)
result = await self.db.execute(stmt)
await self.db.commit()
return result.rowcount
except Exception as e:
await self.db.rollback()
logger.error(f"Atomic update failed: {e}")
raise#乐观锁实现
# models/base.py (扩展基础模型)
from sqlalchemy import Integer
from sqlalchemy.orm import mapped_column
from database import Base
class VersionedBase(Base):
"""带版本控制的基础模型"""
__abstract__ = True
version: Mapped[int] = mapped_column(Integer, default=1, nullable=False)
# models/user.py (使用版本控制)
class User(VersionedBase):
# ... 其他字段
pass
# services/optimistic_lock_service.py
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy import and_
class OptimisticLockService:
"""乐观锁服务"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def update_with_version_check(self, model_instance, new_data: dict, expected_version: int):
"""使用版本号进行乐观锁更新"""
from sqlalchemy import update
# 构建更新语句,包含版本检查
stmt = (
update(type(model_instance))
.where(
and_(
type(model_instance).id == model_instance.id,
type(model_instance).version == expected_version
)
)
.values(**new_data, version=expected_version + 1)
)
result = await self.db.execute(stmt)
if result.rowcount == 0:
# 检查记录是否存在
existing = await self.db.get(type(model_instance), model_instance.id)
if existing:
raise ValueError("数据已被其他用户修改,请刷新后重试")
else:
raise ValueError("记录不存在")
await self.db.commit()
return result.rowcount#分布式锁(使用Redis)
# services/distributed_lock_service.py
import asyncio
import time
import uuid
from typing import Optional
import redis.asyncio as redis
class DistributedLockService:
"""分布式锁服务"""
def __init__(self, redis_client: redis.Redis, default_timeout: int = 30):
self.redis = redis_client
self.default_timeout = default_timeout
async def acquire_lock(self, lock_key: str, timeout: Optional[int] = None) -> Optional[str]:
"""获取分布式锁"""
if timeout is None:
timeout = self.default_timeout
lock_value = str(uuid.uuid4())
end_time = time.time() + timeout
while time.time() < end_time:
# 使用SET命令的NX和EX选项来原子性地设置锁
if await self.redis.set(lock_key, lock_value, nx=True, ex=timeout):
return lock_value
# 等待一段时间后重试
await asyncio.sleep(0.01)
return None
async def release_lock(self, lock_key: str, lock_value: str) -> bool:
"""释放分布式锁"""
lua_script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
return await self.redis.eval(lua_script, 1, lock_key, lock_value)
async def execute_with_lock(self, lock_key: str, operation, timeout: Optional[int] = None):
"""在分布式锁保护下执行操作"""
lock_value = await self.acquire_lock(lock_key, timeout)
if not lock_value:
raise TimeoutError("Unable to acquire distributed lock")
try:
return await operation()
finally:
await self.release_lock(lock_key, lock_value)#性能优化与最佳实践
#查询优化技巧
# services/query_optimization_service.py
from sqlalchemy import select, func
from sqlalchemy.orm import selectinload, joinedload
from typing import List
class QueryOptimizationService:
"""查询优化服务"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def get_user_with_posts_optimized(self, user_id: int):
"""优化的用户+帖子查询(避免N+1问题)"""
stmt = (
select(User)
.options(selectinload(User.posts))
.where(User.id == user_id)
)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def get_posts_with_author_and_tags(self, skip: int = 0, limit: int = 20):
"""优化的帖子+作者+标签查询"""
stmt = (
select(Post)
.join(Post.author)
.options(joinedload(Post.author))
.options(selectinload(Post.tags))
.where(Post.published == True)
.order_by(Post.published_at.desc())
.offset(skip)
.limit(limit)
)
result = await self.db.execute(stmt)
return result.scalars().all()
async def get_aggregated_post_stats(self):
"""聚合查询获取帖子统计信息"""
stmt = (
select(
User.name,
func.count(Post.id).label('post_count'),
func.avg(func.length(Post.content)).label('avg_content_length')
)
.join(User.posts)
.group_by(User.id, User.name)
.having(func.count(Post.id) > 0)
)
result = await self.db.execute(stmt)
return result.all()
async def paginated_search(self, query_text: str, page: int = 1, page_size: int = 20):
"""分页搜索"""
offset = (page - 1) * page_size
# 搜索查询
search_stmt = (
select(Post)
.where(
Post.title.ilike(f"%{query_text}%") |
Post.content.ilike(f"%{query_text}%")
)
.where(Post.published == True)
.order_by(Post.published_at.desc())
.offset(offset)
.limit(page_size)
)
# 总数查询
count_stmt = (
select(func.count(Post.id))
.where(
Post.title.ilike(f"%{query_text}%") |
Post.content.ilike(f"%{query_text}%")
)
.where(Post.published == True)
)
search_result = await self.db.execute(search_stmt)
count_result = await self.db.execute(count_stmt)
posts = search_result.scalars().all()
total_count = count_result.scalar_one()
return {
"posts": posts,
"total": total_count,
"page": page,
"page_size": page_size,
"total_pages": (total_count + page_size - 1) // page_size
}#连接池优化
# config/database_config.py
from pydantic import BaseSettings
from typing import Optional
class DatabaseConfig(BaseSettings):
"""数据库配置"""
database_url: str
database_pool_size: int = 20
database_max_overflow: int = 10
database_pool_timeout: int = 30
database_pool_recycle: int = 3600
database_pool_pre_ping: bool = True
database_echo: bool = False
# PostgreSQL特定配置
postgresql_ssl_mode: str = "prefer"
postgresql_statement_timeout: int = 30000 # 30秒
# 连接池健康检查
pool_health_check_threshold: int = 10 # 秒
class Config:
env_file = ".env"
# database/engine_manager.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.pool import QueuePool
from config.database_config import DatabaseConfig
import logging
logger = logging.getLogger(__name__)
class EngineManager:
"""数据库引擎管理器"""
def __init__(self, config: DatabaseConfig):
self.config = config
self.engine: Optional[AsyncEngine] = None
def create_engine(self) -> AsyncEngine:
"""创建数据库引擎"""
if self.engine is None:
self.engine = create_async_engine(
self.config.database_url,
poolclass=QueuePool,
pool_size=self.config.database_pool_size,
max_overflow=self.config.database_max_overflow,
pool_timeout=self.config.database_pool_timeout,
pool_recycle=self.config.database_pool_recycle,
pool_pre_ping=self.config.database_pool_pre_ping,
echo=self.config.database_echo,
# 连接池回收策略
pool_recycle_callback=lambda conn: logger.info("Connection recycled"),
)
logger.info("Database engine created successfully")
return self.engine
async def dispose_engine(self):
"""销毁数据库引擎"""
if self.engine:
await self.engine.dispose()
self.engine = None
logger.info("Database engine disposed")#缓存策略
# services/cache_service.py
import json
import pickle
from typing import Optional, Any
from datetime import timedelta
import redis.asyncio as redis
class CacheService:
"""缓存服务"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
async def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
value = await self.redis.get(key)
if value:
return pickle.loads(value)
return None
async def set(self, key: str, value: Any, expire: timedelta = timedelta(hours=1)):
"""设置缓存"""
pickled_value = pickle.dumps(value)
await self.redis.setex(key, int(expire.total_seconds()), pickled_value)
async def delete(self, key: str):
"""删除缓存"""
await self.redis.delete(key)
async def get_or_set(self, key: str, fetch_func, expire: timedelta = timedelta(hours=1)):
"""获取或设置缓存"""
value = await self.get(key)
if value is None:
value = await fetch_func()
await self.set(key, value, expire)
return value
# 使用缓存的服务示例
class CachedUserService:
"""带缓存的用户服务"""
def __init__(self, db_session: AsyncSession, cache_service: CacheService):
self.db = db_session
self.cache = cache_service
async def get_user_by_id(self, user_id: int):
"""获取用户(带缓存)"""
cache_key = f"user:{user_id}"
# 先尝试从缓存获取
user = await self.cache.get(cache_key)
if user:
return user
# 缓存未命中,从数据库获取
stmt = select(User).where(User.id == user_id)
result = await self.db.execute(stmt)
user = result.scalar_one_or_none()
if user:
# 存入缓存
await self.cache.set(cache_key, user, timedelta(minutes=30))
return user
async def invalidate_user_cache(self, user_id: int):
"""使用户缓存失效"""
cache_key = f"user:{user_id}"
await self.cache.delete(cache_key)#常见陷阱与避坑指南
#陷阱1:N+1查询问题
# ❌ 错误:N+1查询
async def get_posts_with_authors_bad():
posts = await repo.get_multi() # 获取所有帖子
result = []
for post in posts:
# 每次循环都查询数据库获取作者信息
author = await user_repo.get_by_id(post.author_id) # N次查询
result.append({
"post": post,
"author": author
})
return result
# ✅ 正确:使用预加载
async def get_posts_with_authors_good():
stmt = (
select(Post)
.join(Post.author)
.options(selectinload(Post.author)) # 预加载作者信息
)
result = await db.execute(stmt)
posts = result.scalars().all()
return posts#陷阱2:会话生命周期管理
# ❌ 错误:会话在函数返回后关闭
async def get_user_incorrect(db: AsyncSession):
user = await repo.get_by_id(1)
# 注意:此时会话可能已经关闭,访问user.posts会失败
return user.posts # 可能抛出DetachedInstanceError
# ✅ 正确:在会话范围内完成所有操作
async def get_user_correct(db: AsyncSession):
stmt = (
select(User)
.options(selectinload(User.posts))
.where(User.id == 1)
)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
# 在会话范围内访问关联数据
return [post.to_dict() for post in user.posts] if user else []#陷阱3:事务边界不清
# ❌ 错误:事务边界不清晰
async def transfer_money_bad(from_account_id: int, to_account_id: int, amount: float):
# 这些操作不在同一个事务中
await deduct_amount(from_account_id, amount) # 可能成功
await add_amount(to_account_id, amount) # 可能失败,导致资金丢失
# ✅ 正确:明确的事务边界
async def transfer_money_good(from_account_id: int, to_account_id: int, amount: float, db: AsyncSession):
async with db.begin(): # 明确的事务边界
await deduct_amount(from_account_id, amount, db)
await add_amount(to_account_id, amount, db)
# 事务自动提交,异常自动回滚#陷阱4:并发控制不当
# ❌ 错误:没有并发控制
async def update_stock_bad(product_id: int, quantity_change: int, db: AsyncSession):
product = await get_product(product_id, db)
new_quantity = product.stock - quantity_change
if new_quantity < 0:
raise ValueError("库存不足")
product.stock = new_quantity
await db.commit() # 在高并发下可能导致库存变为负数
# ✅ 正确:使用数据库锁
async def update_stock_good(product_id: int, quantity_change: int, db: AsyncSession):
from sqlalchemy import update
stmt = (
update(Product)
.where(Product.id == product_id)
.values(stock=Product.stock - quantity_change)
.where(Product.stock >= quantity_change) # 在数据库层面检查
)
result = await db.execute(stmt)
if result.rowcount == 0:
raise ValueError("库存不足")
await db.commit()#陷阱5:连接泄漏
# ❌ 错误:可能泄漏连接
async def get_data_bad():
session = async_sessionmaker()()
try:
# 如果这里抛出异常,连接可能不会被正确关闭
result = await session.execute(select(User))
return result.scalars().all()
except Exception:
# 异常处理可能不完善
pass
# ✅ 正确:使用上下文管理器
async def get_data_good():
async with async_sessionmaker()() as session:
try:
result = await session.execute(select(User))
return result.scalars().all()
except Exception:
await session.rollback()
raise
# 连接会自动关闭#与其他ORM对比
#SQLAlchemy 2.0 vs Django ORM
| 特性 | SQLAlchemy 2.0 | Django ORM |
|---|---|---|
| 异步支持 | ✅ 原生异步 | ❌ 同步为主 |
| 类型提示 | ✅ 优秀的类型支持 | ❌ 类型支持较弱 |
| 灵活性 | ✅ 高度灵活 | ❌ 约定优于配置 |
| 学习曲线 | ⚠️ 中等 | ✅ 简单易学 |
| 性能 | ✅ 高性能 | ⚠️ 中等 |
| 生态系统 | ✅ 丰富的生态系统 | ✅ 完整的Web框架 |
#SQLAlchemy 2.0 vs Peewee
| 特性 | SQLAlchemy 2.0 | Peewee |
|---|---|---|
| 异步支持 | ✅ 完整异步 | ❌ 主要是同步 |
| 查询API | ✅ 现代化API | ✅ 简洁API |
| 复杂查询 | ✅ 强大的复杂查询 | ⚠️ 简单查询 |
| 社区支持 | ✅ 活跃社区 | ⚠️ 小而精 |
| 文档质量 | ✅ 优秀文档 | ✅ 良好文档 |
| 企业级特性 | ✅ 丰富企业级特性 | ⚠️ 基础特性 |
#SQLAlchemy 2.0 vs Tortoise ORM
| 特性 | SQLAlchemy 2.0 | Tortoise ORM |
|---|---|---|
| 异步支持 | ✅ 原生异步 | ✅ 专为异步设计 |
| Django风格 | ❌ SQLAlchemy风格 | ✅ Django风格 |
| 迁移工具 | ✅ Alembic | ✅ 内置迁移 |
| 类型提示 | ✅ 优秀的类型支持 | ✅ Python类型提示 |
| 成熟度 | ✅ 非常成熟 | ⚠️ 相对年轻 |
| 生态系统 | ✅ 丰富的生态系统 | ⚠️ 生态较小 |
#相关教程
#总结
| 组件 | 作用 | 最佳实践 |
|---|---|---|
create_async_engine | 创建异步数据库引擎 | 配置合适的连接池参数 |
AsyncSession | 异步数据库会话 | 使用上下文管理器确保资源释放 |
Mapped[T] + mapped_column | SQLAlchemy 2.0类型注解 | 使用类型提示提高代码质量 |
DeclarativeBase | 所有模型的基类 | 统一模型管理 |
select() + execute() | 异步查询 | 使用预加载避免N+1问题 |
selectinload | 预加载关联,防止N+1 | 在需要关联数据时使用 |
lifespan 管理器 | 管理数据库连接生命周期 | 确保应用启动和关闭时正确处理连接 |
| Repository模式 | 数据访问层抽象 | 分离业务逻辑和数据访问逻辑 |
💡 核心思想:SQLAlchemy 2.0的异步ORM为企业级应用提供了强大的数据持久化能力。通过合理的配置、优化的查询策略和适当的并发控制,可以构建高性能、可维护的数据库应用。
SQLAlchemy 2.0与FastAPI的结合为企业级应用提供了完整的异步数据库解决方案,通过Repository模式、事务管理、连接池优化等技术手段,可以构建出高性能、可扩展的现代化Web应用。
#1. 为什么选择 SQLAlchemy 2.0?
#1.1 SQLAlchemy 2.0 的核心改进
- 原生异步支持:
AsyncSession+asyncpg/aiomysql,无需同步线程池 - 统一查询语法:Core 和 ORM 使用同一套
select()API - 更严格的类型提示:
Mapped[int]等注解让代码即文档 - 迁移友好:与 Alembic 深度集成
#1.2 项目依赖安装
pip install sqlalchemy[asyncio] asyncpg aiosqlite
# sqlalchemy[asyncio] 包含 asyncio 扩展
# asyncpg → PostgreSQL(生产推荐)
# aiosqlite → SQLite(开发/轻量)#2. 基础配置
#2.1 数据库连接配置
# database.py
from sqlalchemy.ext.asyncio import (
AsyncSession, create_async_engine, async_sessionmaker
)
from sqlalchemy.orm import DeclarativeBase
from config import get_settings
settings = get_settings()
# PostgreSQL 连接字符串格式:postgresql+asyncpg://
DATABASE_URL = settings.database_url # "postgresql+asyncpg://user:pass@localhost/mydb"
# SQLite 连接字符串格式:sqlite+aiosqlite:///
# DATABASE_URL = "sqlite+aiosqlite:///./app.db"
engine = create_async_engine(
DATABASE_URL,
echo=settings.debug, # 打印 SQL 语句(开发用)
pool_size=20, # 连接池大小
max_overflow=10, # 超出 pool_size 的连接数
pool_pre_ping=True, # 连接前检测是否有效
)
# 会话工厂
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False, # 提交后不自动刷新对象
autocommit=False,
autoflush=False,
)
# 基础模型类
class Base(DeclarativeBase):
pass#2.2 在 FastAPI 中管理数据库生命周期
# main.py
from contextlib import asynccontextmanager
from database import engine, Base
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动:创建所有表
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
# 关闭:释放连接池
await engine.dispose()
app = FastAPI(lifespan=lifespan)#2.3 数据库依赖
# dependencies.py
from database import AsyncSessionLocal
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
yield session#3. 模型定义
#3.1 定义用户模型
# models/user.py
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), nullable=False)
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
onupdate=func.now()
)
# 关系
posts: Mapped[list["Post"]] = relationship("Post", back_populates="author", lazy="selectin")
def __repr__(self):
return f"<User {self.email}>"#3.2 定义帖子模型(多对多关系)
# models/post.py
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, Text, Table, Column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from database import Base
from datetime import datetime
# 多对多关联表
tags = Table(
"tags", Base.metadata,
Column("post_id", Integer, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True),
Column("tag_id", Integer, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)
class Tag(Base):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30), unique=True, nullable=False)
class Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
published: Mapped[bool] = mapped_column(Boolean, default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
created_at: Mapped[datetime] = mapped_column(server_default=func.now())
# 关系
author: Mapped["User"] = relationship("User", back_populates="posts")
tags: Mapped[list["Tag"]] = relationship("Tag", secondary=tags, lazy="selectin")#4. CRUD 操作
#4.1 基础查询
# schemas/user.py
from pydantic import BaseModel, EmailStr
from datetime import datetime
class UserBase(BaseModel):
name: str
email: EmailStr
class UserCreate(UserBase):
password: str
class UserUpdate(BaseModel):
name: str | None = None
email: EmailStr | None = None
password: str | None = None
class UserSchema(UserBase):
id: int
is_active: bool
created_at: datetime
class Config:
from_attributes = True#4.2 Repository 模式封装
# repositories/user_repository.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from sqlalchemy.orm import selectinload
from models.user import User
from models.post import Post
from typing import Optional
class UserRepository:
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_id(self, user_id: int) -> Optional[User]:
result = await self.db.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> Optional[User]:
result = await self.db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def get_multi(
self, skip: int = 0, limit: int = 20, search: str | None = None
) -> list[User]:
query = select(User)
if search:
query = query.where(
or_(
User.name.ilike(f"%{search}%"),
User.email.ilike(f"%{search}%")
)
)
query = query.offset(skip).limit(limit).order_by(User.created_at.desc())
result = await self.db.execute(query)
return list(result.scalars().all())
async def create(self, user_data: dict) -> User:
user = User(**user_data)
self.db.add(user)
await self.db.commit()
await self.db.refresh(user)
return user
async def update(self, user: User, update_data: dict) -> User:
for key, value in update_data.items():
if value is not None:
setattr(user, key, value)
await self.db.commit()
await self.db.refresh(user)
return user
async def delete(self, user: User) -> None:
await self.db.delete(user)
await self.db.commit()
async def get_user_posts(self, user_id: int) -> list[Post]:
result = await self.db.execute(
select(Post)
.where(Post.author_id == user_id)
.options(selectinload(Post.tags))
.order_by(Post.created_at.desc())
)
return list(result.scalars().all())#5. 在路由中使用
#5.1 用户 CRUD 路由
# routers/users.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from dependencies import get_db
from repositories.user_repository import UserRepository
from schemas.user import UserSchema, UserCreate, UserUpdate
from passlib.context import CryptContext
router = APIRouter(prefix="/users", tags=["用户管理"])
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_repo(db: AsyncSession = Depends(get_db)) -> UserRepository:
return UserRepository(db)
def hash_password(password: str) -> str:
return pwd_context.hash(password)
@router.get("/", response_model=List[UserSchema])
async def list_users(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
search: str | None = None,
repo: UserRepository = Depends(get_repo),
):
users = await repo.get_multi(skip=skip, limit=limit, search=search)
return users
@router.get("/{user_id}", response_model=UserSchema)
async def get_user(user_id: int, repo: UserRepository = Depends(get_repo)):
user = await repo.get_by_id(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user
@router.post("/", response_model=UserSchema, status_code=status.HTTP_201_CREATED)
async def create_user(
data: UserCreate,
repo: UserRepository = Depends(get_repo),
):
existing = await repo.get_by_email(data.email)
if existing:
raise HTTPException(status_code=409, detail="邮箱已被注册")
user = await repo.create({
**data.model_dump(exclude={"password"}),
"hashed_password": hash_password(data.password),
})
return user
@router.put("/{user_id}", response_model=UserSchema)
async def update_user(
user_id: int,
data: UserUpdate,
repo: UserRepository = Depends(get_repo),
):
user = await repo.get_by_id(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
update_dict = data.model_dump(exclude_unset=True)
if "password" in update_dict:
update_dict["hashed_password"] = hash_password(update_dict.pop("password"))
updated = await repo.update(user, update_dict)
return updated
@router.delete("/{user_id}", status_code=204)
async def delete_user(user_id: int, repo: UserRepository = Depends(get_repo)):
user = await repo.get_by_id(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
await repo.delete(user)#6. 事务与批量操作
#6.1 事务控制
async def bulk_create_posts(db: AsyncSession, posts_data: list[dict]):
async with db.begin():
for data in posts_data:
post = Post(**data)
db.add(post)
# 自动提交,异常自动回滚
await db.commit()#6.2 乐观锁(版本控制)
class User(Base):
# ... 其他字段
version: Mapped[int] = mapped_column(Integer, default=1)
async def update_with_optimistic_lock(
repo: UserRepository, user_id: int, expected_version: int, data: dict
):
user = await repo.get_by_id(user_id)
if user.version != expected_version:
raise HTTPException(409, "数据已被其他人修改,请刷新后重试")
data["version"] = expected_version + 1
return await repo.update(user, data)#7. 常用查询模式速查
# 条件查询
select(User).where(User.is_active == True)
# 模糊搜索
User.name.ilike("%alice%")
# 分页
select(User).offset(0).limit(10)
# 排序
select(User).order_by(User.created_at.desc())
# 预加载关联(避免 N+1 问题)
select(Post).options(selectinload(Post.author), selectinload(Post.tags))
# 聚合统计
from sqlalchemy import func, select
stmt = select(User.role, func.count(User.id)).group_by(User.role)
result = await db.execute(stmt)
# 联表查询
stmt = select(Post.title, User.name).join(User, Post.author_id == User.id)
result = await db.execute(stmt)#8. 小结
| 组件 | 作用 |
|---|---|
create_async_engine | 创建异步数据库引擎 |
AsyncSessionLocal | 异步会话工厂 |
Mapped[T] + mapped_column | SQLAlchemy 2.0 类型注解 |
DeclarativeBase | 所有模型的基类 |
select() + execute() | 异步查询 |
selectinload | 预加载关联,防止 N+1 |
lifespan 管理器 | 在 FastAPI 启动/关闭时管理连接池 |
💡 最佳实践:使用 Repository 模式封装数据访问,路由只处理 HTTP 逻辑,数据库操作交给 Repository。好处是代码可测试,业务逻辑与框架解耦。
🔗 扩展阅读

