FastAPI OAuth2与JWT鉴权完全指南

📂 所属阶段:第四阶段 — 安全与认证(安全篇)
🔗 相关章节:FastAPI依赖注入系统 · FastAPI密码哈希与安全实践

目录

为什么选择JWT认证?

传统Session认证的局限性

方案优势劣势适用场景
Session认证简单易用,服务器可控有状态,扩展性差,跨域困难单体应用,小型项目
JWT认证无状态,可扩展,跨域友好Token较大,无法主动失效分布式系统,微服务,API

JWT的核心价值

  1. 无状态:服务器无需存储Session信息
  2. 可扩展:适合水平扩展的微服务架构
  3. 跨域友好:天然支持跨域认证
  4. 自包含:Token本身包含用户信息
  5. 标准化:遵循RFC 7519标准

项目依赖安装

# 安装JWT和密码处理依赖
pip install python-jose[cryptography] passlib[bcrypt] bcrypt

# python-jose → JWT编解码
# passlib → 密码哈希
# bcrypt → 推荐哈希算法

完整依赖示例

# requirements.txt
fastapi==0.104.1
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==4.1.2
sqlalchemy==2.0.23
asyncpg==0.29.0
pydantic==2.5.0
python-multipart==0.0.6
uvicorn==0.24.0
redis==5.0.1

1. JWT 基础

1.1 JWT 结构

eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.
eyJzdWIiOiIxIiwiZW1haWwiOiJhbGljZUBleGFtcGxlLmNvbSIsInJvbGUiOiJhZG1pbiIsImV4cCI6MTc0MzEyNzIwMH0.
SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c
     ↓                ↓                ↓
  Header          Payload           Signature
  • Header:算法和类型 {"alg":"HS256","typ":"JWT"}
  • Payload:声明(用户信息、过期时间等)
  • Signature:签名,防止伪造

1.2 安装依赖

pip install python-jose[cryptography] passlib[bcrypt] bcrypt
# python-jose → JWT 编解码
# passlib → 密码哈希
# bcrypt → 推荐哈希算法

JWT工具模块实现

完整JWT工具模块

# auth/jwt.py
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
from jose import jwt, JWTError
from passlib.context import CryptContext
from fastapi import HTTPException, status
from config import get_settings

settings = get_settings()

# 密码哈希配置
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# JWT算法配置
ALGORITHM = "HS256"

class JWTManager:
    """JWT管理器 - 集中管理JWT相关操作"""
    
    def __init__(self):
        self.secret_key = settings.jwt_secret
        self.algorithm = ALGORITHM
        self.access_token_expire = timedelta(minutes=settings.jwt_expire_minutes)
        self.refresh_token_expire = timedelta(days=settings.jwt_refresh_expire_days or 30)
    
    def verify_password(self, plain_password: str, hashed_password: str) -> bool:
        """验证密码"""
        return pwd_context.verify(plain_password, hashed_password)
    
    def hash_password(self, password: str) -> str:
        """哈希密码"""
        return pwd_context.hash(password)
    
    def create_access_token(
        self,
        data: Dict[str, Any],
        expires_delta: Optional[timedelta] = None
    ) -> str:
        """创建访问令牌"""
        to_encode = data.copy()
        
        # 设置过期时间
        expire = datetime.now(timezone.utc) + (
            expires_delta or self.access_token_expire
        )
        
        # 添加JWT标准声明
        to_encode.update({
            "exp": expire.timestamp(),  # 过期时间戳
            "iat": datetime.now(timezone.utc).timestamp(),  # 签发时间戳
            "jti": str(uuid.uuid4()),  # 令牌唯一ID (JWT ID)
            "type": "access",  # 令牌类型
        })
        
        # 编码JWT
        encoded_jwt = jwt.encode(
            to_encode, 
            self.secret_key, 
            algorithm=self.algorithm
        )
        return encoded_jwt
    
    def create_refresh_token(
        self, 
        user_id: int,
        additional_claims: Optional[Dict[str, Any]] = None
    ) -> str:
        """创建刷新令牌"""
        claims = {
            "sub": str(user_id),
            "type": "refresh",
        }
        
        if additional_claims:
            claims.update(additional_claims)
        
        return self.create_access_token(
            data=claims,
            expires_delta=self.refresh_token_expire
        )
    
    def decode_token(self, token: str) -> Dict[str, Any]:
        """解码并验证令牌"""
        try:
            payload = jwt.decode(
                token, 
                self.secret_key, 
                algorithms=[self.algorithm],
                options={
                    "verify_exp": True,  # 验证过期时间
                    "verify_iat": True,  # 验证签发时间
                }
            )
            return payload
        except JWTError as e:
            raise self._create_credentials_exception(f"Token无效: {str(e)}")
    
    def validate_token_type(self, token: str, expected_type: str) -> Dict[str, Any]:
        """验证令牌类型"""
        payload = self.decode_token(token)
        token_type = payload.get("type")
        
        if token_type != expected_type:
            raise self._create_credentials_exception(
                f"期望{expected_type}令牌,但收到{token_type}令牌"
            )
        
        return payload
    
    def get_user_id_from_token(self, token: str) -> int:
        """从令牌中获取用户ID"""
        payload = self.decode_token(token)
        user_id = payload.get("sub")
        
        if not user_id:
            raise self._create_credentials_exception("令牌中缺少用户ID")
        
        try:
            return int(user_id)
        except ValueError:
            raise self._create_credentials_exception("无效的用户ID")
    
    def _create_credentials_exception(self, detail: str = "无法验证凭据") -> HTTPException:
        """创建认证异常"""
        return HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=detail,
            headers={"WWW-Authenticate": "Bearer"},
        )

# 全局JWT管理器实例
jwt_manager = JWTManager()

# 便捷函数(保持向后兼容)
def verify_password(plain_password: str, hashed_password: str) -> bool:
    return jwt_manager.verify_password(plain_password, hashed_password)

def hash_password(password: str) -> str:
    return jwt_manager.hash_password(password)

def create_access_token(
    data: dict,
    expires_delta: timedelta | None = None,
) -> str:
    return jwt_manager.create_access_token(data, expires_delta)

def create_refresh_token(user_id: int) -> str:
    return jwt_manager.create_refresh_token(user_id)

def decode_token(token: str) -> dict:
    return jwt_manager.decode_token(token)

credentials_exception = jwt_manager._create_credentials_exception

JWT配置管理

# config.py (扩展)
from pydantic import BaseSettings
import os

class Settings(BaseSettings):
    # 应用配置
    app_name: str = "FastAPI Auth App"
    version: str = "1.0.0"
    debug: bool = False
    
    # 数据库配置
    database_url: str = "sqlite:///./test.db"
    
    # JWT配置
    jwt_secret: str = os.getenv("JWT_SECRET_KEY", "")
    jwt_expire_minutes: int = 30
    jwt_refresh_expire_days: int = 7
    jwt_algorithm: str = "HS256"
    
    # 密码配置
    password_min_length: int = 8
    password_require_uppercase: bool = True
    password_require_lowercase: bool = True
    password_require_numbers: bool = True
    password_require_symbols: bool = False
    
    class Config:
        env_file = ".env"

# 验证JWT密钥
def get_settings():
    settings = Settings()
    
    if not settings.jwt_secret:
        raise ValueError("JWT_SECRET_KEY环境变量未设置")
    
    if len(settings.jwt_secret) < 32:
        raise ValueError("JWT_SECRET_KEY至少需要32个字符")
    
    return settings

settings = get_settings()

安全的JWT配置

# auth/security.py
import secrets
import string
from typing import Optional

def generate_secure_secret(length: int = 64) -> str:
    """生成安全的JWT密钥"""
    alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
    return ''.join(secrets.choice(alphabet) for _ in range(length))

def validate_password_strength(password: str) -> tuple[bool, str]:
    """验证密码强度"""
    settings = get_settings()
    
    if len(password) < settings.password_min_length:
        return False, f"密码长度至少需要{settings.password_min_length}个字符"
    
    if settings.password_require_uppercase and not any(c.isupper() for c in password):
        return False, "密码必须包含大写字母"
    
    if settings.password_require_lowercase and not any(c.islower() for c in password):
        return False, "密码必须包含小写字母"
    
    if settings.password_require_numbers and not any(c.isdigit() for c in password):
        return False, "密码必须包含数字"
    
    if settings.password_require_symbols and not any(c in "!@#$%^&*" for c in password):
        return False, "密码必须包含特殊字符"
    
    return True, "密码强度符合要求"

OAuth2密码模式实现

完整的认证模型定义

# models/user.py
from sqlalchemy import Column, Integer, String, Boolean, DateTime
from sqlalchemy.orm import declarative_base
from datetime import datetime

Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    
    id = Column(Integer, primary_key=True, index=True)
    email = Column(String, unique=True, index=True, nullable=False)
    hashed_password = Column(String, nullable=False)
    full_name = Column(String, nullable=True)
    is_active = Column(Boolean, default=True, nullable=False)
    role = Column(String, default="user", nullable=False)  # user, admin, moderator
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
    
    def to_dict(self):
        return {
            "id": self.id,
            "email": self.email,
            "full_name": self.full_name,
            "is_active": self.is_active,
            "role": self.role,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "updated_at": self.updated_at.isoformat() if self.updated_at else None,
        }

认证相关Schema定义

# schemas/auth.py
from pydantic import BaseModel, EmailStr, validator
from typing import Optional
from datetime import datetime

class Token(BaseModel):
    """令牌响应模型"""
    access_token: str
    refresh_token: str
    token_type: str = "bearer"
    expires_in: int  # 令牌过期时间(秒)

class TokenData(BaseModel):
    """令牌数据模型"""
    user_id: Optional[int] = None
    email: Optional[str] = None
    role: Optional[str] = None
    exp: Optional[int] = None

class LoginRequest(BaseModel):
    """登录请求模型"""
    email: EmailStr
    password: str
    
    @validator('password')
    def validate_password(cls, v):
        if len(v) < 6:
            raise ValueError('密码长度至少6位')
        return v

class RegisterRequest(BaseModel):
    """注册请求模型"""
    email: EmailStr
    password: str
    full_name: Optional[str] = None
    
    @validator('password')
    def validate_password_strength(cls, v):
        is_valid, message = validate_password_strength(v)
        if not is_valid:
            raise ValueError(message)
        return v

class RefreshTokenRequest(BaseModel):
    """刷新令牌请求"""
    refresh_token: str

class LogoutRequest(BaseModel):
    """登出请求"""
    token: str

认证服务层

高级认证服务实现

# services/auth_service.py
from typing import Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from fastapi import HTTPException, status
from datetime import datetime
import logging

from models.user import User
from auth.jwt import jwt_manager
from schemas.auth import LoginRequest, RegisterRequest
from utils.security import rate_limiter

logger = logging.getLogger(__name__)

class AuthService:
    """认证服务 - 处理用户认证相关业务逻辑"""
    
    def __init__(self, db: AsyncSession):
        self.db = db

    async def authenticate_user(self, email: str, password: str) -> Optional[User]:
        """验证用户凭据"""
        # 防暴力破解:检查IP限制
        client_ip = getattr(self, '_client_ip', 'unknown')
        if await rate_limiter.is_rate_limited(client_ip, 'login'):
            raise HTTPException(
                status_code=status.HTTP_429_TOO_MANY_REQUESTS,
                detail="登录尝试次数过多,请稍后再试"
            )
        
        # 查询用户
        result = await self.db.execute(
            select(User).where(User.email == email)
        )
        user = result.scalar_one_or_none()
        
        if not user:
            # 记录失败尝试(防暴力破解)
            await rate_limiter.record_failure(client_ip, 'login')
            return None
            
        if not jwt_manager.verify_password(password, user.hashed_password):
            # 记录失败尝试
            await rate_limiter.record_failure(client_ip, 'login')
            return None
            
        if not user.is_active:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="用户账户已被禁用"
            )
        
        # 重置失败计数
        await rate_limiter.reset_failures(client_ip, 'login')
        return user

    async def register_user(self, register_data: RegisterRequest) -> User:
        """用户注册"""
        # 检查邮箱是否已存在
        result = await self.db.execute(
            select(User).where(User.email == register_data.email)
        )
        existing_user = result.scalar_one_or_none()
        
        if existing_user:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="邮箱已被注册"
            )
        
        # 创建新用户
        hashed_password = jwt_manager.hash_password(register_data.password)
        user = User(
            email=register_data.email,
            hashed_password=hashed_password,
            full_name=register_data.full_name,
            role="user"  # 默认普通用户角色
        )
        
        self.db.add(user)
        await self.db.commit()
        await self.db.refresh(user)
        
        logger.info(f"新用户注册: {user.email}")
        return user

    async def get_user_by_id(self, user_id: int) -> Optional[User]:
        """根据ID获取用户"""
        result = await self.db.execute(
            select(User).where(User.id == user_id)
        )
        return result.scalar_one_or_none()

    async def get_user_by_email(self, email: str) -> Optional[User]:
        """根据邮箱获取用户"""
        result = await self.db.execute(
            select(User).where(User.email == email)
        )
        return result.scalar_one_or_none()

    async def deactivate_user(self, user_id: int) -> bool:
        """停用用户账户"""
        result = await self.db.execute(
            update(User)
            .where(User.id == user_id)
            .values(is_active=False, updated_at=datetime.utcnow())
        )
        await self.db.commit()
        return result.rowcount > 0

    async def activate_user(self, user_id: int) -> bool:
        """激活用户账户"""
        result = await self.db.execute(
            update(User)
            .where(User.id == user_id)
            .values(is_active=True, updated_at=datetime.utcnow())
        )
        await self.db.commit()
        return result.rowcount > 0

    async def change_password(self, user_id: int, old_password: str, new_password: str) -> bool:
        """修改用户密码"""
        user = await self.get_user_by_id(user_id)
        if not user:
            return False
            
        if not jwt_manager.verify_password(old_password, user.hashed_password):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="原密码错误"
            )
        
        # 验证新密码强度
        is_valid, message = validate_password_strength(new_password)
        if not is_valid:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=message
            )
        
        # 更新密码
        new_hashed = jwt_manager.hash_password(new_password)
        result = await self.db.execute(
            update(User)
            .where(User.id == user_id)
            .values(hashed_password=new_hashed, updated_at=datetime.utcnow())
        )
        await self.db.commit()
        return result.rowcount > 0

    async def update_user_role(self, user_id: int, new_role: str) -> bool:
        """更新用户角色"""
        valid_roles = ["user", "admin", "moderator"]
        if new_role not in valid_roles:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"无效的角色: {new_role}"
            )
        
        result = await self.db.execute(
            update(User)
            .where(User.id == user_id)
            .values(role=new_role, updated_at=datetime.utcnow())
        )
        await self.db.commit()
        return result.rowcount > 0

完整认证路由实现

# routers/auth.py
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import timedelta
import logging

from dependencies import get_db
from services.auth_service import AuthService
from auth.jwt import jwt_manager
from schemas.auth import (
    Token, LoginRequest, RegisterRequest, 
    RefreshTokenRequest, LogoutRequest
)
from config import get_settings

logger = logging.getLogger(__name__)
settings = get_settings()
router = APIRouter(prefix="/auth", tags=["认证"])

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")

@router.post("/login", response_model=Token)
async def login(
    request: Request,
    form_data: OAuth2PasswordRequestForm = Depends(),
    db: AsyncSession = Depends(get_db),
):
    """
    OAuth2 密码模式登录
    支持用户名(邮箱)+ 密码认证
    """
    service = AuthService(db)
    # 设置客户端IP用于限流
    service._client_ip = request.client.host if request.client else 'unknown'
    
    user = await service.authenticate_user(form_data.username, form_data.password)

    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="邮箱或密码错误",
            headers={"WWW-Authenticate": "Bearer"},
        )

    # 创建访问令牌(短期)
    access_token_data = {
        "sub": str(user.id),
        "email": user.email,
        "role": user.role,
        "user_info": {
            "id": user.id,
            "email": user.email,
            "full_name": user.full_name,
            "role": user.role
        }
    }
    
    access_token = jwt_manager.create_access_token(
        data=access_token_data,
        expires_delta=timedelta(minutes=settings.jwt_expire_minutes)
    )
    
    # 创建刷新令牌(长期)
    refresh_token = jwt_manager.create_refresh_token(user.id)

    logger.info(f"用户登录成功: {user.email}")
    
    return Token(
        access_token=access_token,
        refresh_token=refresh_token,
        token_type="bearer",
        expires_in=settings.jwt_expire_minutes * 60
    )

@router.post("/register", response_model=dict)
async def register(
    request: Request,
    register_data: RegisterRequest,
    db: AsyncSession = Depends(get_db),
):
    """用户注册"""
    service = AuthService(db)
    user = await service.register_user(register_data)
    
    logger.info(f"新用户注册成功: {user.email}")
    return {"message": "注册成功", "user_id": user.id}

@router.post("/refresh", response_model=Token)
async def refresh_token(
    refresh_request: RefreshTokenRequest,
    db: AsyncSession = Depends(get_db)
):
    """刷新访问令牌"""
    try:
        # 验证刷新令牌
        payload = jwt_manager.validate_token_type(
            refresh_request.refresh_token, 
            expected_type="refresh"
        )
        
        user_id = jwt_manager.get_user_id_from_token(refresh_request.refresh_token)
        service = AuthService(db)
        user = await service.get_user_by_id(user_id)
        
        if not user or not user.is_active:
            raise jwt_manager._create_credentials_exception("用户不存在或已被禁用")
        
        # 创建新的访问令牌
        access_token_data = {
            "sub": str(user.id),
            "email": user.email,
            "role": user.role,
        }
        
        new_access_token = jwt_manager.create_access_token(
            data=access_token_data,
            expires_delta=timedelta(minutes=settings.jwt_expire_minutes)
        )
        
        # 创建新的刷新令牌(滚动刷新令牌)
        new_refresh_token = jwt_manager.create_refresh_token(user.id)

        logger.info(f"令牌刷新成功: {user.email}")
        
        return Token(
            access_token=new_access_token,
            refresh_token=new_refresh_token,
            token_type="bearer",
            expires_in=settings.jwt_expire_minutes * 60
        )
        
    except Exception as e:
        logger.error(f"令牌刷新失败: {str(e)}")
        raise jwt_manager._create_credentials_exception("刷新令牌无效")

@router.post("/logout")
async def logout(
    logout_request: LogoutRequest,
    token: str = Depends(oauth2_scheme)
):
    """用户登出(可选:加入黑名单)"""
    # 这里可以实现令牌黑名单逻辑
    # await add_to_blacklist(logout_request.token)
    logger.info("用户登出")
    return {"message": "登出成功"}

@router.post("/change-password")
async def change_password(
    old_password: str,
    new_password: str,
    current_user: User = Depends(get_current_active_user),
    db: AsyncSession = Depends(get_db)
):
    """修改密码"""
    service = AuthService(db)
    success = await service.change_password(
        current_user.id, old_password, new_password
    )
    
    if not success:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="密码修改失败"
        )
    
    logger.info(f"用户修改密码成功: {current_user.email}")
    return {"message": "密码修改成功"}

依赖注入与用户获取

完整的依赖注入实现

# dependencies.py
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
import logging

from database import get_db_session
from models.user import User
from services.auth_service import AuthService
from auth.jwt import jwt_manager

logger = logging.getLogger(__name__)

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")

async def get_current_user(
    token: str = Depends(oauth2_scheme),
    db: AsyncSession = Depends(get_db_session),
) -> User:
    """获取当前认证用户"""
    try:
        # 解码并验证令牌
        payload = jwt_manager.decode_token(token)
        user_id: int = jwt_manager.get_user_id_from_token(token)
        
        if user_id is None:
            raise jwt_manager._create_credentials_exception("无法获取用户ID")
            
        # 从数据库获取用户
        service = AuthService(db)
        user = await service.get_user_by_id(user_id)
        
        if user is None:
            raise jwt_manager._create_credentials_exception("用户不存在")
        
        if not user.is_active:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="用户账户已被禁用",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        return user
        
    except HTTPException:
        # 重新抛出HTTP异常
        raise
    except Exception as e:
        logger.error(f"获取当前用户失败: {str(e)}")
        raise jwt_manager._create_credentials_exception("无法验证用户身份")

async def get_current_active_user(
    current_user: User = Depends(get_current_user),
) -> User:
    """获取当前活跃用户"""
    if not current_user.is_active:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="用户账户未激活"
        )
    return current_user

async def get_current_admin_user(
    current_user: User = Depends(get_current_user),
) -> User:
    """获取当前管理员用户"""
    if current_user.role != "admin":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="权限不足:需要管理员权限"
        )
    return current_user

async def get_current_moderator_user(
    current_user: User = Depends(get_current_user),
) -> User:
    """获取当前版主用户"""
    if current_user.role not in ["admin", "moderator"]:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="权限不足:需要管理员或版主权限"
        )
    return current_user

# 角色检查装饰器
def require_role(*roles: str):
    """角色权限装饰器"""
    async def role_checker(current_user: User = Depends(get_current_user)) -> User:
        if current_user.role not in roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"权限不足:需要 {', '.join(roles)} 权限"
            )
        return current_user
    return role_checker

受保护路由实现

完整的受保护路由示例

# routers/protected.py
from fastapi import APIRouter, Depends
from typing import List

from dependencies import (
    get_current_user, get_current_active_user, 
    get_current_admin_user, require_role
)
from models.user import User

router = APIRouter(prefix="/protected", tags=["受保护路由"])

@router.get("/profile")
async def get_profile(current_user: User = Depends(get_current_active_user)):
    """获取用户个人信息 - 需要认证用户"""
    return current_user.to_dict()

@router.get("/admin/dashboard")
async def admin_dashboard(current_user: User = Depends(get_current_admin_user)):
    """管理员仪表板 - 需要管理员权限"""
    return {
        "message": "管理员仪表板",
        "user": current_user.email,
        "role": current_user.role,
        "admin_data": "敏感管理数据"
    }

@router.get("/moderator/reports")
async def moderator_reports(
    current_user: User = Depends(require_role("admin", "moderator"))
):
    """版主报告 - 需要管理员或版主权限"""
    return {
        "message": "版主报告页面",
        "user": current_user.email,
        "role": current_user.role
    }

@router.get("/user/preferences")
async def user_preferences(current_user: User = Depends(get_current_active_user)):
    """用户偏好设置 - 仅认证用户"""
    return {
        "user_id": current_user.id,
        "email": current_user.email,
        "preferences": {
            "theme": "light",
            "language": "zh-CN",
            "notifications": True
        }
    }

@router.put("/user/update")
async def update_user_info(
    full_name: str = None,
    current_user: User = Depends(get_current_active_user),
    db: AsyncSession = Depends(get_db_session)
):
    """更新用户信息 - 仅认证用户"""
    if full_name:
        current_user.full_name = full_name
        await db.commit()
        await db.refresh(current_user)
    
    return {"message": "用户信息更新成功", "user": current_user.to_dict()}

令牌管理与安全策略

高级令牌管理实现

# auth/token_manager.py
import redis
from typing import Optional, Set
from datetime import datetime
import logging

from auth.jwt import jwt_manager

logger = logging.getLogger(__name__)

class TokenManager:
    """令牌管理器 - 处理令牌生命周期管理"""
    
    def __init__(self, redis_client: redis.Redis = None):
        self.redis = redis_client
        self.blacklist_prefix = "token:blacklist:"
        self.active_tokens_prefix = "token:active:"
        self.refresh_tokens_prefix = "token:refresh:"
    
    async def blacklist_token(self, token: str) -> bool:
        """将令牌加入黑名单(登出时使用)"""
        try:
            payload = jwt_manager.decode_token(token)
            exp = payload.get("exp")
            
            if not exp:
                return False
                
            # 计算剩余过期时间
            ttl = int(exp - datetime.now().timestamp())
            
            if ttl > 0:
                # 加入黑名单
                key = f"{self.blacklist_prefix}{token}"
                await self.redis.setex(key, ttl, "1")
                
                # 从活跃令牌集合中移除
                active_key = f"{self.active_tokens_prefix}{payload.get('sub', '')}"
                await self.redis.srem(active_key, token)
                
                logger.info(f"令牌已加入黑名单: {token[:20]}...")
                return True
                
        except Exception as e:
            logger.error(f"令牌黑名单失败: {str(e)}")
            return False
    
    async def is_token_blacklisted(self, token: str) -> bool:
        """检查令牌是否在黑名单中"""
        try:
            key = f"{self.blacklist_prefix}{token}"
            return await self.redis.exists(key) > 0
        except Exception:
            # Redis不可用时返回False(保守策略)
            return False
    
    async def add_active_token(self, user_id: int, token: str) -> bool:
        """添加活跃令牌到用户令牌集合"""
        try:
            payload = jwt_manager.decode_token(token)
            exp = payload.get("exp")
            
            if not exp:
                return False
                
            ttl = int(exp - datetime.now().timestamp())
            
            if ttl > 0:
                # 将令牌添加到用户活跃令牌集合
                key = f"{self.active_tokens_prefix}{user_id}"
                await self.redis.sadd(key, token)
                await self.redis.expire(key, ttl)
                
                # 同时在令牌到用户映射中记录
                token_key = f"{self.active_tokens_prefix}token:{token}"
                await self.redis.setex(token_key, ttl, str(user_id))
                
                return True
        except Exception as e:
            logger.error(f"添加活跃令牌失败: {str(e)}")
            return False
    
    async def invalidate_user_tokens(self, user_id: int) -> bool:
        """使用户的所有令牌失效"""
        try:
            key = f"{self.active_tokens_prefix}{user_id}"
            tokens = await self.redis.smembers(key)
            
            # 将所有令牌加入黑名单
            for token in tokens:
                await self.blacklist_token(token.decode('utf-8'))
            
            # 清空用户令牌集合
            await self.redis.delete(key)
            
            logger.info(f"用户 {user_id} 的所有令牌已失效")
            return True
            
        except Exception as e:
            logger.error(f"使用户令牌失效失败: {str(e)}")
            return False
    
    async def get_user_active_tokens(self, user_id: int) -> Set[str]:
        """获取用户的所有活跃令牌"""
        try:
            key = f"{self.active_tokens_prefix}{user_id}"
            tokens = await self.redis.smembers(key)
            return {token.decode('utf-8') for token in tokens}
        except Exception:
            return set()
    
    async def cleanup_expired_tokens(self) -> int:
        """清理过期令牌"""
        # 这个方法通常在定时任务中调用
        # 实现略...
        pass

# 全局令牌管理器实例
token_manager = TokenManager()

安全中间件实现

# middleware/auth_middleware.py
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging

from auth.jwt import jwt_manager
from auth.token_manager import token_manager

logger = logging.getLogger(__name__)

class AuthMiddleware(BaseHTTPMiddleware):
    """认证中间件 - 在请求级别处理认证相关逻辑"""
    
    async def dispatch(self, request: Request, call_next):
        # 检查路径是否需要认证
        protected_paths = ["/protected/", "/admin/", "/api/v1/private/"]
        is_protected = any(request.url.path.startswith(path) for path in protected_paths)
        
        # 如果是受保护路径,检查令牌
        auth_header = request.headers.get("authorization")
        if is_protected and auth_header and auth_header.startswith("Bearer "):
            token = auth_header[7:]  # 移除 "Bearer " 前缀
            
            # 检查令牌是否在黑名单中
            if await token_manager.is_token_blacklisted(token):
                return JSONResponse(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    content={"detail": "令牌已被撤销"}
                )
            
            # 验证令牌有效性
            try:
                payload = jwt_manager.decode_token(token)
                # 将用户信息添加到请求状态中
                request.state.current_user_id = payload.get("sub")
                request.state.current_user_payload = payload
            except Exception:
                return JSONResponse(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    content={"detail": "无效的认证令牌"}
                )
        
        response = await call_next(request)
        return response

安全最佳实践

企业级安全配置

# security/config.py
import os
from typing import List, Optional

class SecurityConfig:
    """安全配置类"""
    
    # JWT配置
    JWT_ALGORITHM: str = "RS256"  # 生产环境使用非对称加密
    JWT_PRIVATE_KEY_PATH: Optional[str] = os.getenv("JWT_PRIVATE_KEY_PATH")
    JWT_PUBLIC_KEY_PATH: Optional[str] = os.getenv("JWT_PUBLIC_KEY_PATH")
    
    # 密码策略
    PASSWORD_MIN_LENGTH: int = 12
    PASSWORD_REQUIRE_UPPERCASE: bool = True
    PASSWORD_REQUIRE_LOWERCASE: bool = True
    PASSWORD_REQUIRE_NUMBERS: bool = True
    PASSWORD_REQUIRE_SYMBOLS: bool = True
    PASSWORD_MAX_REUSE: int = 5  # 最大密码重用次数
    
    # 速率限制
    LOGIN_RATE_LIMIT: str = "5/minute"  # 每分钟最多5次登录尝试
    API_RATE_LIMIT: str = "100/minute"  # 每分钟最多100次API调用
    
    # 会话管理
    MAX_ACTIVE_SESSIONS: int = 5  # 每用户最大活跃会话数
    SESSION_TIMEOUT: int = 3600  # 会话超时时间(秒)
    
    # IP白名单/黑名单
    ALLOWED_IP_RANGES: List[str] = []
    BLOCKED_IPS: List[str] = []
    
    # 安全日志
    LOG_SENSITIVE_OPERATIONS: bool = True
    AUDIT_LOG_RETENTION_DAYS: int = 90

security_config = SecurityConfig()

安全审计日志

# security/audit_logger.py
import json
from datetime import datetime
from enum import Enum
from typing import Dict, Any

class AuditEventType(Enum):
    USER_LOGIN = "user_login"
    USER_LOGOUT = "user_logout"
    FAILED_LOGIN = "failed_login"
    PASSWORD_CHANGE = "password_change"
    ROLE_CHANGE = "role_change"
    DATA_ACCESS = "data_access"
    PERMISSION_DENIED = "permission_denied"

class AuditLogger:
    """安全审计日志记录器"""
    
    def __init__(self, log_handler):
        self.log_handler = log_handler
    
    def log_event(
        self, 
        event_type: AuditEventType, 
        user_id: int = None, 
        ip_address: str = None, 
        details: Dict[str, Any] = None
    ):
        """记录审计事件"""
        log_entry = {
            "timestamp": datetime.utcnow().isoformat(),
            "event_type": event_type.value,
            "user_id": user_id,
            "ip_address": ip_address,
            "details": details or {},
            "session_id": None  # 可以从请求上下文中获取
        }
        
        self.log_handler.info(json.dumps(log_entry))
    
    def log_login_success(self, user_id: int, ip_address: str):
        """记录登录成功"""
        self.log_event(
            AuditEventType.USER_LOGIN,
            user_id=user_id,
            ip_address=ip_address,
            details={"success": True}
        )
    
    def log_failed_login(self, email: str, ip_address: str, reason: str = None):
        """记录登录失败"""
        self.log_event(
            AuditEventType.FAILED_LOGIN,
            details={
                "email": email,
                "reason": reason or "invalid_credentials"
            },
            ip_address=ip_address
        )
    
    def log_permission_denied(self, user_id: int, endpoint: str, ip_address: str):
        """记录权限拒绝"""
        self.log_event(
            AuditEventType.PERMISSION_DENIED,
            user_id=user_id,
            ip_address=ip_address,
            details={"endpoint": endpoint}
        )

# 全局审计日志记录器
audit_logger = AuditLogger(logging.getLogger("audit"))

常见安全陷阱与解决方案

安全漏洞防范

# security/vulnerability_protection.py
from functools import wraps
from typing import Callable

def prevent_reauthentication_attack(func: Callable) -> Callable:
    """防止重新认证攻击装饰器"""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        # 在令牌刷新时验证用户凭据
        # 确保刷新令牌与访问令牌来自同一设备/会话
        return await func(*args, **kwargs)
    return wrapper

def enforce_token_binding(func: Callable) -> Callable:
    """强制令牌绑定装饰器"""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        # 将令牌与特定设备指纹绑定
        # 防止令牌在不同设备间共享
        return await func(*args, **kwargs)
    return wrapper

def prevent_brute_force(func: Callable) -> Callable:
    """防暴力破解装饰器"""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        # 实现登录尝试限制
        # IP地址锁定机制
        return await func(*args, **kwargs)
    return wrapper

# 安全头设置
SECURE_HEADERS = {
    "X-Content-Type-Options": "nosniff",
    "X-Frame-Options": "DENY",
    "X-XSS-Protection": "1; mode=block",
    "Strict-Transport-Security": "max-age=31536000; includeSubDomains",
    "Content-Security-Policy": "default-src 'self'",
}

与其他认证方式对比

特性JWT认证Session认证OAuth2API Key
状态管理无状态有状态无状态无状态
扩展性
跨域支持优秀困难优秀优秀
安全性中等中等
令牌管理复杂简单复杂简单
性能中等
适用场景微服务、API传统Web应用第三方授权服务间认证

总结

FastAPI中的OAuth2与JWT认证提供了强大而灵活的安全机制:

  1. 无状态设计:JWT令牌自包含用户信息,适合分布式系统
  2. 灵活的角色控制:通过依赖注入实现精细的权限管理
  3. 安全的令牌管理:支持令牌刷新、黑名单等高级功能
  4. 可扩展的架构:易于集成第三方认证服务

通过遵循安全最佳实践,可以构建企业级的安全认证系统。

💡 关键要点:始终使用HTTPS传输令牌,定期轮换JWT密钥,实现适当的速率限制和审计日志。


🔗 扩展阅读

</final_file_content>