依赖注入系统(Dependency Injection):实现代码复用、权限校验与数据库连接

📂 所属阶段:第二阶段 — 进阶黑科技(核心篇)
🔗 相关章节:异步编程深度解析 · 中间件应用


1. 什么是依赖注入?

1.1 不用依赖注入的痛

假设你需要查询数据库获取当前用户:

# ❌ 没有依赖注入:每个路由都要重复这些代码
@app.get("/users/me")
def get_current_user():
    token = request.headers.get("Authorization")
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(401, "Unauthorized")
    return user

@app.get("/orders")
def get_orders():
    token = request.headers.get("Authorization")
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(401, "Unauthorized")
    # ... 更多逻辑
    return orders

问题:认证逻辑在每个路由中重复,修改一处要改 N 处。

1.2 用依赖注入解决

# ✅ 依赖注入:把公共逻辑抽成"依赖项",按需注入
async def get_current_user(token: str = Depends(get_token)):
    """认证依赖:自动提取 token 并验证用户"""
    user = await db.get_user_by_token(token)
    if not user:
        raise HTTPException(401, "Unauthorized")
    return user

@app.get("/users/me")
async def get_me(user: User = Depends(get_current_user)):
    return user

@app.get("/orders")
async def get_orders(user: User = Depends(get_current_user)):
    # user 已经通过注入获得了,无需重复认证
    return await get_user_orders(user.id)

FastAPI 的依赖注入就是让路由函数的参数"声明式"地声明它需要什么,系统自动解析并注入值。


2. 依赖注入基础

2.1 Depends 函数

Depends 是 FastAPI 依赖注入的核心,告诉 FastAPI"这个参数的值从另一个可调用对象获取"。

from fastapi import Depends, FastAPI

app = FastAPI()

# 简单的依赖函数
def get_query_param(q: str = "default"):
    return f"查询内容: {q}"

@app.get("/search")
async def search(q: str = Depends(get_query_param)):
    return {"q": q}

2.2 带参数的依赖函数

from fastapi import Query

# 有参数的依赖工厂
def pagination_params(
    page: int = Query(1, ge=1),
    page_size: int = Query(10, ge=1, le=100)
):
    return {"page": page, "page_size": page_size}

@app.get("/items")
async def list_items(params: dict = Depends(pagination_params)):
    # params = {"page": 1, "page_size": 10}
    return params

2.3 类作为依赖

类比函数更直观,FastAPI 会自动调用类来实例化:

from fastapi import Header

class AuthService:
    def __init__(self, authorization: str = Header(...)):
        if not authorization.startswith("Bearer "):
            raise HTTPException(401, "Invalid token")
        self.token = authorization.replace("Bearer ", "")
        self.user = self._verify_token(self.token)

    def _verify_token(self, token: str):
        # 模拟验证
        if token == "admin-secret":
            return {"id": 1, "role": "admin", "name": "Admin"}
        elif token == "user-secret":
            return {"id": 2, "role": "user", "name": "Bob"}
        raise HTTPException(401, "Invalid token")

# 注入为一个实例
@app.get("/profile")
async def get_profile(auth: AuthService = Depends()):
    return auth.user

3. 依赖注入与数据库连接

3.1 创建数据库依赖

from fastapi import Depends, FastAPI
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base

# 数据库配置
DATABASE_URL = "sqlite+aiosqlite:///./app.db"

engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
AsyncSessionLocal = sessionmaker(
    engine, class_=AsyncSession, expire_on_commit=False
)
Base = declarative_base()

# 数据库依赖(异步)
async def get_db() -> AsyncSession:
    async with AsyncSessionLocal() as session:
        yield session  # 使用 with 块结束后自动关闭连接

# 使用方式
@app.get("/users/{user_id}")
async def get_user(
    user_id: int,
    db: AsyncSession = Depends(get_db)
):
    from sqlalchemy import select
    result = await db.execute(select(User).where(User.id == user_id))
    user = result.scalar_one_or_none()
    if not user:
        raise HTTPException(404, "User not found")
    return user

3.2 连接池复用

依赖注入天然支持连接池复用:

# 数据库连接池由 FastAPI 的 lifespan 管理
# 依赖函数每次被调用时获取一个连接,用完归还
async def get_db():
    async with AsyncSessionLocal() as session:
        yield session  # yield 之后 FastAPI 负责清理

4. 依赖链(级联依赖)

依赖可以依赖其他依赖,形成链式结构:

# 第一层:解析 token
def get_token(authorization: str = Header(...)):
    if not authorization:
        raise HTTPException(401, "Missing Authorization header")
    return authorization.replace("Bearer ", "")

# 第二层:根据 token 查用户(依赖第一层)
def get_current_user(token: str = Depends(get_token)):
    user = fake_verify_token(token)
    if not user:
        raise HTTPException(401, "Invalid token")
    return user

# 第三层:检查管理员权限(依赖第二层)
def require_admin(current_user: dict = Depends(get_current_user)):
    if current_user.get("role") != "admin":
        raise HTTPException(403, "Admin access required")
    return current_user

# 使用:普通用户路由
@app.get("/profile")
async def profile(user: dict = Depends(get_current_user)):
    return user

# 使用:管理员路由
@app.delete("/users/{user_id}")
async def delete_user(
    user_id: int,
    admin: dict = Depends(require_admin)
):
    return {"deleted": user_id, "by": admin["name"]}

5. 依赖的可选参数与默认值

5.1 可选依赖

from typing import Optional

async def get_optional_user(
    token: Optional[str] = Header(None, alias="Authorization")
) -> Optional[dict]:
    if not token:
        return None
    return verify_token(token)

# 使用 Optional 依赖
@app.get("/personalized")
async def personalized(
    user: Optional[dict] = Depends(get_optional_user)
):
    if user:
        return {"message": f"Welcome {user['name']}"}
    return {"message": "Welcome, guest!"}

5.2 有条件地应用依赖

from fastapi import Query, Depends

# 根据查询参数决定是否需要认证
async def optional_auth(
    token: Optional[str] = Query(None, alias="token")
):
    if not token:
        return None
    return verify_token(token)

@app.get("/items")
async def list_items(
    user: Optional[dict] = Depends(optional_auth),
    admin: Optional[dict] = Depends(optional_auth)
):
    return {"authenticated": user is not None}

6. 依赖的缓存与复用

6.1 默认行为:同一请求内缓存

默认情况下,同一请求中多次使用同一个依赖,只执行一次

def get_user():
    print("🔍 查询用户(只执行一次)")
    return {"id": 1, "name": "Alice"}

@app.get("/a")
@app.get("/b")
@app.get("/c")
async def endpoints(user: dict = Depends(get_user)):
    # 三个路由同时调用 user 依赖
    # 在同一个请求中只打印一次 "查询用户"
    return {"user": user}

6.2 使用 use_cache=False 禁用缓存

async def get_request_id(request: Request):
    # 每次都重新生成
    return request.headers.get("X-Request-ID")

# 强制不使用缓存
async def always_fresh(
    req_id: str = Depends(get_request_id, use_cache=False)
):
    return req_id

7. 实战:构建完整认证体系

7.1 完整的依赖分层

请求 → 中间件 → 依赖注入链

         1. get_db          (数据库连接)
         2. get_token       (提取 token)
         3. get_current_user(验证用户)
         4. require_role    (权限检查)

             路由处理器

7.2 完整示例代码

"""
auth dependencies.py — 认证依赖模块
"""
from fastapi import Depends, HTTPException, Header, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, List
from functools import wraps
import jwt

# 简易 JWT 验证
SECRET_KEY = "super-secret-key"
ALGORITHM = "HS256"

security = HTTPBearer(auto_error=False)

def decode_token(token: str) -> dict:
    try:
        return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
    except jwt.ExpiredSignatureError:
        raise HTTPException(401, "Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(401, "Invalid token")

# ── 依赖 1:获取当前用户(可选)─────────────────────
async def get_current_user_optional(
    credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[dict]:
    if not credentials:
        return None
    return decode_token(credentials.credentials)

# ── 依赖 2:获取当前用户(必需)─────────────────────
async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security)
) -> dict:
    return decode_token(credentials.credentials)

# ── 依赖 3:角色权限校验─────────────────────────────
def require_roles(allowed_roles: List[str]):
    """工厂函数:生成特定角色的依赖"""
    async def checker(user: dict = Depends(get_current_user)) -> dict:
        if user.get("role") not in allowed_roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Requires one of roles: {allowed_roles}"
            )
        return user
    return checker

# ── 预定义常用角色依赖───────────────────────────────
require_admin = require_roles(["admin"])
require_editor = require_roles(["admin", "editor"])
require_any_user = require_roles(["admin", "editor", "user"])
"""
main.py — 使用认证依赖
"""
from fastapi import FastAPI, Depends
from auth import (
    get_current_user,
    get_current_user_optional,
    require_admin,
    require_editor,
)

app = FastAPI()

@app.get("/public")
async def public():
    return {"public": True}

@app.get("/me")
async def me(user: dict = Depends(get_current_user)):
    return user

@app.get("/dashboard")
async def dashboard(user: dict = Depends(require_any_user)):
    return {"dashboard": True, "user": user["sub"]}

@app.post("/articles")
async def create_article(user: dict = Depends(require_editor)):
    return {"created_by": user["sub"]}

@app.delete("/users/{user_id}")
async def delete_user(user_id: int, user: dict = Depends(require_admin)):
    return {"deleted": user_id, "by": user["sub"]}

8. 依赖注入的进阶用法

8.1 异步依赖

# 异步依赖
async def get_async_db():
    async with AsyncSessionLocal() as session:
        yield session

@app.get("/items")
async def list_items(db: AsyncSession = Depends(get_async_db)):
    ...

8.2 多层依赖合并

# 同时注入多个依赖
@app.get("/admin-panel")
async def admin_panel(
    user: dict = Depends(require_admin),
    db: AsyncSession = Depends(get_db)
):
    result = await db.execute(select(Log).order_by(desc(Log.created_at)))
    return result.scalars().all()

8.3 全局依赖(应用于所有路由)

app = FastAPI()

# 所有路由都会先执行这个依赖
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    ...

# 或者用依赖注入(更推荐)
from fastapi import FastAPI, Request

async def common_params(request: Request):
    return {"path": request.url.path}

# 在 app 层级添加依赖(需要覆盖 add_api_route)
app.add_api_route(
    "/items",
    list_items,
    dependencies=[Depends(require_admin)]
)

9. 小结

模式说明适用场景
Depends(func)最基础的依赖注入参数提取、认证
类依赖 Depends(Class)自动实例化服务类、数据库连接
依赖链依赖间可嵌套认证 → 用户 → 权限
use_cache=False禁用请求级缓存每次请求需要新实例
全局 dependencies应用于所有路由全局认证、日志
Optional + Depends可选依赖公开/私有混合接口

💡 核心思想:依赖注入让代码变得可测试(可以 mock 依赖)、可复用(一处定义,到处使用)、声明式(代码意图一目了然)。


🔗 扩展阅读