FastAPI依赖注入系统完全指南

📂 所属阶段:第二阶段 — 进阶黑科技(核心篇)
🔗 相关章节:FastAPI异步编程深度解析 · FastAPI中间件应用

目录

什么是依赖注入?

不用依赖注入的痛点

假设你需要查询数据库获取当前用户:

# ❌ 没有依赖注入:每个路由都要重复这些代码
@app.get("/users/me")
def get_current_user():
    token = request.headers.get("Authorization")
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(401, "Unauthorized")
    return user

@app.get("/orders")
def get_orders():
    token = request.headers.get("Authorization")
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(401, "Unauthorized")
    # ... 更多逻辑
    return orders

问题:认证逻辑在每个路由中重复,修改一处要改 N 处。

用依赖注入解决问题

# ✅ 依赖注入:把公共逻辑抽成"依赖项",按需注入
async def get_current_user(token: str = Depends(get_token)):
    """认证依赖:自动提取 token 并验证用户"""
    user = await db.get_user_by_token(token)
    if not user:
        raise HTTPException(401, "Unauthorized")
    return user

@app.get("/users/me")
async def get_me(user: User = Depends(get_current_user)):
    return user

@app.get("/orders")
async def get_orders(user: User = Depends(get_current_user)):
    # user 已经通过注入获得了,无需重复认证
    return await get_user_orders(user.id)

FastAPI 的依赖注入就是让路由函数的参数"声明式"地声明它需要什么,系统自动解析并注入值。

依赖注入的核心优势

  1. 代码复用:公共逻辑只需定义一次
  2. 解耦:路由函数不需要知道依赖的具体实现
  3. 可测试性:依赖可以轻松替换为模拟对象
  4. 维护性:修改依赖逻辑不影响使用它的路由
  5. 性能:内置缓存机制,同一请求内的依赖只执行一次

依赖注入基础

Depends 函数详解

Depends 是 FastAPI 依赖注入的核心,告诉 FastAPI"这个参数的值从另一个可调用对象获取"。

from fastapi import Depends, FastAPI

app = FastAPI()

# 简单的依赖函数
def get_query_param(q: str = "default"):
    return f"查询内容: {q}"

@app.get("/search")
async def search(q: str = Depends(get_query_param)):
    return {"q": q}

带参数的依赖函数

from fastapi import Query

# 有参数的依赖工厂
def pagination_params(
    page: int = Query(1, ge=1),
    page_size: int = Query(10, ge=1, le=100)
):
    return {"page": page, "page_size": page_size}

@app.get("/items")
async def list_items(params: dict = Depends(pagination_params)):
    # params = {"page": 1, "page_size": 10}
    return params

类作为依赖

类比函数更直观,FastAPI 会自动调用类来实例化:

from fastapi import Header, HTTPException

class AuthService:
    def __init__(self, authorization: str = Header(...)):
        if not authorization.startswith("Bearer "):
            raise HTTPException(401, "Invalid token")
        self.token = authorization.replace("Bearer ", "")
        self.user = self._verify_token(self.token)

    def _verify_token(self, token: str):
        # 模拟验证
        if token == "admin-secret":
            return {"id": 1, "role": "admin", "name": "Admin"}
        elif token == "user-secret":
            return {"id": 2, "role": "user", "name": "Bob"}
        raise HTTPException(401, "Invalid token")

# 注入为一个实例
@app.get("/profile")
async def get_profile(auth: AuthService = Depends()):
    return auth.user

依赖函数的类型

类型特点适用场景
普通函数简单的参数提取或转换查询参数处理
异步函数支持 await 操作数据库查询、HTTP请求
类构造器创建复杂对象服务类、配置类
生成器函数支持资源清理数据库连接、文件操作

数据库连接与依赖

创建数据库依赖

from fastapi import Depends, FastAPI
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base

# 数据库配置
DATABASE_URL = "sqlite+aiosqlite:///./app.db"

engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
AsyncSessionLocal = sessionmaker(
    engine, class_=AsyncSession, expire_on_commit=False
)
Base = declarative_base()

# 数据库依赖(异步)
async def get_db() -> AsyncSession:
    async with AsyncSessionLocal() as session:
        try:
            yield session  # 使用 with 块结束后自动关闭连接
        finally:
            await session.close()

# 使用方式
@app.get("/users/{user_id}")
async def get_user(
    user_id: int,
    db: AsyncSession = Depends(get_db)
):
    from sqlalchemy import select
    result = await db.execute(select(User).where(User.id == user_id))
    user = result.scalar_one_or_none()
    if not user:
        raise HTTPException(404, "User not found")
    return user

连接池复用

依赖注入天然支持连接池复用:

# 数据库连接池由 FastAPI 的 lifespan 管理
# 依赖函数每次被调用时获取一个连接,用完归还
async def get_db():
    async with AsyncSessionLocal() as session:
        try:
            yield session  # yield 之后 FastAPI 负责清理
        finally:
            await session.close()

# 在应用启动时初始化
@app.on_event("startup")
async def startup():
    # 初始化数据库连接池
    pass

@app.on_event("shutdown")
async def shutdown():
    # 关闭数据库连接池
    pass

高级数据库依赖模式

from contextlib import asynccontextmanager

@asynccontextmanager
async def get_db_session():
    """数据库会话上下文管理器"""
    async with AsyncSessionLocal() as session:
        try:
            yield session
        except Exception:
            await session.rollback()
            raise
        else:
            await session.commit()

# 使用上下文管理器
@app.post("/users/")
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(lambda: get_db_session().__anext__())):
    # 更推荐使用传统的依赖注入模式
    pass

依赖链与级联依赖

基础依赖链

依赖可以依赖其他依赖,形成链式结构:

# 第一层:解析 token
def get_token(authorization: str = Header(...)):
    if not authorization:
        raise HTTPException(401, "Missing Authorization header")
    return authorization.replace("Bearer ", "")

# 第二层:根据 token 查用户(依赖第一层)
def get_current_user(token: str = Depends(get_token)):
    user = fake_verify_token(token)
    if not user:
        raise HTTPException(401, "Invalid token")
    return user

# 第三层:检查管理员权限(依赖第二层)
def require_admin(current_user: dict = Depends(get_current_user)):
    if current_user.get("role") != "admin":
        raise HTTPException(403, "Admin access required")
    return current_user

# 使用:普通用户路由
@app.get("/profile")
async def profile(user: dict = Depends(get_current_user)):
    return user

# 使用:管理员路由
@app.delete("/users/{user_id}")
async def delete_user(
    user_id: int,
    admin: dict = Depends(require_admin)
):
    return {"deleted": user_id, "by": admin["name"]}

复杂依赖链示例

# 用户认证链
async def authenticate_user(token: str = Depends(get_token)) -> dict:
    """用户认证:验证token并返回用户信息"""
    user = await verify_token_async(token)
    if not user:
        raise HTTPException(401, "Invalid token")
    return user

async def authorize_user(user: dict = Depends(authenticate_user)) -> dict:
    """用户授权:检查用户权限"""
    if user.get("is_active") is False:
        raise HTTPException(403, "User account is deactivated")
    return user

def require_permission(permission: str):
    """权限检查工厂函数"""
    async def check_permission(user: dict = Depends(authorize_user)) -> dict:
        if permission not in user.get("permissions", []):
            raise HTTPException(403, f"Permission '{permission}' required")
        return user
    return check_permission

# 使用复杂依赖链
@app.put("/posts/{post_id}")
async def update_post(
    post_id: int,
    user: dict = Depends(require_permission("edit_posts"))
):
    return {"message": f"Post {post_id} updated by {user['name']}"}

依赖链的性能考虑

# 优化:合并相关依赖以减少调用次数
async def get_user_with_permissions(token: str = Depends(get_token)):
    """一次性获取用户及其权限信息"""
    user = await verify_token_async(token)
    if not user:
        raise HTTPException(401, "Invalid token")
    
    # 同时获取权限信息,避免额外查询
    permissions = await get_user_permissions(user["id"])
    user["permissions"] = permissions
    
    return user

# 使用合并后的依赖
@app.get("/dashboard")
async def get_dashboard(user: dict = Depends(get_user_with_permissions)):
    return {"user": user["name"], "permissions": user["permissions"]}

可选参数与默认值

可选依赖

from typing import Optional

async def get_optional_user(
    token: Optional[str] = Header(None, alias="Authorization")
) -> Optional[dict]:
    if not token:
        return None
    return verify_token(token)

# 使用 Optional 依赖
@app.get("/personalized")
async def personalized(
    user: Optional[dict] = Depends(get_optional_user)
):
    if user:
        return {"message": f"Welcome {user['name']}"}
    return {"message": "Welcome, guest!"}

有条件地应用依赖

from fastapi import Query, Depends

# 根据查询参数决定是否需要认证
async def optional_auth(
    token: Optional[str] = Query(None, alias="token")
):
    if not token:
        return None
    return verify_token(token)

@app.get("/items")
async def list_items(
    user: Optional[dict] = Depends(optional_auth),
    admin: Optional[dict] = Depends(optional_auth)
):
    return {"authenticated": user is not None}

默认值依赖

# 依赖函数可以有默认参数
def get_settings(
    debug: bool = Query(False, description="启用调试模式"),
    timeout: int = Query(30, ge=1, le=300, description="请求超时时间")
) -> dict:
    return {
        "debug": debug,
        "timeout": timeout,
        "api_version": "1.0.0"
    }

@app.get("/status")
async def get_status(settings: dict = Depends(get_settings)):
    return settings

依赖缓存与复用

默认行为:同一请求内缓存

默认情况下,同一请求中多次使用同一个依赖,只执行一次

def get_user():
    print("🔍 查询用户(只执行一次)")
    return {"id": 1, "name": "Alice"}

@app.get("/a")
@app.get("/b")
@app.get("/c")
async def endpoints(user: dict = Depends(get_user)):
    # 三个路由同时调用 user 依赖
    # 在同一个请求中只打印一次 "查询用户"
    return {"user": user}

使用 use_cache=False 禁用缓存

async def get_request_id(request: Request):
    # 每次都重新生成
    return request.headers.get("X-Request-ID")

# 强制不使用缓存
async def always_fresh(
    req_id: str = Depends(get_request_id, use_cache=False)
):
    return req_id

缓存策略优化

from functools import lru_cache
import asyncio

# 对于计算密集型依赖,使用缓存
@lru_cache(maxsize=128)
def expensive_computation(param: str) -> str:
    # 模拟昂贵的计算
    result = param.upper()
    for _ in range(1000):
        result = hash(result)
    return str(result)

async def cached_dependency(param: str):
    """异步包装缓存函数"""
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(None, expensive_computation, param)

@app.get("/compute/{param}")
async def compute_endpoint(
    param: str,
    result: str = Depends(cached_dependency)
):
    return {"input": param, "result": result}

实战:构建完整认证体系

完整的依赖分层

请求 → 中间件 → 依赖注入链

         1. get_db          (数据库连接)
         2. get_token       (提取 token)
         3. get_current_user(验证用户)
         4. require_role    (权限检查)

             路由处理器

完整示例代码

"""
auth_dependencies.py — 认证依赖模块
"""
from fastapi import Depends, HTTPException, Header, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, List
from functools import wraps
import jwt
import time
from datetime import datetime, timedelta

# 简易 JWT 验证
SECRET_KEY = "super-secret-key"
ALGORITHM = "HS256"

security = HTTPBearer(auto_error=False)

def decode_token(token: str) -> dict:
    try:
        return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
    except jwt.ExpiredSignatureError:
        raise HTTPException(401, "Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(401, "Invalid token")

# ── 依赖 1:获取当前用户(可选)─────────────────────
async def get_current_user_optional(
    credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[dict]:
    if not credentials:
        return None
    return decode_token(credentials.credentials)

# ── 依赖 2:获取当前用户(必需)─────────────────────
async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security)
) -> dict:
    return decode_token(credentials.credentials)

# ── 依赖 3:角色权限校验─────────────────────────────
def require_roles(allowed_roles: List[str]):
    """工厂函数:生成特定角色的依赖"""
    async def checker(user: dict = Depends(get_current_user)) -> dict:
        if user.get("role") not in allowed_roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Requires one of roles: {allowed_roles}"
            )
        return user
    return checker

# ── 依赖 4:权限级别校验─────────────────────────────
def require_permission_level(min_level: int):
    """工厂函数:基于权限级别的依赖"""
    async def checker(user: dict = Depends(get_current_user)) -> dict:
        user_level = user.get("permission_level", 0)
        if user_level < min_level:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Permission level {min_level} required, got {user_level}"
            )
        return user
    return checker

# ── 预定义常用角色依赖───────────────────────────────
require_admin = require_roles(["admin"])
require_editor = require_roles(["admin", "editor"])
require_any_user = require_roles(["admin", "editor", "user"])
require_system_admin = require_permission_level(10)

# ── 依赖 5:IP白名单校验─────────────────────────────
async def check_ip_whitelist(
    request: Request,
    user: dict = Depends(get_current_user)
):
    """IP白名单校验依赖"""
    client_ip = request.client.host
    allowed_ips = user.get("allowed_ips", [])
    
    if allowed_ips and client_ip not in allowed_ips:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="IP not in whitelist"
        )
    
    return user
"""
main.py — 使用认证依赖
"""
from fastapi import FastAPI, Depends, Request
from auth_dependencies import (
    get_current_user,
    get_current_user_optional,
    require_admin,
    require_editor,
    require_any_user,
    require_system_admin,
    check_ip_whitelist
)

app = FastAPI()

@app.get("/public")
async def public():
    return {"public": True}

@app.get("/me")
async def me(user: dict = Depends(get_current_user)):
    return user

@app.get("/dashboard")
async def dashboard(user: dict = Depends(require_any_user)):
    return {"dashboard": True, "user": user["sub"]}

@app.post("/articles")
async def create_article(user: dict = Depends(require_editor)):
    return {"created_by": user["sub"]}

@app.delete("/users/{user_id}")
async def delete_user(user_id: int, user: dict = Depends(require_admin)):
    return {"deleted": user_id, "by": user["sub"]}

@app.put("/system/config")
async def update_system_config(request: Request, user: dict = Depends(require_system_admin)):
    return {"updated": True, "by": user["sub"], "ip": request.client.host}

@app.get("/secure-area")
async def secure_area(user: dict = Depends(check_ip_whitelist)):
    return {"message": "Access granted from whitelisted IP", "user": user["sub"]}

进阶用法

异步依赖

# 异步依赖
async def get_async_db():
    async with AsyncSessionLocal() as session:
        try:
            yield session
        finally:
            await session.close()

@app.get("/items")
async def list_items(db: AsyncSession = Depends(get_async_db)):
    # 使用异步数据库连接
    pass

多层依赖合并

# 同时注入多个依赖
@app.get("/admin-panel")
async def admin_panel(
    user: dict = Depends(require_admin),
    db: AsyncSession = Depends(get_db)
):
    # 同时使用用户权限和数据库连接
    result = await db.execute(select(Log).order_by(desc(Log.created_at)))
    return result.scalars().all()

全局依赖(应用于所有路由)

from fastapi import FastAPI

app = FastAPI()

# 在路由级别添加全局依赖
@app.get("/api/v1/users", dependencies=[Depends(require_any_user)])
async def list_users():
    return {"users": []}

@app.post("/api/v1/posts", dependencies=[Depends(require_editor)])
async def create_post():
    return {"message": "Post created"}

依赖注入的类型提示

from typing import Callable, Any
from fastapi import Depends

# 依赖注入的类型注解
async def get_current_user() -> dict:
    return {"id": 1, "name": "user"}

def create_user_dependency() -> Callable[[], dict]:
    """创建用户依赖的工厂函数"""
    async def dependency() -> dict:
        return await get_current_user()
    return dependency

# 使用类型注解的依赖
@app.get("/profile")
async def get_profile(user: dict = Depends(create_user_dependency())):
    return user

性能优化建议

1. 依赖缓存策略

from functools import lru_cache
import asyncio

# 对于昂贵的计算或查询,使用缓存
@lru_cache(maxsize=128)
def get_cached_config(config_name: str) -> dict:
    """缓存配置获取"""
    # 模拟从数据库或文件获取配置
    return {"name": config_name, "value": "cached_value"}

async def get_config_dependency(config_name: str):
    """异步包装缓存函数"""
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(None, get_cached_config, config_name)

2. 数据库连接优化

# 连接池配置
from sqlalchemy.pool import QueuePool

engine = create_engine(
    DATABASE_URL,
    poolclass=QueuePool,
    pool_size=20,          # 连接池大小
    max_overflow=30,       # 最大溢出连接数
    pool_pre_ping=True,    # 连接前测试
    pool_recycle=3600      # 连接回收时间
)

3. 依赖执行顺序优化

# 优化依赖执行顺序,将快速的依赖前置
async def quick_check():
    """快速检查,优先执行"""
    return "quick_result"

async def heavy_operation(quick_result: str = Depends(quick_check)):
    """耗时操作,依赖快速检查"""
    # 只有快速检查通过才会执行耗时操作
    return f"heavy_{quick_result}"

@app.get("/optimized-endpoint")
async def optimized_endpoint(result: str = Depends(heavy_operation)):
    return {"result": result}

常见陷阱与避坑指南

陷阱1:依赖函数的副作用

# ❌ 错误:依赖函数有副作用
counter = 0
def bad_dependency():
    global counter
    counter += 1  # 副作用:修改全局变量
    return counter

# ✅ 正确:依赖函数应该是纯函数
def good_dependency():
    return {"timestamp": time.time()}

陷阱2:依赖链中的异常处理

# ❌ 问题:异常传播可能不明确
def get_user():
    # 可能抛出异常
    return risky_operation()

def get_user_profile(user: dict = Depends(get_user)):
    # 如果 get_user 抛出异常,这里不会执行
    return {"profile": user["id"]}

# ✅ 解决:明确的异常处理
def get_user():
    try:
        return risky_operation()
    except Exception as e:
        raise HTTPException(status_code=401, detail=str(e))

陷阱3:循环依赖

# ❌ 错误:循环依赖会导致运行时错误
def dep_a(b: str = Depends(dep_b)):
    return "a"

def dep_b(a: str = Depends(dep_a)):  # 循环依赖!
    return "b"

陷阱4:过度使用依赖

# ❌ 过度复杂:不必要的依赖嵌套
def get_config():
    return {"setting": "value"}

def get_service(config: dict = Depends(get_config)):
    return MyService(config)

def get_controller(service: MyService = Depends(get_service)):
    return Controller(service)

# ✅ 简洁:直接注入需要的依赖
@app.get("/endpoint")
async def endpoint(service: MyService = Depends(get_service)):
    return service.process()

陷阱5:依赖函数中的阻塞操作

# ❌ 错误:在依赖中使用阻塞操作
def blocking_dependency():
    time.sleep(1)  # 阻塞整个事件循环
    return "result"

# ✅ 正确:将阻塞操作移到线程池
async def non_blocking_dependency():
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(None, time.sleep, 1)

相关教程

依赖注入是FastAPI的核心特性之一,合理使用可以大大提高代码的可维护性和可测试性。建议将公共的认证、授权、数据库连接等逻辑抽取为依赖,避免在路由函数中重复编写相同的代码。 FastAPI的依赖注入系统内置了缓存机制,在同一次请求中,相同的依赖只会执行一次。这有助于提高性能,但也需要注意依赖函数应该是幂等的,不应该有副作用。

总结

模式说明适用场景
Depends(func)最基础的依赖注入参数提取、认证
类依赖 Depends(Class)自动实例化服务类、配置类
依赖链依赖间可嵌套认证 → 用户 → 权限
use_cache=False禁用请求级缓存每次请求需要新实例
全局 dependencies应用于所有路由全局认证、日志
Optional + Depends可选依赖公开/私有混合接口

💡 核心思想:依赖注入让代码变得可测试(可以 mock 依赖)、可复用(一处定义,到处使用)、声明式(代码意图一目了然)。

FastAPI的依赖注入系统是构建企业级应用的重要工具,它不仅提高了代码的可维护性,还能有效管理复杂的业务逻辑和权限控制。


1. 什么是依赖注入?

1.1 不用依赖注入的痛

假设你需要查询数据库获取当前用户:

# ❌ 没有依赖注入:每个路由都要重复这些代码
@app.get("/users/me")
def get_current_user():
    token = request.headers.get("Authorization")
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(401, "Unauthorized")
    return user

@app.get("/orders")
def get_orders():
    token = request.headers.get("Authorization")
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(401, "Unauthorized")
    # ... 更多逻辑
    return orders

问题:认证逻辑在每个路由中重复,修改一处要改 N 处。

1.2 用依赖注入解决

# ✅ 依赖注入:把公共逻辑抽成"依赖项",按需注入
async def get_current_user(token: str = Depends(get_token)):
    """认证依赖:自动提取 token 并验证用户"""
    user = await db.get_user_by_token(token)
    if not user:
        raise HTTPException(401, "Unauthorized")
    return user

@app.get("/users/me")
async def get_me(user: User = Depends(get_current_user)):
    return user

@app.get("/orders")
async def get_orders(user: User = Depends(get_current_user)):
    # user 已经通过注入获得了,无需重复认证
    return await get_user_orders(user.id)

FastAPI 的依赖注入就是让路由函数的参数"声明式"地声明它需要什么,系统自动解析并注入值。


2. 依赖注入基础

2.1 Depends 函数

Depends 是 FastAPI 依赖注入的核心,告诉 FastAPI"这个参数的值从另一个可调用对象获取"。

from fastapi import Depends, FastAPI

app = FastAPI()

# 简单的依赖函数
def get_query_param(q: str = "default"):
    return f"查询内容: {q}"

@app.get("/search")
async def search(q: str = Depends(get_query_param)):
    return {"q": q}

2.2 带参数的依赖函数

from fastapi import Query

# 有参数的依赖工厂
def pagination_params(
    page: int = Query(1, ge=1),
    page_size: int = Query(10, ge=1, le=100)
):
    return {"page": page, "page_size": page_size}

@app.get("/items")
async def list_items(params: dict = Depends(pagination_params)):
    # params = {"page": 1, "page_size": 10}
    return params

2.3 类作为依赖

类比函数更直观,FastAPI 会自动调用类来实例化:

from fastapi import Header

class AuthService:
    def __init__(self, authorization: str = Header(...)):
        if not authorization.startswith("Bearer "):
            raise HTTPException(401, "Invalid token")
        self.token = authorization.replace("Bearer ", "")
        self.user = self._verify_token(self.token)

    def _verify_token(self, token: str):
        # 模拟验证
        if token == "admin-secret":
            return {"id": 1, "role": "admin", "name": "Admin"}
        elif token == "user-secret":
            return {"id": 2, "role": "user", "name": "Bob"}
        raise HTTPException(401, "Invalid token")

# 注入为一个实例
@app.get("/profile")
async def get_profile(auth: AuthService = Depends()):
    return auth.user

3. 依赖注入与数据库连接

3.1 创建数据库依赖

from fastapi import Depends, FastAPI
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base

# 数据库配置
DATABASE_URL = "sqlite+aiosqlite:///./app.db"

engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
AsyncSessionLocal = sessionmaker(
    engine, class_=AsyncSession, expire_on_commit=False
)
Base = declarative_base()

# 数据库依赖(异步)
async def get_db() -> AsyncSession:
    async with AsyncSessionLocal() as session:
        yield session  # 使用 with 块结束后自动关闭连接

# 使用方式
@app.get("/users/{user_id}")
async def get_user(
    user_id: int,
    db: AsyncSession = Depends(get_db)
):
    from sqlalchemy import select
    result = await db.execute(select(User).where(User.id == user_id))
    user = result.scalar_one_or_none()
    if not user:
        raise HTTPException(404, "User not found")
    return user

3.2 连接池复用

依赖注入天然支持连接池复用:

# 数据库连接池由 FastAPI 的 lifespan 管理
# 依赖函数每次被调用时获取一个连接,用完归还
async def get_db():
    async with AsyncSessionLocal() as session:
        yield session  # yield 之后 FastAPI 负责清理

4. 依赖链(级联依赖)

依赖可以依赖其他依赖,形成链式结构:

# 第一层:解析 token
def get_token(authorization: str = Header(...)):
    if not authorization:
        raise HTTPException(401, "Missing Authorization header")
    return authorization.replace("Bearer ", "")

# 第二层:根据 token 查用户(依赖第一层)
def get_current_user(token: str = Depends(get_token)):
    user = fake_verify_token(token)
    if not user:
        raise HTTPException(401, "Invalid token")
    return user

# 第三层:检查管理员权限(依赖第二层)
def require_admin(current_user: dict = Depends(get_current_user)):
    if current_user.get("role") != "admin":
        raise HTTPException(403, "Admin access required")
    return current_user

# 使用:普通用户路由
@app.get("/profile")
async def profile(user: dict = Depends(get_current_user)):
    return user

# 使用:管理员路由
@app.delete("/users/{user_id}")
async def delete_user(
    user_id: int,
    admin: dict = Depends(require_admin)
):
    return {"deleted": user_id, "by": admin["name"]}

5. 依赖的可选参数与默认值

5.1 可选依赖

from typing import Optional

async def get_optional_user(
    token: Optional[str] = Header(None, alias="Authorization")
) -> Optional[dict]:
    if not token:
        return None
    return verify_token(token)

# 使用 Optional 依赖
@app.get("/personalized")
async def personalized(
    user: Optional[dict] = Depends(get_optional_user)
):
    if user:
        return {"message": f"Welcome {user['name']}"}
    return {"message": "Welcome, guest!"}

5.2 有条件地应用依赖

from fastapi import Query, Depends

# 根据查询参数决定是否需要认证
async def optional_auth(
    token: Optional[str] = Query(None, alias="token")
):
    if not token:
        return None
    return verify_token(token)

@app.get("/items")
async def list_items(
    user: Optional[dict] = Depends(optional_auth),
    admin: Optional[dict] = Depends(optional_auth)
):
    return {"authenticated": user is not None}

6. 依赖的缓存与复用

6.1 默认行为:同一请求内缓存

默认情况下,同一请求中多次使用同一个依赖,只执行一次

def get_user():
    print("🔍 查询用户(只执行一次)")
    return {"id": 1, "name": "Alice"}

@app.get("/a")
@app.get("/b")
@app.get("/c")
async def endpoints(user: dict = Depends(get_user)):
    # 三个路由同时调用 user 依赖
    # 在同一个请求中只打印一次 "查询用户"
    return {"user": user}

6.2 使用 use_cache=False 禁用缓存

async def get_request_id(request: Request):
    # 每次都重新生成
    return request.headers.get("X-Request-ID")

# 强制不使用缓存
async def always_fresh(
    req_id: str = Depends(get_request_id, use_cache=False)
):
    return req_id

7. 实战:构建完整认证体系

7.1 完整的依赖分层

请求 → 中间件 → 依赖注入链

         1. get_db          (数据库连接)
         2. get_token       (提取 token)
         3. get_current_user(验证用户)
         4. require_role    (权限检查)

             路由处理器

7.2 完整示例代码

"""
auth dependencies.py — 认证依赖模块
"""
from fastapi import Depends, HTTPException, Header, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, List
from functools import wraps
import jwt

# 简易 JWT 验证
SECRET_KEY = "super-secret-key"
ALGORITHM = "HS256"

security = HTTPBearer(auto_error=False)

def decode_token(token: str) -> dict:
    try:
        return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
    except jwt.ExpiredSignatureError:
        raise HTTPException(401, "Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(401, "Invalid token")

# ── 依赖 1:获取当前用户(可选)─────────────────────
async def get_current_user_optional(
    credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[dict]:
    if not credentials:
        return None
    return decode_token(credentials.credentials)

# ── 依赖 2:获取当前用户(必需)─────────────────────
async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security)
) -> dict:
    return decode_token(credentials.credentials)

# ── 依赖 3:角色权限校验─────────────────────────────
def require_roles(allowed_roles: List[str]):
    """工厂函数:生成特定角色的依赖"""
    async def checker(user: dict = Depends(get_current_user)) -> dict:
        if user.get("role") not in allowed_roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Requires one of roles: {allowed_roles}"
            )
        return user
    return checker

# ── 预定义常用角色依赖───────────────────────────────
require_admin = require_roles(["admin"])
require_editor = require_roles(["admin", "editor"])
require_any_user = require_roles(["admin", "editor", "user"])
"""
main.py — 使用认证依赖
"""
from fastapi import FastAPI, Depends
from auth import (
    get_current_user,
    get_current_user_optional,
    require_admin,
    require_editor,
)

app = FastAPI()

@app.get("/public")
async def public():
    return {"public": True}

@app.get("/me")
async def me(user: dict = Depends(get_current_user)):
    return user

@app.get("/dashboard")
async def dashboard(user: dict = Depends(require_any_user)):
    return {"dashboard": True, "user": user["sub"]}

@app.post("/articles")
async def create_article(user: dict = Depends(require_editor)):
    return {"created_by": user["sub"]}

@app.delete("/users/{user_id}")
async def delete_user(user_id: int, user: dict = Depends(require_admin)):
    return {"deleted": user_id, "by": user["sub"]}

8. 依赖注入的进阶用法

8.1 异步依赖

# 异步依赖
async def get_async_db():
    async with AsyncSessionLocal() as session:
        yield session

@app.get("/items")
async def list_items(db: AsyncSession = Depends(get_async_db)):
    ...

8.2 多层依赖合并

# 同时注入多个依赖
@app.get("/admin-panel")
async def admin_panel(
    user: dict = Depends(require_admin),
    db: AsyncSession = Depends(get_db)
):
    result = await db.execute(select(Log).order_by(desc(Log.created_at)))
    return result.scalars().all()

8.3 全局依赖(应用于所有路由)

app = FastAPI()

# 所有路由都会先执行这个依赖
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    ...

# 或者用依赖注入(更推荐)
from fastapi import FastAPI, Request

async def common_params(request: Request):
    return {"path": request.url.path}

# 在 app 层级添加依赖(需要覆盖 add_api_route)
app.add_api_route(
    "/items",
    list_items,
    dependencies=[Depends(require_admin)]
)

9. 小结

模式说明适用场景
Depends(func)最基础的依赖注入参数提取、认证
类依赖 Depends(Class)自动实例化服务类、数据库连接
依赖链依赖间可嵌套认证 → 用户 → 权限
use_cache=False禁用请求级缓存每次请求需要新实例
全局 dependencies应用于所有路由全局认证、日志
Optional + Depends可选依赖公开/私有混合接口

💡 核心思想:依赖注入让代码变得可测试(可以 mock 依赖)、可复用(一处定义,到处使用)、声明式(代码意图一目了然)。


🔗 扩展阅读