FastAPI与Redis集成完全指南

📂 所属阶段:第三阶段 — 数据持久化(数据库篇)
🔗 相关章节:FastAPI SQLAlchemy 2.0实战 · FastAPI异步编程深度解析

目录

为什么选择Redis集成?

Redis在现代Web应用中的核心价值

Redis作为内存数据结构存储,为现代Web应用提供了关键的性能和功能优势:

应用场景传统方案Redis方案性能提升
热点数据缓存数据库查询内存读取100x+
Session管理数据库存储内存存储50x+
消息队列专用MQRedis List/ZSet10x+
实时排行榜SQL聚合Sorted Set100x+
限流控制应用层计数Redis计数50x+

高性能架构对比

传统架构:
请求 → FastAPI → 数据库查询 → 返回数据

    单点瓶颈,高延迟

Redis优化架构:
请求 → FastAPI → Redis缓存(命中)→ 直接返回
          ↓              ↓
    数据库查询 ← 未命中 ← 缓存未命中

    Redis写入 ← 更新缓存

项目依赖安装

# Redis异步客户端
pip install redis[hiredis]

# hiredis是C语言实现的高性能解析器,显著提升性能
# 可选:Redis连接池监控
pip install aioredis

完整依赖示例

# requirements.txt
fastapi==0.104.1
redis[hiredis]==5.0.1
aioredis==2.0.1
pydantic==2.5.0
python-multipart==0.0.6
uvicorn==0.24.0

Redis基础配置与连接管理

Redis连接配置详解

# redis_client.py
import redis.asyncio as redis
from redis.backoff import ExponentialBackoff
from redis.retry import Retry
from config import get_settings
import logging
from typing import Optional
from contextlib import asynccontextmanager

settings = get_settings()
logger = logging.getLogger(__name__)

class RedisManager:
    """Redis连接管理器 - 集中管理Redis连接和配置"""
    
    def __init__(self):
        self.client: Optional[redis.Redis] = None
        self._initialized = False
    
    async def initialize(self):
        """初始化Redis连接"""
        if self._initialized:
            return
        
        try:
            # 创建异步Redis连接
            self.client = redis.from_url(
                settings.redis_url,
                # 连接池配置
                max_connections=settings.redis_max_connections,
                retry_on_timeout=settings.redis_retry_on_timeout,
                retry=Retry(ExponentialBackoff(), attempts=3),
                
                # 序列化配置
                encoding="utf-8",
                decode_responses=True,  # 自动解码bytes为str
                
                # 性能优化
                socket_keepalive=True,  # 保持连接
                socket_keepalive_options={},  # TCP keepalive选项
                health_check_interval=30,  # 健康检查间隔
                
                # 超时配置
                socket_connect_timeout=5,  # 连接超时
                socket_timeout=30,         # 读写超时
                retry_on_timeout=True,     # 超时重试
            )
            
            # 测试连接
            await self.client.ping()
            logger.info("Redis connection established successfully")
            self._initialized = True
            
        except Exception as e:
            logger.error(f"Failed to connect to Redis: {e}")
            raise
    
    async def close(self):
        """关闭Redis连接"""
        if self.client:
            await self.client.close()
            self._initialized = False
            logger.info("Redis connection closed")
    
    def get_client(self) -> redis.Redis:
        """获取Redis客户端实例"""
        if not self._initialized or not self.client:
            raise RuntimeError("Redis client not initialized")
        return self.client

# 全局Redis管理器实例
redis_manager = RedisManager()

# 便捷的客户端获取函数
def get_redis_client() -> redis.Redis:
    """获取Redis客户端的便捷函数"""
    return redis_manager.get_client()

配置管理

# config.py (扩展配置)
from pydantic import BaseSettings
from functools import lru_cache
from typing import List, Optional

class Settings(BaseSettings):
    # 应用配置
    app_name: str = "FastAPI Redis App"
    version: str = "1.0.0"
    debug: bool = False
    
    # Redis配置
    redis_url: str = "redis://localhost:6379/0"
    redis_max_connections: int = 20
    redis_retry_on_timeout: bool = True
    redis_socket_keepalive: bool = True
    redis_health_check_interval: int = 30
    redis_socket_connect_timeout: int = 5
    redis_socket_timeout: int = 30
    
    # Redis命名空间配置
    redis_namespace: str = "fastapi_app"
    
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

@lru_cache()
def get_settings() -> Settings:
    return Settings()

settings = get_settings()

在FastAPI中管理Redis生命周期

# main.py (扩展)
from fastapi import FastAPI
from contextlib import asynccontextmanager
from redis_client import redis_manager
import logging

logger = logging.getLogger(__name__)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用生命周期管理"""
    logger.info("Initializing Redis connection...")
    
    # 启动时初始化Redis
    await redis_manager.initialize()
    logger.info("Redis initialized successfully")
    yield
    
    # 关闭时清理资源
    logger.info("Shutting down Redis connection...")
    await redis_manager.close()
    logger.info("Redis connection closed")

app = FastAPI(
    title="FastAPI Redis API",
    version="1.0.0",
    description="使用Redis缓存的FastAPI应用",
    lifespan=lifespan
)

# 健康检查端点
@app.get("/health")
async def health_check():
    """健康检查端点"""
    redis_client = get_redis_client()
    try:
        await redis_client.ping()
        redis_status = "connected"
    except:
        redis_status = "disconnected"
    
    return {
        "status": "healthy",
        "redis": redis_status,
        "timestamp": "2024-01-15T10:30:00Z"
    }

Redis依赖注入

# dependencies.py (扩展)
from fastapi import Depends
from redis_client import get_redis_client
import redis.asyncio as redis

# Redis客户端依赖
async def get_redis() -> redis.Redis:
    """获取Redis客户端的依赖函数"""
    return get_redis_client()

# 在路由中使用
from fastapi import APIRouter
import redis.asyncio as redis

router = APIRouter()

@router.get("/cache/health")
async def cache_health_check(redis_client: redis.Redis = Depends(get_redis)):
    """缓存健康检查"""
    try:
        info = await redis_client.info()
        return {
            "status": "healthy",
            "version": info.get("redis_version"),
            "connected_clients": info.get("connected_clients"),
            "used_memory": info.get("used_memory_human")
        }
    except Exception as e:
        return {"status": "unhealthy", "error": str(e)}

高性能缓存策略

缓存装饰器实现

# cache/decorators.py
import json
import hashlib
from functools import wraps
from typing import Any, Callable, Optional, Union
from redis.asyncio import Redis
from typing import get_type_hints
import inspect
import logging

logger = logging.getLogger(__name__)

def cached(
    key_prefix: str = "",
    ttl: int = 300,  # 默认5分钟过期
    key_func: Optional[Callable] = None,
    serializer: str = "json",
    namespace: str = ""
):
    """
    异步缓存装饰器
    用法:
    @cached("user:{user_id}", ttl=600)
    async def get_user(user_id: int):
        return await db.get_user(user_id)
    """
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        async def wrapper(*args, **kwargs):
            # 获取Redis客户端
            redis_client = get_redis_client()
            
            # 构造缓存键
            if key_func:
                cache_key = key_func(*args, **kwargs)
            else:
                # 自动生成缓存键
                sig = inspect.signature(func)
                bound_args = sig.bind(*args, **kwargs)
                bound_args.apply_defaults()
                
                # 生成参数哈希
                params_str = str(sorted(bound_args.arguments.items()))
                params_hash = hashlib.md5(params_str.encode()).hexdigest()[:8]
                
                # 构造缓存键
                if key_prefix:
                    cache_key = f"{namespace}:{key_prefix}:{params_hash}"
                else:
                    cache_key = f"{namespace}:{func.__name__}:{params_hash}"
            
            try:
                # 尝试从Redis获取
                cached_value = await redis_client.get(cache_key)
                if cached_value:
                    logger.debug(f"Cache hit for key: {cache_key}")
                    
                    # 反序列化
                    if serializer == "json":
                        return json.loads(cached_value)
                    else:
                        return cached_value
                
                logger.debug(f"Cache miss for key: {cache_key}")
                
                # 未命中,执行原函数
                result = await func(*args, **kwargs)
                
                # 序列化并存入Redis
                if serializer == "json":
                    serialized_result = json.dumps(result, default=str)
                else:
                    serialized_result = result
                
                await redis_client.setex(cache_key, ttl, serialized_result)
                logger.debug(f"Cache set for key: {cache_key}")
                
                return result
                
            except Exception as e:
                logger.error(f"Cache error for key {cache_key}: {e}")
                # 缓存失败时仍返回原始结果
                return await func(*args, **kwargs)
        
        # 添加缓存管理方法
        wrapper.invalidate = lambda *args, **kwargs: invalidate_cache(key_prefix, *args, **kwargs)
        return wrapper
    return decorator

async def invalidate_cache(pattern: str, *args, **kwargs):
    """清除匹配的缓存键"""
    redis_client = get_redis_client()
    
    # 构造清除模式
    if args or kwargs:
        # 如果提供了参数,构造具体键
        cache_key = pattern.format(*args, **kwargs)
        await redis_client.delete(cache_key)
    else:
        # 否则使用通配符模式
        keys = await redis_client.keys(f"*{pattern}*")
        if keys:
            await redis_client.delete(*keys)

# 批量缓存操作
class BatchCache:
    """批量缓存操作类"""
    
    def __init__(self, redis_client: Redis):
        self.redis = redis_client
    
    async def mget_cached(self, keys: list) -> list:
        """批量获取缓存"""
        results = await self.redis.mget(keys)
        return [json.loads(result) if result else None for result in results]
    
    async def mset_cached(self, mapping: dict, ttl: int = 300):
        """批量设置缓存"""
        # 序列化值
        serialized_mapping = {
            key: json.dumps(value, default=str) 
            for key, value in mapping.items()
        }
        
        pipe = self.redis.pipeline()
        pipe.mset(serialized_mapping)
        for key in mapping.keys():
            pipe.expire(key, ttl)
        await pipe.execute()

高级缓存策略

# cache/strategies.py
from enum import Enum
from typing import Optional, Any
import time
from redis.asyncio import Redis

class CacheStrategy(Enum):
    """缓存策略枚举"""
    TTL = "ttl"  # 固定TTL
    LFU = "lfu"  # 最少使用
    LRU = "lru"  # 最近最少使用
    SLIDING_WINDOW = "sliding_window"  # 滑动窗口

class AdvancedCache:
    """高级缓存策略实现"""
    
    def __init__(self, redis_client: Redis):
        self.redis = redis_client
    
    async def get_with_fallback(
        self, 
        cache_key: str, 
        fetch_func: Callable, 
        ttl: int = 300,
        fallback_ttl: int = 60
    ):
        """
        带降级的缓存获取
        如果缓存失败,使用降级策略
        """
        try:
            # 尝试获取缓存
            cached_value = await self.redis.get(cache_key)
            if cached_value:
                return json.loads(cached_value)
            
            # 缓存未命中,执行获取函数
            fresh_data = await fetch_func()
            
            # 设置缓存
            await self.redis.setex(
                cache_key, 
                ttl, 
                json.dumps(fresh_data, default=str)
            )
            
            return fresh_data
            
        except Exception as e:
            # 降级策略:尝试获取降级缓存
            fallback_key = f"{cache_key}:fallback"
            try:
                fallback_value = await self.redis.get(fallback_key)
                if fallback_value:
                    logger.warning(f"Using fallback cache for {cache_key}: {e}")
                    return json.loads(fallback_value)
            except:
                pass
            
            # 最后尝试直接获取数据(不缓存)
            return await fetch_func()
    
    async def sliding_window_cache(
        self, 
        cache_key: str, 
        fetch_func: Callable, 
        window_ttl: int = 300,
        refresh_threshold: int = 60
    ):
        """
        滑动窗口缓存
        在接近过期时提前刷新缓存
        """
        cache_info_key = f"{cache_key}:info"
        
        # 获取缓存信息
        cache_info = await self.redis.hgetall(cache_info_key)
        if cache_info:
            cached_value = await self.redis.get(cache_key)
            if cached_value:
                # 检查是否需要刷新
                created_at = int(cache_info.get('created_at', 0))
                if time.time() - created_at > (window_ttl - refresh_threshold):
                    # 异步刷新缓存
                    import asyncio
                    asyncio.create_task(self._refresh_cache(cache_key, fetch_func, window_ttl))
                
                return json.loads(cached_value)
        
        # 缓存未命中,获取新鲜数据
        fresh_data = await fetch_func()
        
        # 存储数据和元信息
        pipe = self.redis.pipeline()
        pipe.setex(cache_key, window_ttl, json.dumps(fresh_data, default=str))
        pipe.hset(cache_info_key, 'created_at', int(time.time()))
        pipe.expire(cache_info_key, window_ttl + 60)  # 信息稍微晚过期
        await pipe.execute()
        
        return fresh_data
    
    async def _refresh_cache(self, cache_key: str, fetch_func: Callable, ttl: int):
        """异步刷新缓存"""
        try:
            fresh_data = await fetch_func()
            await self.redis.setex(
                cache_key, 
                ttl, 
                json.dumps(fresh_data, default=str)
            )
            logger.info(f"Cache refreshed for {cache_key}")
        except Exception as e:
            logger.error(f"Failed to refresh cache for {cache_key}: {e}")

# 缓存预热
class CacheWarmer:
    """缓存预热器"""
    
    def __init__(self, redis_client: Redis):
        self.redis = redis_client
    
    async def warm_common_queries(self):
        """预热常见查询"""
        # 预热热门用户数据
        hot_user_ids = [1, 2, 3, 4, 5]  # 示例热门用户ID
        for user_id in hot_user_ids:
            cache_key = f"user:profile:{user_id}"
            if not await self.redis.exists(cache_key):
                # 从数据库获取并缓存
                user_data = await self._fetch_user_data(user_id)
                await self.redis.setex(
                    cache_key, 
                    3600,  # 1小时
                    json.dumps(user_data, default=str)
                )
    
    async def _fetch_user_data(self, user_id: int):
        """获取用户数据的示例方法"""
        # 这里应该调用数据库查询
        return {"id": user_id, "name": f"User_{user_id}"}

用户信息缓存示例

# services/user_cache_service.py
from cache.decorators import cached, invalidate_cache
from cache.strategies import AdvancedCache
from typing import Optional, Dict, Any
import json

class UserCacheService:
    """用户缓存服务"""
    
    def __init__(self):
        self.advanced_cache = AdvancedCache(get_redis_client())
    
    @cached("user:profile:{user_id}", ttl=900)  # 15分钟
    async def get_user_profile(self, user_id: int) -> Optional[Dict[str, Any]]:
        """获取用户资料缓存"""
        # 实际的数据库查询逻辑
        from repositories.user_repository import UserRepository
        user_repo = UserRepository()
        user = await user_repo.get_by_id(user_id)
        if not user:
            return None
        
        return {
            "id": user.id,
            "name": user.name,
            "email": user.email,
            "avatar": user.avatar,
            "created_at": user.created_at.isoformat() if user.created_at else None,
            "updated_at": user.updated_at.isoformat() if user.updated_at else None,
        }
    
    async def update_user_profile(self, user_id: int, data: Dict[str, Any]) -> bool:
        """更新用户资料并清除缓存"""
        from repositories.user_repository import UserRepository
        user_repo = UserRepository()
        
        # 更新数据库
        success = await user_repo.update(user_id, data)
        
        if success:
            # 清除相关缓存
            await invalidate_cache(f"user:profile:{user_id}")
            await invalidate_cache(f"user:basic:{user_id}")
            await invalidate_cache(f"user:stats:{user_id}")
        
        return success
    
    async def get_user_statistics(self, user_id: int) -> Dict[str, Any]:
        """获取用户统计数据(滑动窗口缓存)"""
        cache_key = f"user:stats:{user_id}"
        
        async def fetch_stats():
            # 从数据库获取统计数据
            from repositories.user_repository import UserRepository
            user_repo = UserRepository()
            return await user_repo.get_user_statistics(user_id)
        
        return await self.advanced_cache.sliding_window_cache(
            cache_key, 
            fetch_stats, 
            window_ttl=1800,  # 30分钟
            refresh_threshold=300  # 5分钟提前刷新
        )
    
    async def get_hot_users(self) -> list:
        """获取热门用户列表(长期缓存)"""
        cache_key = "user:hot_list"
        
        async def fetch_hot_users():
            from repositories.user_repository import UserRepository
            user_repo = UserRepository()
            return await user_repo.get_hot_users()
        
        return await self.advanced_cache.get_with_fallback(
            cache_key, 
            fetch_hot_users, 
            ttl=3600,  # 1小时
            fallback_ttl=600  # 降级缓存10分钟
        )

分布式Session管理

Session管理器实现

# session/manager.py
import uuid
import json
import time
from typing import Dict, Optional, Any
from redis.asyncio import Redis
from datetime import datetime, timedelta
import logging

logger = logging.getLogger(__name__)

class RedisSessionManager:
    """Redis分布式Session管理器"""
    
    def __init__(self, redis_client: Redis, namespace: str = "session"):
        self.redis = redis_client
        self.namespace = namespace
        self.session_prefix = f"{namespace}:"
        self.session_expire = 86400 * 7  # 7天过期
        self.refresh_threshold = 3600    # 1小时刷新阈值
    
    async def create_session(
        self, 
        user_id: int, 
        extra_data: Optional[Dict[str, Any]] = None,
        custom_expire: Optional[int] = None
    ) -> str:
        """创建新Session,返回session_id"""
        session_id = str(uuid.uuid4())
        session_key = self.session_prefix + session_id
        
        # 构造Session数据
        session_data = {
            "user_id": user_id,
            "created_at": time.time(),
            "last_accessed": time.time(),
            "ip_address": extra_data.get("ip_address") if extra_data else None,
            "user_agent": extra_data.get("user_agent") if extra_data else None,
            **(extra_data or {})
        }
        
        # 存储Session
        expire_time = custom_expire or self.session_expire
        await self.redis.setex(
            session_key,
            expire_time,
            json.dumps(session_data)
        )
        
        logger.info(f"Session created: {session_id} for user {user_id}")
        return session_id
    
    async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
        """获取Session数据"""
        session_key = self.session_prefix + session_id
        data = await self.redis.get(session_key)
        
        if data:
            session_data = json.loads(data)
            # 更新最后访问时间
            session_data["last_accessed"] = time.time()
            await self._update_last_accessed(session_key, session_data)
            return session_data
        
        return None
    
    async def _update_last_accessed(self, session_key: str, session_data: Dict[str, Any]):
        """更新Session最后访问时间"""
        # 仅在超过刷新阈值时更新
        last_accessed = session_data.get("last_accessed", 0)
        if time.time() - last_accessed > self.refresh_threshold:
            session_data["last_accessed"] = time.time()
            # 重新设置过期时间
            await self.redis.setex(
                session_key,
                self.session_expire,
                json.dumps(session_data)
            )
    
    async def refresh_session(self, session_id: str, extend_time: Optional[int] = None) -> bool:
        """刷新Session过期时间"""
        session_key = self.session_prefix + session_id
        # 获取当前Session数据
        data = await self.redis.get(session_key)
        if not data:
            return False
        
        # 重新设置过期时间
        expire_time = extend_time or self.session_expire
        return await self.redis.expire(session_key, expire_time)
    
    async def destroy_session(self, session_id: str) -> bool:
        """销毁Session(登出)"""
        session_key = self.session_prefix + session_id
        result = await self.redis.delete(session_key)
        if result:
            logger.info(f"Session destroyed: {session_id}")
        return bool(result)
    
    async def extend_session(self, session_id: str, additional_time: int) -> bool:
        """延长Session有效期"""
        session_key = self.session_prefix + session_id
        current_ttl = await self.redis.ttl(session_key)
        if current_ttl < 0:
            return False  # Session不存在
        
        new_ttl = max(current_ttl, 0) + additional_time
        return await self.redis.expire(session_key, new_ttl)
    
    async def get_user_sessions(self, user_id: int) -> list:
        """获取用户的所有活跃Session"""
        # 注意:在生产环境中,可能需要维护用户Session索引
        # 这里是一个简化实现
        pattern = f"{self.session_prefix}*"
        keys = await self.redis.keys(pattern)
        
        sessions = []
        for key in keys:
            data = await self.redis.get(key)
            if data:
                session_data = json.loads(data)
                if session_data.get("user_id") == user_id:
                    session_data["session_id"] = key.replace(self.session_prefix, "")
                    sessions.append(session_data)
        
        return sessions
    
    async def invalidate_user_sessions(self, user_id: int) -> int:
        """注销用户的所有Session"""
        pattern = f"{self.session_prefix}*"
        keys = await self.redis.keys(pattern)
        
        user_keys = []
        for key in keys:
            data = await self.redis.get(key)
            if data:
                session_data = json.loads(data)
                if session_data.get("user_id") == user_id:
                    user_keys.append(key)
        
        if user_keys:
            result = await self.redis.delete(*user_keys)
            logger.info(f"Invalidated {result} sessions for user {user_id}")
            return result
        
        return 0

# 全局Session管理器实例
session_manager = RedisSessionManager(get_redis_client())

Session中间件

# middleware/session_middleware.py
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import logging

logger = logging.getLogger(__name__)

class SessionMiddleware(BaseHTTPMiddleware):
    """Session中间件 - 自动处理Session创建和验证"""
    
    async def dispatch(self, request: Request, call_next):
        # 从Cookie或Header获取Session ID
        session_id = (
            request.cookies.get("session_id") or
            request.headers.get("X-Session-ID")
        )
        
        if session_id:
            # 验证Session
            session_data = await session_manager.get_session(session_id)
            if session_data:
                # 将Session数据附加到request
                request.state.session = session_data
                request.state.user_id = session_data.get("user_id")
            else:
                # Session无效,清除Cookie
                request.state.session = None
                request.state.user_id = None
        else:
            request.state.session = None
            request.state.user_id = None
        
        response = await call_next(request)
        
        # 如果创建了新Session,在响应中设置Cookie
        if hasattr(request.state, 'new_session_id'):
            response.set_cookie(
                "session_id",
                request.state.new_session_id,
                httponly=True,
                secure=True,  # 生产环境中应启用HTTPS
                samesite="strict",
                max_age=86400 * 7  # 7天
            )
        
        return response

# 在main.py中注册中间件
from middleware.session_middleware import SessionMiddleware

app.add_middleware(SessionMiddleware)

Session依赖注入

# dependencies/session_deps.py
from fastapi import Request, HTTPException, status
from typing import Optional, Dict, Any
from session.manager import session_manager

async def get_current_session(request: Request) -> Optional[Dict[str, Any]]:
    """获取当前Session依赖"""
    session_data = getattr(request.state, 'session', None)
    if not session_data:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="未登录或Session已过期"
        )
    return session_data

async def get_current_user_id(request: Request) -> int:
    """获取当前用户ID依赖"""
    user_id = getattr(request.state, 'user_id', None)
    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="未登录或Session已过期"
        )
    return user_id

async def require_login(request: Request) -> Dict[str, Any]:
    """需要登录的依赖"""
    session = await get_current_session(request)
    user_id = await get_current_user_id(request)
    
    # 刷新Session
    session_id = session.get("session_id")
    if session_id:
        await session_manager.refresh_session(session_id)
    
    return session

# 使用示例
from fastapi import APIRouter, Depends

router = APIRouter()

@router.get("/profile")
async def get_profile(
    session: dict = Depends(require_login),
    user_id: int = Depends(get_current_user_id)
):
    """需要登录的路由"""
    return {
        "user_id": user_id,
        "session_info": {
            "created_at": session.get("created_at"),
            "last_accessed": session.get("last_accessed")
        }
    }

@router.post("/logout")
async def logout(request: Request):
    """登出路由"""
    session_id = request.cookies.get("session_id")
    if session_id:
        await session_manager.destroy_session(session_id)
    
    response = {"message": "登出成功"}
    # 清除Cookie
    from fastapi.responses import JSONResponse
    resp = JSONResponse(content=response)
    resp.delete_cookie("session_id")
    return resp

异步消息队列实现

任务队列系统

# queue/task_queue.py
import json
import uuid
import time
from typing import Dict, Any, Optional, Callable, List
from redis.asyncio import Redis
from datetime import datetime
import asyncio
import logging

logger = logging.getLogger(__name__)

class RedisTaskQueue:
    """Redis异步任务队列"""
    
    def __init__(self, redis_client: Redis, queue_name: str = "default"):
        self.redis = redis_client
        self.queue_name = queue_name
        self.task_queue = f"queue:{queue_name}:tasks"
        self.processing_queue = f"queue:{queue_name}:processing"
        self.delayed_queue = f"queue:{queue_name}:delayed"
        self.retry_queue = f"queue:{queue_name}:retry"
        self.dead_letter_queue = f"queue:{queue_name}:dead_letter"
        
        # 任务配置
        self.max_retries = 3
        self.retry_delay = 60  # 重试延迟(秒)
        self.visibility_timeout = 300  # 可见性超时(秒)
    
    async def enqueue(
        self, 
        task_name: str, 
        task_args: Dict[str, Any], 
        delay: Optional[int] = None,
        priority: int = 0
    ) -> str:
        """入队任务"""
        task_id = str(uuid.uuid4())
        task_data = {
            "id": task_id,
            "name": task_name,
            "args": task_args,
            "priority": priority,
            "created_at": time.time(),
            "attempts": 0,
            "max_retries": self.max_retries
        }
        
        if delay:
            # 延迟任务,使用有序集合
            await self.redis.zadd(
                self.delayed_queue,
                {json.dumps(task_data): time.time() + delay}
            )
        else:
            # 立即任务,使用列表
            await self.redis.rpush(
                self.task_queue,
                json.dumps(task_data)
            )
        
        logger.info(f"Task enqueued: {task_id} - {task_name}")
        return task_id
    
    async def dequeue(self, timeout: int = 5) -> Optional[Dict[str, Any]]:
        """出队任务"""
        # 首先检查延迟队列
        await self._move_delayed_tasks()
        
        # 阻塞式弹出任务
        result = await self.redis.blpop([self.task_queue], timeout=timeout)
        if result:
            _, task_json = result
            task_data = json.loads(task_json)
            
            # 移动到处理队列,设置可见性超时
            await self.redis.zadd(
                self.processing_queue,
                {json.dumps(task_data): time.time() + self.visibility_timeout}
            )
            
            return task_data
        
        return None
    
    async def _move_delayed_tasks(self):
        """移动到期的延迟任务"""
        now = time.time()
        expired_tasks = await self.redis.zrangebyscore(
            self.delayed_queue, 
            "-inf", 
            now
        )
        
        if expired_tasks:
            pipe = self.redis.pipeline()
            for task_json in expired_tasks:
                pipe.lpush(self.task_queue, task_json)
                pipe.zrem(self.delayed_queue, task_json)
            await pipe.execute()
    
    async def complete_task(self, task_data: Dict[str, Any]) -> bool:
        """完成任务"""
        task_json = json.dumps(task_data)
        result = await self.redis.zrem(self.processing_queue, task_json)
        if result:
            logger.info(f"Task completed: {task_data['id']}")
        return bool(result)
    
    async def fail_task(self, task_data: Dict[str, Any], error: str) -> bool:
        """任务失败处理"""
        task_id = task_data["id"]
        attempts = task_data.get("attempts", 0) + 1
        
        if attempts <= task_data.get("max_retries", self.max_retries):
            # 重试任务
            task_data["attempts"] = attempts
            task_data["error"] = error
            task_data["retry_at"] = time.time() + self.retry_delay
            
            await self.redis.zadd(
                self.retry_queue,
                {json.dumps(task_data): task_data["retry_at"]}
            )
            
            logger.warning(f"Task {task_id} failed, scheduled for retry {attempts}/{task_data['max_retries']}")
        else:
            # 移动到死信队列
            task_data["final_error"] = error
            await self.redis.rpush(self.dead_letter_queue, json.dumps(task_data))
            logger.error(f"Task {task_id} moved to dead letter queue after {attempts} attempts")
        
        # 从处理队列移除
        await self.redis.zrem(self.processing_queue, json.dumps(task_data))
        return True
    
    async def get_queue_stats(self) -> Dict[str, int]:
        """获取队列统计信息"""
        return {
            "pending": await self.redis.llen(self.task_queue),
            "processing": await self.redis.zcard(self.processing_queue),
            "delayed": await self.redis.zcard(self.delayed_queue),
            "retry": await self.redis.zcard(self.retry_queue),
            "dead_letter": await self.redis.llen(self.dead_letter_queue)
        }
    
    async def process_task_with_handler(self, handler_func: Callable) -> bool:
        """使用处理器函数处理任务"""
        task_data = await self.dequeue()
        if not task_data:
            return False
        
        try:
            # 执行任务处理函数
            await handler_func(task_data)
            # 完成任务
            await self.complete_task(task_data)
            return True
        except Exception as e:
            # 任务失败
            await self.fail_task(task_data, str(e))
            return False

# 全局任务队列实例
task_queue = RedisTaskQueue(get_redis_client(), "default")

任务处理器

# handlers/task_handlers.py
from typing import Dict, Any
import logging

logger = logging.getLogger(__name__)

class TaskHandlers:
    """任务处理器集合"""
    
    @staticmethod
    async def send_email_task(task_data: Dict[str, Any]):
        """发送邮件任务处理器"""
        args = task_data["args"]
        recipient = args.get("recipient")
        subject = args.get("subject")
        body = args.get("body")
        
        logger.info(f"Sending email to {recipient}: {subject}")
        
        # 实际的邮件发送逻辑
        # await send_email(recipient, subject, body)
        
        logger.info(f"Email sent successfully to {recipient}")
    
    @staticmethod
    async def process_user_registration_task(task_data: Dict[str, Any]):
        """用户注册处理任务"""
        args = task_data["args"]
        user_id = args.get("user_id")
        
        logger.info(f"Processing registration for user {user_id}")
        
        # 发送欢迎邮件
        # await send_welcome_email(user_id)
        
        # 更新用户状态
        # await update_user_status(user_id, "active")
        
        # 记录注册事件
        # await log_registration_event(user_id)
        
        logger.info(f"Registration processed for user {user_id}")
    
    @staticmethod
    async def cleanup_old_data_task(task_data: Dict[str, Any]):
        """清理旧数据任务"""
        args = task_data["args"]
        days_old = args.get("days_old", 30)
        
        logger.info(f"Cleaning up data older than {days_old} days")
        
        # 执行清理逻辑
        # cleaned_count = await cleanup_data(days_old)
        
        logger.info(f"Cleaned up old data: X records removed")
    
    @staticmethod
    async def generate_report_task(task_data: Dict[str, Any]):
        """生成报告任务"""
        args = task_data["args"]
        report_type = args.get("report_type")
        start_date = args.get("start_date")
        end_date = args.get("end_date")
        
        logger.info(f"Generating {report_type} report for {start_date} to {end_date}")
        
        # 生成报告逻辑
        # report = await generate_report(report_type, start_date, end_date)
        
        # 保存报告
        # await save_report(report)
        
        logger.info(f"Report generated: {report_type}")

后台任务处理器

# workers/background_worker.py
import asyncio
import signal
import sys
from typing import Dict, Any, Callable
import logging

logger = logging.getLogger(__name__)

class BackgroundWorker:
    """后台任务处理器"""
    
    def __init__(self, task_queue, task_handlers):
        self.task_queue = task_queue
        self.task_handlers = task_handlers
        self.running = False
        self.task_mapping = {
            "send_email": task_handlers.send_email_task,
            "process_user_registration": task_handlers.process_user_registration_task,
            "cleanup_old_data": task_handlers.cleanup_old_data_task,
            "generate_report": task_handlers.generate_report_task,
        }
    
    async def start(self):
        """启动后台处理器"""
        self.running = True
        logger.info("Background worker started")
        
        # 注册信号处理器
        def signal_handler(signum, frame):
            logger.info(f"Received signal {signum}, shutting down...")
            self.running = False
        
        signal.signal(signal.SIGINT, signal_handler)
        signal.signal(signal.SIGTERM, signal_handler)
        
        # 主处理循环
        while self.running:
            try:
                # 处理一个任务
                handled = await self.task_queue.process_task_with_handler(
                    self._handle_single_task
                )
                
                if not handled:
                    # 没有任务时短暂休眠
                    await asyncio.sleep(1)
                
            except Exception as e:
                logger.error(f"Error in worker loop: {e}")
                await asyncio.sleep(5)  # 出错后等待5秒再继续
        
        logger.info("Background worker stopped")
    
    async def _handle_single_task(self, task_data: Dict[str, Any]):
        """处理单个任务"""
        task_name = task_data["name"]
        handler = self.task_mapping.get(task_name)
        
        if not handler:
            raise ValueError(f"No handler found for task: {task_name}")
        
        logger.info(f"Processing task {task_data['id']}: {task_name}")
        await handler(task_data)
        logger.info(f"Completed task {task_data['id']}: {task_name}")
    
    async def stop(self):
        """停止后台处理器"""
        self.running = False
        logger.info("Stopping background worker...")

# 启动后台处理器的函数
async def run_background_worker():
    """运行后台任务处理器"""
    from queue.task_queue import task_queue
    from handlers.task_handlers import TaskHandlers
    
    worker = BackgroundWorker(task_queue, TaskHandlers())
    await worker.start()

# 在应用启动时运行后台处理器(可选)
# if __name__ == "__main__":
#     import asyncio
#     asyncio.run(run_background_worker())

高级Redis数据结构应用

Sorted Set实现排行榜

# data_structures/rankings.py
import time
from typing import List, Tuple, Optional
from redis.asyncio import Redis
import json

class RankingSystem:
    """基于Sorted Set的排行榜系统"""
    
    def __init__(self, redis_client: Redis, name: str):
        self.redis = redis_client
        self.key = f"ranking:{name}"
    
    async def add_score(self, member: str, score: float):
        """添加分数"""
        await self.redis.zadd(self.key, {member: score})
    
    async def increment_score(self, member: str, increment: float) -> float:
        """增量更新分数"""
        new_score = await self.redis.zincrby(self.key, increment, member)
        return new_score
    
    async def get_rank(self, member: str) -> Optional[int]:
        """获取成员排名(从1开始)"""
        rank = await self.redis.zrevrank(self.key, member)
        return rank + 1 if rank is not None else None
    
    async def get_score(self, member: str) -> Optional[float]:
        """获取成员分数"""
        score = await self.redis.zscore(self.key, member)
        return score
    
    async def get_top_n(self, n: int) -> List[Tuple[str, float]]:
        """获取前N名"""
        members = await self.redis.zrevrange(self.key, 0, n-1, withscores=True)
        return [(member.decode() if isinstance(member, bytes) else member, score) 
                for member, score in members]
    
    async def get_range(self, start: int, end: int) -> List[Tuple[str, float]]:
        """获取指定范围的排名"""
        members = await self.redis.zrevrange(self.key, start, end, withscores=True)
        return [(member.decode() if isinstance(member, bytes) else member, score) 
                for member, score in members]
    
    async def get_members_by_score(self, min_score: float, max_score: float) -> List[Tuple[str, float]]:
        """根据分数范围获取成员"""
        members = await self.redis.zrangebyscore(
            self.key, min_score, max_score, withscores=True
        )
        return [(member.decode() if isinstance(member, bytes) else member, score) 
                for member, score in members]
    
    async def remove_member(self, member: str) -> bool:
        """移除成员"""
        result = await self.redis.zrem(self.key, member)
        return bool(result)
    
    async def get_total_count(self) -> int:
        """获取总成员数"""
        return await self.redis.zcard(self.key)
    
    async def get_percentile_rank(self, member: str) -> Optional[float]:
        """获取成员百分位排名"""
        total = await self.get_total_count()
        if total == 0:
            return None
        
        rank = await self.get_rank(member)
        if rank is None:
            return None
        
        return (total - rank) / total * 100

# 使用示例
class GameRankingSystem:
    """游戏排行榜系统"""
    
    def __init__(self, redis_client: Redis):
        self.score_ranking = RankingSystem(redis_client, "game_scores")
        self.win_ranking = RankingSystem(redis_client, "game_wins")
        self.playtime_ranking = RankingSystem(redis_client, "game_playtime")
    
    async def update_player_stats(self, player_id: str, score: float, wins: int, playtime: float):
        """更新玩家统计数据"""
        # 更新各项排名
        await self.score_ranking.add_score(player_id, score)
        await self.win_ranking.add_score(player_id, wins)
        await self.playtime_ranking.add_score(player_id, playtime)
    
    async def get_player_overview(self, player_id: str) -> dict:
        """获取玩家综合排名信息"""
        return {
            "player_id": player_id,
            "score": await self.score_ranking.get_score(player_id),
            "score_rank": await self.score_ranking.get_rank(player_id),
            "wins": await self.win_ranking.get_score(player_id),
            "wins_rank": await self.win_ranking.get_rank(player_id),
            "playtime": await self.playtime_ranking.get_score(player_id),
            "playtime_rank": await self.playtime_ranking.get_rank(player_id),
        }

HyperLogLog实现UV统计

# data_structures/uv_counter.py
import time
from redis.asyncio import Redis
from datetime import datetime, timedelta

class UVCounter:
    """基于HyperLogLog的UV统计器"""
    
    def __init__(self, redis_client: Redis, name: str):
        self.redis = redis_client
        self.name = name
    
    def _get_key(self, date: str) -> str:
        """获取日期对应的key"""
        return f"uv:{self.name}:{date}"
    
    def _get_period_key(self, period: str) -> str:
        """获取周期对应的key"""
        return f"uv:{self.name}:{period}"
    
    async def add_visitor(self, visitor_id: str, date: str = None):
        """添加访客"""
        if date is None:
            date = datetime.now().strftime("%Y-%m-%d")
        
        key = self._get_key(date)
        await self.redis.pfadd(key, visitor_id)
        
        # 同时添加到周期统计
        today = datetime.now().date()
        if datetime.strptime(date, "%Y-%m-%d").date() == today:
            # 今日数据也计入周、月统计
            week_key = self._get_period_key("week")
            month_key = self._get_period_key("month")
            await self.redis.pfadd(week_key, visitor_id)
            await self.redis.pfadd(month_key, visitor_id)
    
    async def get_daily_uv(self, date: str = None) -> int:
        """获取指定日期UV"""
        if date is None:
            date = datetime.now().strftime("%Y-%m-%d")
        
        key = self._get_key(date)
        return await self.redis.pfcount(key)
    
    async def get_period_uv(self, period: str) -> int:
        """获取周期UV(week/month/year)"""
        key = self._get_period_key(period)
        return await self.redis.pfcount(key)
    
    async def get_uv_trend(self, days: int = 7) -> List[Tuple[str, int]]:
        """获取UV趋势"""
        trend = []
        for i in range(days):
            date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
            uv = await self.get_daily_uv(date)
            trend.append((date, uv))
        return list(reversed(trend))  # 从最早到最近

# 使用示例
class AnalyticsService:
    """分析服务"""
    
    def __init__(self, redis_client: Redis):
        self.page_uv = UVCounter(redis_client, "page_views")
        self.user_uv = UVCounter(redis_client, "unique_users")
        self.region_uv = UVCounter(redis_client, "region_visits")
    
    async def record_page_view(self, user_id: str, page: str, region: str):
        """记录页面访问"""
        date = datetime.now().strftime("%Y-%m-%d")
        
        # 记录页面UV
        await self.page_uv.add_visitor(f"{page}:{user_id}", date)
        
        # 记录用户UV
        await self.user_uv.add_visitor(user_id, date)
        
        # 记录地区UV
        await self.region_uv.add_visitor(f"{region}:{user_id}", date)

BitMap实现签到系统

# data_structures/checkin_system.py
import time
from datetime import datetime, date
from redis.asyncio import Redis

class CheckInSystem:
    """基于BitMap的签到系统"""
    
    def __init__(self, redis_client: Redis, name: str = "checkin"):
        self.redis = redis_client
        self.name = name
    
    def _get_key(self, year_month: str) -> str:
        """获取年月对应的key"""
        return f"bitmap:{self.name}:{year_month}"
    
    def _get_day_offset(self, day: int) -> int:
        """获取日期在月份中的偏移量"""
        return day - 1
    
    async def check_in(self, user_id: str, check_date: date = None):
        """用户签到"""
        if check_date is None:
            check_date = date.today()
        
        year_month = check_date.strftime("%Y-%m")
        day_offset = self._get_day_offset(check_date.day)
        
        key = self._get_key(year_month)
        await self.redis.setbit(key, day_offset, 1)
        
        # 设置过期时间(次月清零)
        next_month = (check_date.replace(day=1) + timedelta(days=32)).replace(day=1)
        expire_time = int((next_month - datetime.now()).total_seconds())
        await self.redis.expire(key, expire_time)
    
    async def is_checked_in(self, user_id: str, check_date: date = None) -> bool:
        """检查是否已签到"""
        if check_date is None:
            check_date = date.today()
        
        year_month = check_date.strftime("%Y-%m")
        day_offset = self._get_day_offset(check_date.day)
        
        key = self._get_key(year_month)
        return bool(await self.redis.getbit(key, day_offset))
    
    async def get_checkin_days(self, user_id: str, target_month: str) -> int:
        """获取指定月份签到天数"""
        key = self._get_key(target_month)
        return await self.redis.bitcount(key)
    
    async def get_continuous_checkin_streak(self, user_id: str, check_dates: list) -> int:
        """获取连续签到天数"""
        streak = 0
        for check_date in reversed(check_dates):
            if await self.is_checked_in(user_id, check_date):
                streak += 1
            else:
                break
        return streak
    
    async def get_monthly_report(self, user_id: str, target_month: str) -> dict:
        """获取月度签到报告"""
        key = self._get_key(target_month)
        
        # 获取总签到天数
        total_days = await self.redis.bitcount(key)
        
        # 获取每日签到情况
        daily_status = []
        for day in range(1, 32):  # 最多31天
            try:
                date_str = f"{target_month}-{day:02d}"
                check_date = datetime.strptime(date_str, "%Y-%m-%d").date()
                is_checked = await self.is_checked_in(user_id, check_date)
                daily_status.append({"day": day, "checked": is_checked})
            except ValueError:
                # 无效日期(如2月30日),跳出循环
                break
        
        return {
            "month": target_month,
            "total_checkins": total_days,
            "daily_status": daily_status,
            "percentage": round(total_days / len(daily_status) * 100, 2) if daily_status else 0
        }

性能监控与调优

Redis性能监控

# monitoring/redis_monitor.py
from typing import Dict, Any, Optional
import time
import asyncio
from redis.asyncio import Redis
import logging

logger = logging.getLogger(__name__)

class RedisMonitor:
    """Redis性能监控器"""
    
    def __init__(self, redis_client: Redis):
        self.redis = redis_client
        self.metrics = {}
    
    async def get_server_info(self) -> Dict[str, Any]:
        """获取Redis服务器信息"""
        info = await self.redis.info()
        return {
            "redis_version": info.get("redis_version"),
            "uptime_in_seconds": info.get("uptime_in_seconds"),
            "connected_clients": info.get("connected_clients"),
            "used_memory": info.get("used_memory"),
            "used_memory_human": info.get("used_memory_human"),
            "used_memory_peak": info.get("used_memory_peak"),
            "used_memory_peak_human": info.get("used_memory_peak_human"),
            "total_commands_processed": info.get("total_commands_processed"),
            "instantaneous_ops_per_sec": info.get("instantaneous_ops_per_sec"),
            "keyspace_hits": info.get("keyspace_hits"),
            "keyspace_misses": info.get("keyspace_misses"),
            "hit_rate": (
                info.get("keyspace_hits", 0) / 
                max(info.get("keyspace_hits", 0) + info.get("keyspace_misses", 1), 1) * 100
            )
        }
    
    async def get_command_stats(self) -> Dict[str, Any]:
        """获取命令统计信息"""
        info = await self.redis.info("commandstats")
        command_stats = {}
        for key, value in info.items():
            if key.startswith("cmdstat_"):
                cmd = key.replace("cmdstat_", "")
                command_stats[cmd] = {
                    "calls": value.get("calls", 0),
                    "usec": value.get("usec", 0),
                    "usec_per_call": value.get("usec_per_call", 0)
                }
        return command_stats
    
    async def get_slowlog(self, count: int = 10) -> list:
        """获取慢查询日志"""
        slowlog = await self.redis.slowlog_get(count)
        return [
            {
                "id": entry["id"],
                "timestamp": entry["time"],
                "duration": entry["duration"],
                "command": " ".join(entry["command"])
            }
            for entry in slowlog
        ]
    
    async def monitor_performance(self, interval: int = 60) -> None:
        """持续监控性能"""
        while True:
            try:
                info = await self.get_server_info()
                logger.info(f"Redis Performance: {info}")
                
                # 检查内存使用情况
                used_memory = info.get("used_memory", 0)
                memory_limit = 1024 * 1024 * 512  # 512MB
                if used_memory > memory_limit * 0.8:  # 超过80%警告
                    logger.warning(f"High memory usage: {info.get('used_memory_human')}")
                
                # 检查命中率
                hit_rate = info.get("hit_rate", 0)
                if hit_rate < 80:  # 命中率低于80%警告
                    logger.warning(f"Low cache hit rate: {hit_rate:.2f}%")
                
                await asyncio.sleep(interval)
                
            except Exception as e:
                logger.error(f"Error monitoring Redis: {e}")
                await asyncio.sleep(interval)

# 在应用启动时启动监控
async def start_redis_monitoring():
    """启动Redis监控"""
    monitor = RedisMonitor(get_redis_client())
    await monitor.monitor_performance()

缓存性能优化

# optimization/cache_optimizer.py
from typing import Dict, Any, List
import time
from redis.asyncio import Redis
import logging

logger = logging.getLogger(__name__)

class CacheOptimizer:
    """缓存性能优化器"""
    
    def __init__(self, redis_client: Redis):
        self.redis = redis_client
        self.access_patterns = {}
    
    async def analyze_cache_usage(self) -> Dict[str, Any]:
        """分析缓存使用模式"""
        # 获取所有key
        keys = await self.redis.keys("*")
        
        analysis = {
            "total_keys": len(keys),
            "key_patterns": {},
            "memory_usage": {},
            "expiration_analysis": {}
        }
        
        for key in keys:
            # 分析key模式
            key_str = key if isinstance(key, str) else key.decode()
            pattern = ":".join(key_str.split(":")[:-1]) if ":" in key_str else key_str
            
            if pattern not in analysis["key_patterns"]:
                analysis["key_patterns"][pattern] = 0
            analysis["key_patterns"][pattern] += 1
            
            # 获取内存使用情况
            memory_usage = await self.redis.memory_usage(key)
            if memory_usage:
                analysis["memory_usage"][key_str] = memory_usage
        
        # 分析过期时间分布
        expirations = {}
        for key in keys:
            ttl = await self.redis.ttl(key)
            ttl_range = "no_expire" if ttl == -1 else f"{ttl // 3600}h"  # 按小时分组
            expirations[ttl_range] = expirations.get(ttl_range, 0) + 1
        
        analysis["expiration_analysis"] = expirations
        return analysis
    
    async def suggest_optimizations(self) -> List[str]:
        """建议优化措施"""
        suggestions = []
        analysis = await self.analyze_cache_usage()
        
        # 检查内存使用
        total_keys = analysis["total_keys"]
        if total_keys > 10000:
            suggestions.append("缓存key数量过多,考虑使用更大的过期时间或分片")
        
        # 检查内存占用大的key
        large_keys = sorted(
            analysis["memory_usage"].items(), 
            key=lambda x: x[1], 
            reverse=True
        )[:5]
        for key, size in large_keys:
            if size > 1024 * 100:  # 超过100KB
                suggestions.append(f"Key '{key}' size is large: {size} bytes, consider compression")
        
        # 检查过期策略
        no_expire_count = analysis["expiration_analysis"].get("no_expire", 0)
        if no_expire_count > total_keys * 0.5:  # 50%以上的key不过期
            suggestions.append("大量key不过期,建议设置合理的过期时间防止内存泄漏")
        
        # 检查key模式
        for pattern, count in analysis["key_patterns"].items():
            if count > 1000:  # 某种模式的key过多
                suggestions.append(f"Pattern '{pattern}' has {count} keys, consider optimization")
        
        return suggestions
    
    async def optimize_cache_strategy(self, strategy: str = "auto") -> Dict[str, Any]:
        """优化缓存策略"""
        if strategy == "auto":
            suggestions = await self.suggest_optimizations()
            results = []
            for suggestion in suggestions:
                if "large" in suggestion.lower():
                    # 压缩大key的建议
                    results.append("Consider compressing large values")
                elif "expire" in suggestion.lower():
                    # 过期时间优化
                    results.append("Adjust expiration times for better memory management")
                elif "pattern" in suggestion.lower():
                    # 模式优化
                    results.append("Review key naming patterns for efficiency")
            return {"suggestions": suggestions, "actions": results}
        
        return {"suggestions": [], "actions": []}

安全配置与最佳实践

Redis安全配置

# security/redis_security.py
import redis.asyncio as redis
from redis.exceptions import AuthenticationError, ConnectionError
import logging

logger = logging.getLogger(__name__)

class RedisSecurity:
    """Redis安全配置管理"""
    
    @staticmethod
    def create_secure_connection(
        host: str = "localhost",
        port: int = 6379,
        password: str = None,
        username: str = None,
        ssl: bool = False,
        ssl_cert_reqs: str = "required",
        ssl_ca_certs: str = None,
        ssl_certfile: str = None,
        ssl_keyfile: str = None,
        decode_responses: bool = True,
        encoding: str = "utf-8"
    ) -> redis.Redis:
        """创建安全的Redis连接"""
        
        connection_kwargs = {
            "host": host,
            "port": port,
            "decode_responses": decode_responses,
            "encoding": encoding,
        }
        
        if password:
            connection_kwargs["password"] = password
        if username:
            connection_kwargs["username"] = username
        
        # SSL配置
        if ssl:
            connection_kwargs.update({
                "ssl": True,
                "ssl_cert_reqs": ssl_cert_reqs,
            })
            if ssl_ca_certs:
                connection_kwargs["ssl_ca_certs"] = ssl_ca_certs
            if ssl_certfile:
                connection_kwargs["ssl_certfile"] = ssl_certfile
            if ssl_keyfile:
                connection_kwargs["ssl_keyfile"] = ssl_keyfile
        
        return redis.Redis(**connection_kwargs)
    
    @staticmethod
    async def validate_connection(client: redis.Redis) -> bool:
        """验证Redis连接安全性"""
        try:
            # 测试连接
            await client.ping()
            
            # 检查服务器配置
            info = await client.info()
            
            # 检查绑定地址(不应是0.0.0.0除非必要)
            # 注意:Redis服务器配置检查需要通过其他方式
            # 这里主要是客户端连接验证
            
            return True
            
        except AuthenticationError:
            logger.error("Redis authentication failed")
            return False
        except ConnectionError:
            logger.error("Redis connection failed")
            return False
        except Exception as e:
            logger.error(f"Redis security validation failed: {e}")
            return False

# 安全的Redis客户端工厂
class SecureRedisClient:
    """安全Redis客户端工厂"""
    
    def __init__(self, config):
        self.config = config
        self._client = None
    
    async def get_client(self) -> redis.Redis:
        """获取安全配置的Redis客户端"""
        if not self._client:
            self._client = RedisSecurity.create_secure_connection(
                host=self.config.redis_host,
                port=self.config.redis_port,
                password=self.config.redis_password,
                ssl=self.config.redis_ssl_enabled
            )
            
            # 验证连接安全性
            if not await RedisSecurity.validate_connection(self._client):
                raise ConnectionError("Failed to establish secure Redis connection")
        
        return self._client

缓存安全最佳实践

# security/cache_security.py
import hashlib
import hmac
import secrets
from typing import Any, Dict
import json
import logging

logger = logging.getLogger(__name__)

class SecureCache:
    """安全缓存实现"""
    
    def __init__(self, redis_client, secret_key: str):
        self.redis = redis_client
        self.secret_key = secret_key
    
    def _generate_cache_key(self, base_key: str, user_id: str = None) -> str:
        """生成安全的缓存key"""
        if user_id:
            # 用户特定的缓存key
            key_material = f"{base_key}:{user_id}:{self.secret_key}"
        else:
            # 公共缓存key
            key_material = f"{base_key}:{self.secret_key}"
        
        # 使用SHA256生成固定长度的key
        hashed_key = hashlib.sha256(key_material.encode()).hexdigest()
        return f"secure:{base_key}:{hashed_key[:16]}"
    
    def _encrypt_data(self, data: Any) -> str:
        """加密缓存数据"""
        # 简化的加密示例,实际应用中应使用更强的加密
        json_data = json.dumps(data, default=str)
        encrypted = json_data  # 实际加密逻辑
        return encrypted
    
    def _decrypt_data(self, encrypted_data: str) -> Any:
        """解密缓存数据"""
        # 简化的解密示例
        decrypted = encrypted_data  # 实际解密逻辑
        return json.loads(decrypted)
    
    async def set_secure(self, key: str, value: Any, ttl: int = 300, user_id: str = None):
        """安全地设置缓存"""
        cache_key = self._generate_cache_key(key, user_id)
        encrypted_value = self._encrypt_data(value)
        
        await self.redis.setex(cache_key, ttl, encrypted_value)
    
    async def get_secure(self, key: str, user_id: str = None) -> Any:
        """安全地获取缓存"""
        cache_key = self._generate_cache_key(key, user_id)
        encrypted_value = await self.redis.get(cache_key)
        
        if encrypted_value:
            return self._decrypt_data(encrypted_value)
        
        return None
    
    async def invalidate_user_cache(self, user_id: str, pattern: str = "*"):
        """清除用户特定的缓存"""
        search_pattern = f"secure:{pattern}:{user_id}:*"
        keys = await self.redis.keys(search_pattern)
        if keys:
            await self.redis.delete(*keys)

# 使用示例
class SecureUserService:
    """安全用户服务"""
    
    def __init__(self, redis_client, secret_key: str):
        self.secure_cache = SecureCache(redis_client, secret_key)
    
    async def get_user_profile_secure(self, user_id: int):
        """安全获取用户资料"""
        cache_key = f"user:profile:{user_id}"
        cached_data = await self.secure_cache.get_secure(cache_key, str(user_id))
        
        if cached_data:
            return cached_data
        
        # 从数据库获取数据
        from repositories.user_repository import UserRepository
        user_repo = UserRepository()
        user = await user_repo.get_by_id(user_id)
        
        if user:
            profile_data = {
                "id": user.id,
                "name": user.name,
                "email": user.email,
                "created_at": user.created_at.isoformat()
            }
            
            # 安全地缓存数据
            await self.secure_cache.set_secure(
                cache_key, profile_data, ttl=900, user_id=str(user_id)
            )
            
            return profile_data
        
        return None

常见陷阱与避坑指南

陷阱1:缓存穿透

# ❌ 错误:缓存穿透
async def get_user_bad(user_id: int):
    # 恶意用户请求大量不存在的ID
    # 每次都查询数据库,造成数据库压力
    user = await db.get_user(user_id)
    if not user:
        return None  # 空值不缓存,导致持续查询数据库
    return user

# ✅ 正确:缓存空值
async def get_user_good(user_id: int):
    # 检查缓存
    cached_user = await redis.get(f"user:{user_id}")
    if cached_user is not None:
        return json.loads(cached_user) if cached_user != "NULL" else None
    
    # 查询数据库
    user = await db.get_user(user_id)
    
    # 无论是否存在都缓存结果
    if user:
        await redis.setex(f"user:{user_id}", 300, json.dumps(user))
    else:
        # 缓存空值,防止缓存穿透,但设置较短过期时间
        await redis.setex(f"user:{user_id}", 60, "NULL")
    
    return user

陷阱2:缓存雪崩

# ❌ 错误:缓存雪崩
async def load_hot_data_bad():
    # 大量缓存同时过期,导致数据库瞬间压力过大
    tasks = []
    for i in range(1000):
        if not await redis.exists(f"data:{i}"):
            # 同时查询数据库
            task = db.load_data(i)
            tasks.append(task)
    return await asyncio.gather(*tasks)

# ✅ 正确:随机过期时间
async def load_hot_data_good():
    async def load_with_jitter(data_id: int):
        cache_key = f"data:{data_id}"
        cached_data = await redis.get(cache_key)
        
        if cached_data:
            return json.loads(cached_data)
        
        # 加载数据
        data = await db.load_data(data_id)
        
        # 设置随机过期时间(300-600秒之间)
        jitter = secrets.randbelow(300)  # 0-299秒随机抖动
        await redis.setex(cache_key, 300 + jitter, json.dumps(data))
        
        return data
    
    tasks = [load_with_jitter(i) for i in range(1000)]
    return await asyncio.gather(*tasks)

陷阱3:缓存击穿

# ❌ 错误:缓存击穿
async def get_hot_article_bad(article_id: int):
    article = await redis.get(f"article:{article_id}")
    if not article:
        # 热点数据过期时,大量请求同时查询数据库
        article = await db.get_article(article_id)
        await redis.setex(f"article:{article_id}", 3600, json.dumps(article))
    return article

# ✅ 正确:分布式锁
async def get_hot_article_good(article_id: int):
    cache_key = f"article:{article_id}"
    lock_key = f"lock:article:{article_id}"
    
    # 尝试获取分布式锁
    lock_acquired = await redis.set(lock_key, "1", nx=True, ex=10)  # 10秒锁
    
    if lock_acquired:
        try:
            # 双重检查
            article = await redis.get(cache_key)
            if article:
                return json.loads(article)
            
            # 加载数据
            article = await db.get_article(article_id)
            await redis.setex(cache_key, 3600, json.dumps(article))
            return article
        finally:
            # 释放锁
            await redis.delete(lock_key)
    else:
        # 未获取到锁,等待后重试或返回旧数据
        await asyncio.sleep(0.1)
        article = await redis.get(cache_key)
        return json.loads(article) if article else None

陷阱4:内存泄漏

# ❌ 错误:没有设置过期时间
async def cache_without_expire():
    # 无限期缓存数据,可能导致内存泄漏
    await redis.set("permanent_data", json.dumps(data))

# ✅ 正确:设置合理的过期时间
async def cache_with_expire():
    # 设置合理的过期时间
    await redis.setex("temp_data", 3600, json.dumps(data))  # 1小时过期

陷阱5:序列化问题

# ❌ 错误:序列化复杂对象
async def cache_complex_object_bad():
    # datetime等复杂对象可能无法正确序列化
    data = {"timestamp": datetime.now(), "user": User()}
    await redis.setex("complex", 300, json.dumps(data))  # 可能失败

# ✅ 正确:处理序列化
async def cache_complex_object_good():
    data = {
        "timestamp": datetime.now().isoformat(),
        "user": user.dict() if hasattr(user, 'dict') else str(user)
    }
    await redis.setex("complex", 300, json.dumps(data, default=str))

与其他缓存方案对比

Redis vs Memcached

特性RedisMemcached
数据结构支持多种(String, Hash, List, Set, ZSet等)仅支持String
持久化✅ 支持RDB和AOF持久化❌ 仅内存存储
数据过期✅ 精确到key的过期✅ 支持过期时间
内存效率⚠️ 相对较低(功能丰富)✅ 更高的内存效率
集群支持✅ 原生集群模式✅ 一致性哈希
事务支持✅ 支持事务❌ 无事务
发布订阅✅ 支持Pub/Sub❌ 不支持
适用场景复杂缓存需求、会话存储、消息队列简单KV缓存

Redis vs 本地缓存

| 特性 | Redis | 本地缓存 (如c


1. Redis 快速入门

1.1 为什么 Web 服务需要 Redis?

请求 → FastAPI → 数据库查询 → 返回

缓存后:
请求 → FastAPI → Redis(命中)→ 直接返回
请求 → FastAPI → Redis(未命中)→ 数据库 → 写入Redis → 返回

Redis 作用:把频繁查询但不常变化的数据(如用户信息、文章列表)放在内存里,查询速度比数据库快 10-100 倍。

1.2 安装与连接

pip install redis[hiredis]
# hiredis 是 C 语言实现的高性能解析器,推荐安装
# redis_client.py
import redis.asyncio as redis
from config import get_settings

settings = get_settings()

# 异步 Redis 连接
redis_client = redis.from_url(
    settings.redis_url,  # "redis://localhost:6379/0"
    encoding="utf-8",
    decode_responses=True,   # 自动将 bytes 解码为 str
)

# 测试连接
async def test_redis():
    await redis_client.ping()
    print("✅ Redis 连接成功")

2. 缓存实现

2.1 通用缓存装饰器

# cache.py
import json
from functools import wraps
from redis.asyncio import Redis

def cached(
    key_prefix: str,
    expire: int = 300,  # 默认 5 分钟过期
):
    """
    异步缓存装饰器
    用法:
    @cached("user:{user_id}", expire=600)
    async def get_user(user_id: int):
        return await db.get_user(user_id)
    """
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            # 构造缓存 key
            cache_key = key_prefix.format(*args, **kwargs)

            # 尝试从 Redis 获取
            cached_value = await redis_client.get(cache_key)
            if cached_value:
                return json.loads(cached_value)

            # 未命中,执行原函数
            result = await func(*args, **kwargs)

            # 存入 Redis
            await redis_client.setex(
                cache_key,
                expire,
                json.dumps(result, default=str)
            )
            return result
        return wrapper
    return decorator


async def invalidate_cache(pattern: str):
    """清除匹配的缓存键"""
    keys = await redis_client.keys(pattern)
    if keys:
        await redis_client.delete(*keys)

2.2 用户信息缓存

from cache import cached, invalidate_cache

# 缓存用户信息,5 分钟内重复查询直接返回缓存
@cached("user:{user_id}", expire=300)
async def get_user_cached(user_id: int):
    user = await db.get_user(user_id)
    return {
        "id": user.id,
        "name": user.name,
        "email": user.email,
        "avatar": user.avatar,
    }

# 用户更新后清除缓存
async def update_user(user_id: int, data: dict):
    await db.update_user(user_id, data)
    # 清除该用户的缓存
    await invalidate_cache(f"user:{user_id}")

2.3 页面缓存(整页缓存)

@app.get("/posts/{slug}")
async def get_post(slug: str):
    cache_key = f"post:page:{slug}"

    # 尝试获取缓存
    cached_html = await redis_client.get(cache_key)
    if cached_html:
        return HTMLResponse(content=cached_html)

    # 生成页面
    post = await get_post_data(slug)
    html = render_post_template(post)

    # 缓存 10 分钟
    await redis_client.setex(cache_key, 600, html)
    return HTMLResponse(content=html)

3. 分布式 Session

3.1 基于 Redis 的 Session 管理

# session.py
import uuid
import json
from datetime import timedelta
from redis.asyncio import Redis

class RedisSession:
    def __init__(self, redis: Redis):
        self.redis = redis
        self.prefix = "session:"
        self.expire = 86400 * 7  # 7 天过期

    async def create(self, user_id: int, extra_data: dict = None) -> str:
        """创建新 Session,返回 session_id"""
        session_id = str(uuid.uuid4())
        session_key = self.prefix + session_id
        data = {"user_id": user_id, **(extra_data or {})}
        await self.redis.setex(
            session_key,
            self.expire,
            json.dumps(data)
        )
        return session_id

    async def get(self, session_id: str) -> dict | None:
        """获取 Session 数据"""
        session_key = self.prefix + session_id
        data = await self.redis.get(session_key)
        if data:
            return json.loads(data)
        return None

    async def refresh(self, session_id: str) -> bool:
        """刷新 Session 过期时间"""
        session_key = self.prefix + session_id
        return await self.redis.expire(session_key, self.expire)

    async def destroy(self, session_id: str) -> bool:
        """销毁 Session(登出)"""
        session_key = self.prefix + session_id
        return await self.redis.delete(session_key) > 0

session_manager = RedisSession(redis_client)

3.2 Session 认证依赖

from fastapi import Cookie, HTTPException

async def get_current_user_via_session(session_id: str = Cookie(None)):
    if not session_id:
        raise HTTPException(401, "请先登录")
    session = await session_manager.get(session_id)
    if not session:
        raise HTTPException(401, "Session 已过期,请重新登录")
    await session_manager.refresh(session_id)  # 续期
    return session

@app.get("/profile")
async def profile(user: dict = Depends(get_current_user_via_session)):
    return {"user_id": user["user_id"], "data": user}

4. 简单的消息队列

4.1 任务队列:List 实现

# queue.py
TASK_QUEUE = "fastapi:task_queue"

async def enqueue_task(task_data: dict):
    """入队:添加任务到队列尾部"""
    import json
    await redis_client.rpush(TASK_QUEUE, json.dumps(task_data))

async def dequeue_task(blocking: bool = True, timeout: int = 5):
    """出队:从队列头部取任务"""
    import json
    if blocking:
        result = await redis_client.blpop(TASK_QUEUE, timeout=timeout)
        if result:
            _, task = result
            return json.loads(task)
    else:
        task = await redis_client.lpop(TASK_QUEUE)
        if task:
            return json.loads(task)
    return None

# 后台 worker
async def process_tasks():
    while True:
        task = await dequeue_task(blocking=True, timeout=0)
        if task:
            print(f"处理任务: {task}")
            # await process_task(task)

4.2 延迟队列:ZSet 实现

DELAY_QUEUE = "fastapi:delay_queue"
DELAY_TEMP_QUEUE =