#FastAPI与Redis集成完全指南
📂 所属阶段:第三阶段 — 数据持久化(数据库篇)
🔗 相关章节:FastAPI SQLAlchemy 2.0实战 · FastAPI异步编程深度解析
#目录
- 为什么选择Redis集成
- Redis基础配置与连接管理
- 高性能缓存策略
- 分布式Session管理
- 异步消息队列实现
- 高级Redis数据结构应用
- 性能监控与调优
- 安全配置与最佳实践
- 常见陷阱与避坑指南
- 与其他缓存方案对比
- 总结
#为什么选择Redis集成?
#Redis在现代Web应用中的核心价值
Redis作为内存数据结构存储,为现代Web应用提供了关键的性能和功能优势:
| 应用场景 | 传统方案 | Redis方案 | 性能提升 |
|---|---|---|---|
| 热点数据缓存 | 数据库查询 | 内存读取 | 100x+ |
| Session管理 | 数据库存储 | 内存存储 | 50x+ |
| 消息队列 | 专用MQ | Redis List/ZSet | 10x+ |
| 实时排行榜 | SQL聚合 | Sorted Set | 100x+ |
| 限流控制 | 应用层计数 | Redis计数 | 50x+ |
#高性能架构对比
传统架构:
请求 → FastAPI → 数据库查询 → 返回数据
↓
单点瓶颈,高延迟
Redis优化架构:
请求 → FastAPI → Redis缓存(命中)→ 直接返回
↓ ↓
数据库查询 ← 未命中 ← 缓存未命中
↓
Redis写入 ← 更新缓存#项目依赖安装
# Redis异步客户端
pip install redis[hiredis]
# hiredis是C语言实现的高性能解析器,显著提升性能
# 可选:Redis连接池监控
pip install aioredis#完整依赖示例
# requirements.txt
fastapi==0.104.1
redis[hiredis]==5.0.1
aioredis==2.0.1
pydantic==2.5.0
python-multipart==0.0.6
uvicorn==0.24.0#Redis基础配置与连接管理
#Redis连接配置详解
# redis_client.py
import redis.asyncio as redis
from redis.backoff import ExponentialBackoff
from redis.retry import Retry
from config import get_settings
import logging
from typing import Optional
from contextlib import asynccontextmanager
settings = get_settings()
logger = logging.getLogger(__name__)
class RedisManager:
"""Redis连接管理器 - 集中管理Redis连接和配置"""
def __init__(self):
self.client: Optional[redis.Redis] = None
self._initialized = False
async def initialize(self):
"""初始化Redis连接"""
if self._initialized:
return
try:
# 创建异步Redis连接
self.client = redis.from_url(
settings.redis_url,
# 连接池配置
max_connections=settings.redis_max_connections,
retry_on_timeout=settings.redis_retry_on_timeout,
retry=Retry(ExponentialBackoff(), attempts=3),
# 序列化配置
encoding="utf-8",
decode_responses=True, # 自动解码bytes为str
# 性能优化
socket_keepalive=True, # 保持连接
socket_keepalive_options={}, # TCP keepalive选项
health_check_interval=30, # 健康检查间隔
# 超时配置
socket_connect_timeout=5, # 连接超时
socket_timeout=30, # 读写超时
retry_on_timeout=True, # 超时重试
)
# 测试连接
await self.client.ping()
logger.info("Redis connection established successfully")
self._initialized = True
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
async def close(self):
"""关闭Redis连接"""
if self.client:
await self.client.close()
self._initialized = False
logger.info("Redis connection closed")
def get_client(self) -> redis.Redis:
"""获取Redis客户端实例"""
if not self._initialized or not self.client:
raise RuntimeError("Redis client not initialized")
return self.client
# 全局Redis管理器实例
redis_manager = RedisManager()
# 便捷的客户端获取函数
def get_redis_client() -> redis.Redis:
"""获取Redis客户端的便捷函数"""
return redis_manager.get_client()#配置管理
# config.py (扩展配置)
from pydantic import BaseSettings
from functools import lru_cache
from typing import List, Optional
class Settings(BaseSettings):
# 应用配置
app_name: str = "FastAPI Redis App"
version: str = "1.0.0"
debug: bool = False
# Redis配置
redis_url: str = "redis://localhost:6379/0"
redis_max_connections: int = 20
redis_retry_on_timeout: bool = True
redis_socket_keepalive: bool = True
redis_health_check_interval: int = 30
redis_socket_connect_timeout: int = 5
redis_socket_timeout: int = 30
# Redis命名空间配置
redis_namespace: str = "fastapi_app"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
@lru_cache()
def get_settings() -> Settings:
return Settings()
settings = get_settings()#在FastAPI中管理Redis生命周期
# main.py (扩展)
from fastapi import FastAPI
from contextlib import asynccontextmanager
from redis_client import redis_manager
import logging
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("Initializing Redis connection...")
# 启动时初始化Redis
await redis_manager.initialize()
logger.info("Redis initialized successfully")
yield
# 关闭时清理资源
logger.info("Shutting down Redis connection...")
await redis_manager.close()
logger.info("Redis connection closed")
app = FastAPI(
title="FastAPI Redis API",
version="1.0.0",
description="使用Redis缓存的FastAPI应用",
lifespan=lifespan
)
# 健康检查端点
@app.get("/health")
async def health_check():
"""健康检查端点"""
redis_client = get_redis_client()
try:
await redis_client.ping()
redis_status = "connected"
except:
redis_status = "disconnected"
return {
"status": "healthy",
"redis": redis_status,
"timestamp": "2024-01-15T10:30:00Z"
}#Redis依赖注入
# dependencies.py (扩展)
from fastapi import Depends
from redis_client import get_redis_client
import redis.asyncio as redis
# Redis客户端依赖
async def get_redis() -> redis.Redis:
"""获取Redis客户端的依赖函数"""
return get_redis_client()
# 在路由中使用
from fastapi import APIRouter
import redis.asyncio as redis
router = APIRouter()
@router.get("/cache/health")
async def cache_health_check(redis_client: redis.Redis = Depends(get_redis)):
"""缓存健康检查"""
try:
info = await redis_client.info()
return {
"status": "healthy",
"version": info.get("redis_version"),
"connected_clients": info.get("connected_clients"),
"used_memory": info.get("used_memory_human")
}
except Exception as e:
return {"status": "unhealthy", "error": str(e)}#高性能缓存策略
#缓存装饰器实现
# cache/decorators.py
import json
import hashlib
from functools import wraps
from typing import Any, Callable, Optional, Union
from redis.asyncio import Redis
from typing import get_type_hints
import inspect
import logging
logger = logging.getLogger(__name__)
def cached(
key_prefix: str = "",
ttl: int = 300, # 默认5分钟过期
key_func: Optional[Callable] = None,
serializer: str = "json",
namespace: str = ""
):
"""
异步缓存装饰器
用法:
@cached("user:{user_id}", ttl=600)
async def get_user(user_id: int):
return await db.get_user(user_id)
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
# 获取Redis客户端
redis_client = get_redis_client()
# 构造缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
else:
# 自动生成缓存键
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
# 生成参数哈希
params_str = str(sorted(bound_args.arguments.items()))
params_hash = hashlib.md5(params_str.encode()).hexdigest()[:8]
# 构造缓存键
if key_prefix:
cache_key = f"{namespace}:{key_prefix}:{params_hash}"
else:
cache_key = f"{namespace}:{func.__name__}:{params_hash}"
try:
# 尝试从Redis获取
cached_value = await redis_client.get(cache_key)
if cached_value:
logger.debug(f"Cache hit for key: {cache_key}")
# 反序列化
if serializer == "json":
return json.loads(cached_value)
else:
return cached_value
logger.debug(f"Cache miss for key: {cache_key}")
# 未命中,执行原函数
result = await func(*args, **kwargs)
# 序列化并存入Redis
if serializer == "json":
serialized_result = json.dumps(result, default=str)
else:
serialized_result = result
await redis_client.setex(cache_key, ttl, serialized_result)
logger.debug(f"Cache set for key: {cache_key}")
return result
except Exception as e:
logger.error(f"Cache error for key {cache_key}: {e}")
# 缓存失败时仍返回原始结果
return await func(*args, **kwargs)
# 添加缓存管理方法
wrapper.invalidate = lambda *args, **kwargs: invalidate_cache(key_prefix, *args, **kwargs)
return wrapper
return decorator
async def invalidate_cache(pattern: str, *args, **kwargs):
"""清除匹配的缓存键"""
redis_client = get_redis_client()
# 构造清除模式
if args or kwargs:
# 如果提供了参数,构造具体键
cache_key = pattern.format(*args, **kwargs)
await redis_client.delete(cache_key)
else:
# 否则使用通配符模式
keys = await redis_client.keys(f"*{pattern}*")
if keys:
await redis_client.delete(*keys)
# 批量缓存操作
class BatchCache:
"""批量缓存操作类"""
def __init__(self, redis_client: Redis):
self.redis = redis_client
async def mget_cached(self, keys: list) -> list:
"""批量获取缓存"""
results = await self.redis.mget(keys)
return [json.loads(result) if result else None for result in results]
async def mset_cached(self, mapping: dict, ttl: int = 300):
"""批量设置缓存"""
# 序列化值
serialized_mapping = {
key: json.dumps(value, default=str)
for key, value in mapping.items()
}
pipe = self.redis.pipeline()
pipe.mset(serialized_mapping)
for key in mapping.keys():
pipe.expire(key, ttl)
await pipe.execute()#高级缓存策略
# cache/strategies.py
from enum import Enum
from typing import Optional, Any
import time
from redis.asyncio import Redis
class CacheStrategy(Enum):
"""缓存策略枚举"""
TTL = "ttl" # 固定TTL
LFU = "lfu" # 最少使用
LRU = "lru" # 最近最少使用
SLIDING_WINDOW = "sliding_window" # 滑动窗口
class AdvancedCache:
"""高级缓存策略实现"""
def __init__(self, redis_client: Redis):
self.redis = redis_client
async def get_with_fallback(
self,
cache_key: str,
fetch_func: Callable,
ttl: int = 300,
fallback_ttl: int = 60
):
"""
带降级的缓存获取
如果缓存失败,使用降级策略
"""
try:
# 尝试获取缓存
cached_value = await self.redis.get(cache_key)
if cached_value:
return json.loads(cached_value)
# 缓存未命中,执行获取函数
fresh_data = await fetch_func()
# 设置缓存
await self.redis.setex(
cache_key,
ttl,
json.dumps(fresh_data, default=str)
)
return fresh_data
except Exception as e:
# 降级策略:尝试获取降级缓存
fallback_key = f"{cache_key}:fallback"
try:
fallback_value = await self.redis.get(fallback_key)
if fallback_value:
logger.warning(f"Using fallback cache for {cache_key}: {e}")
return json.loads(fallback_value)
except:
pass
# 最后尝试直接获取数据(不缓存)
return await fetch_func()
async def sliding_window_cache(
self,
cache_key: str,
fetch_func: Callable,
window_ttl: int = 300,
refresh_threshold: int = 60
):
"""
滑动窗口缓存
在接近过期时提前刷新缓存
"""
cache_info_key = f"{cache_key}:info"
# 获取缓存信息
cache_info = await self.redis.hgetall(cache_info_key)
if cache_info:
cached_value = await self.redis.get(cache_key)
if cached_value:
# 检查是否需要刷新
created_at = int(cache_info.get('created_at', 0))
if time.time() - created_at > (window_ttl - refresh_threshold):
# 异步刷新缓存
import asyncio
asyncio.create_task(self._refresh_cache(cache_key, fetch_func, window_ttl))
return json.loads(cached_value)
# 缓存未命中,获取新鲜数据
fresh_data = await fetch_func()
# 存储数据和元信息
pipe = self.redis.pipeline()
pipe.setex(cache_key, window_ttl, json.dumps(fresh_data, default=str))
pipe.hset(cache_info_key, 'created_at', int(time.time()))
pipe.expire(cache_info_key, window_ttl + 60) # 信息稍微晚过期
await pipe.execute()
return fresh_data
async def _refresh_cache(self, cache_key: str, fetch_func: Callable, ttl: int):
"""异步刷新缓存"""
try:
fresh_data = await fetch_func()
await self.redis.setex(
cache_key,
ttl,
json.dumps(fresh_data, default=str)
)
logger.info(f"Cache refreshed for {cache_key}")
except Exception as e:
logger.error(f"Failed to refresh cache for {cache_key}: {e}")
# 缓存预热
class CacheWarmer:
"""缓存预热器"""
def __init__(self, redis_client: Redis):
self.redis = redis_client
async def warm_common_queries(self):
"""预热常见查询"""
# 预热热门用户数据
hot_user_ids = [1, 2, 3, 4, 5] # 示例热门用户ID
for user_id in hot_user_ids:
cache_key = f"user:profile:{user_id}"
if not await self.redis.exists(cache_key):
# 从数据库获取并缓存
user_data = await self._fetch_user_data(user_id)
await self.redis.setex(
cache_key,
3600, # 1小时
json.dumps(user_data, default=str)
)
async def _fetch_user_data(self, user_id: int):
"""获取用户数据的示例方法"""
# 这里应该调用数据库查询
return {"id": user_id, "name": f"User_{user_id}"}#用户信息缓存示例
# services/user_cache_service.py
from cache.decorators import cached, invalidate_cache
from cache.strategies import AdvancedCache
from typing import Optional, Dict, Any
import json
class UserCacheService:
"""用户缓存服务"""
def __init__(self):
self.advanced_cache = AdvancedCache(get_redis_client())
@cached("user:profile:{user_id}", ttl=900) # 15分钟
async def get_user_profile(self, user_id: int) -> Optional[Dict[str, Any]]:
"""获取用户资料缓存"""
# 实际的数据库查询逻辑
from repositories.user_repository import UserRepository
user_repo = UserRepository()
user = await user_repo.get_by_id(user_id)
if not user:
return None
return {
"id": user.id,
"name": user.name,
"email": user.email,
"avatar": user.avatar,
"created_at": user.created_at.isoformat() if user.created_at else None,
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
}
async def update_user_profile(self, user_id: int, data: Dict[str, Any]) -> bool:
"""更新用户资料并清除缓存"""
from repositories.user_repository import UserRepository
user_repo = UserRepository()
# 更新数据库
success = await user_repo.update(user_id, data)
if success:
# 清除相关缓存
await invalidate_cache(f"user:profile:{user_id}")
await invalidate_cache(f"user:basic:{user_id}")
await invalidate_cache(f"user:stats:{user_id}")
return success
async def get_user_statistics(self, user_id: int) -> Dict[str, Any]:
"""获取用户统计数据(滑动窗口缓存)"""
cache_key = f"user:stats:{user_id}"
async def fetch_stats():
# 从数据库获取统计数据
from repositories.user_repository import UserRepository
user_repo = UserRepository()
return await user_repo.get_user_statistics(user_id)
return await self.advanced_cache.sliding_window_cache(
cache_key,
fetch_stats,
window_ttl=1800, # 30分钟
refresh_threshold=300 # 5分钟提前刷新
)
async def get_hot_users(self) -> list:
"""获取热门用户列表(长期缓存)"""
cache_key = "user:hot_list"
async def fetch_hot_users():
from repositories.user_repository import UserRepository
user_repo = UserRepository()
return await user_repo.get_hot_users()
return await self.advanced_cache.get_with_fallback(
cache_key,
fetch_hot_users,
ttl=3600, # 1小时
fallback_ttl=600 # 降级缓存10分钟
)#分布式Session管理
#Session管理器实现
# session/manager.py
import uuid
import json
import time
from typing import Dict, Optional, Any
from redis.asyncio import Redis
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
class RedisSessionManager:
"""Redis分布式Session管理器"""
def __init__(self, redis_client: Redis, namespace: str = "session"):
self.redis = redis_client
self.namespace = namespace
self.session_prefix = f"{namespace}:"
self.session_expire = 86400 * 7 # 7天过期
self.refresh_threshold = 3600 # 1小时刷新阈值
async def create_session(
self,
user_id: int,
extra_data: Optional[Dict[str, Any]] = None,
custom_expire: Optional[int] = None
) -> str:
"""创建新Session,返回session_id"""
session_id = str(uuid.uuid4())
session_key = self.session_prefix + session_id
# 构造Session数据
session_data = {
"user_id": user_id,
"created_at": time.time(),
"last_accessed": time.time(),
"ip_address": extra_data.get("ip_address") if extra_data else None,
"user_agent": extra_data.get("user_agent") if extra_data else None,
**(extra_data or {})
}
# 存储Session
expire_time = custom_expire or self.session_expire
await self.redis.setex(
session_key,
expire_time,
json.dumps(session_data)
)
logger.info(f"Session created: {session_id} for user {user_id}")
return session_id
async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""获取Session数据"""
session_key = self.session_prefix + session_id
data = await self.redis.get(session_key)
if data:
session_data = json.loads(data)
# 更新最后访问时间
session_data["last_accessed"] = time.time()
await self._update_last_accessed(session_key, session_data)
return session_data
return None
async def _update_last_accessed(self, session_key: str, session_data: Dict[str, Any]):
"""更新Session最后访问时间"""
# 仅在超过刷新阈值时更新
last_accessed = session_data.get("last_accessed", 0)
if time.time() - last_accessed > self.refresh_threshold:
session_data["last_accessed"] = time.time()
# 重新设置过期时间
await self.redis.setex(
session_key,
self.session_expire,
json.dumps(session_data)
)
async def refresh_session(self, session_id: str, extend_time: Optional[int] = None) -> bool:
"""刷新Session过期时间"""
session_key = self.session_prefix + session_id
# 获取当前Session数据
data = await self.redis.get(session_key)
if not data:
return False
# 重新设置过期时间
expire_time = extend_time or self.session_expire
return await self.redis.expire(session_key, expire_time)
async def destroy_session(self, session_id: str) -> bool:
"""销毁Session(登出)"""
session_key = self.session_prefix + session_id
result = await self.redis.delete(session_key)
if result:
logger.info(f"Session destroyed: {session_id}")
return bool(result)
async def extend_session(self, session_id: str, additional_time: int) -> bool:
"""延长Session有效期"""
session_key = self.session_prefix + session_id
current_ttl = await self.redis.ttl(session_key)
if current_ttl < 0:
return False # Session不存在
new_ttl = max(current_ttl, 0) + additional_time
return await self.redis.expire(session_key, new_ttl)
async def get_user_sessions(self, user_id: int) -> list:
"""获取用户的所有活跃Session"""
# 注意:在生产环境中,可能需要维护用户Session索引
# 这里是一个简化实现
pattern = f"{self.session_prefix}*"
keys = await self.redis.keys(pattern)
sessions = []
for key in keys:
data = await self.redis.get(key)
if data:
session_data = json.loads(data)
if session_data.get("user_id") == user_id:
session_data["session_id"] = key.replace(self.session_prefix, "")
sessions.append(session_data)
return sessions
async def invalidate_user_sessions(self, user_id: int) -> int:
"""注销用户的所有Session"""
pattern = f"{self.session_prefix}*"
keys = await self.redis.keys(pattern)
user_keys = []
for key in keys:
data = await self.redis.get(key)
if data:
session_data = json.loads(data)
if session_data.get("user_id") == user_id:
user_keys.append(key)
if user_keys:
result = await self.redis.delete(*user_keys)
logger.info(f"Invalidated {result} sessions for user {user_id}")
return result
return 0
# 全局Session管理器实例
session_manager = RedisSessionManager(get_redis_client())#Session中间件
# middleware/session_middleware.py
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
class SessionMiddleware(BaseHTTPMiddleware):
"""Session中间件 - 自动处理Session创建和验证"""
async def dispatch(self, request: Request, call_next):
# 从Cookie或Header获取Session ID
session_id = (
request.cookies.get("session_id") or
request.headers.get("X-Session-ID")
)
if session_id:
# 验证Session
session_data = await session_manager.get_session(session_id)
if session_data:
# 将Session数据附加到request
request.state.session = session_data
request.state.user_id = session_data.get("user_id")
else:
# Session无效,清除Cookie
request.state.session = None
request.state.user_id = None
else:
request.state.session = None
request.state.user_id = None
response = await call_next(request)
# 如果创建了新Session,在响应中设置Cookie
if hasattr(request.state, 'new_session_id'):
response.set_cookie(
"session_id",
request.state.new_session_id,
httponly=True,
secure=True, # 生产环境中应启用HTTPS
samesite="strict",
max_age=86400 * 7 # 7天
)
return response
# 在main.py中注册中间件
from middleware.session_middleware import SessionMiddleware
app.add_middleware(SessionMiddleware)#Session依赖注入
# dependencies/session_deps.py
from fastapi import Request, HTTPException, status
from typing import Optional, Dict, Any
from session.manager import session_manager
async def get_current_session(request: Request) -> Optional[Dict[str, Any]]:
"""获取当前Session依赖"""
session_data = getattr(request.state, 'session', None)
if not session_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录或Session已过期"
)
return session_data
async def get_current_user_id(request: Request) -> int:
"""获取当前用户ID依赖"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录或Session已过期"
)
return user_id
async def require_login(request: Request) -> Dict[str, Any]:
"""需要登录的依赖"""
session = await get_current_session(request)
user_id = await get_current_user_id(request)
# 刷新Session
session_id = session.get("session_id")
if session_id:
await session_manager.refresh_session(session_id)
return session
# 使用示例
from fastapi import APIRouter, Depends
router = APIRouter()
@router.get("/profile")
async def get_profile(
session: dict = Depends(require_login),
user_id: int = Depends(get_current_user_id)
):
"""需要登录的路由"""
return {
"user_id": user_id,
"session_info": {
"created_at": session.get("created_at"),
"last_accessed": session.get("last_accessed")
}
}
@router.post("/logout")
async def logout(request: Request):
"""登出路由"""
session_id = request.cookies.get("session_id")
if session_id:
await session_manager.destroy_session(session_id)
response = {"message": "登出成功"}
# 清除Cookie
from fastapi.responses import JSONResponse
resp = JSONResponse(content=response)
resp.delete_cookie("session_id")
return resp#异步消息队列实现
#任务队列系统
# queue/task_queue.py
import json
import uuid
import time
from typing import Dict, Any, Optional, Callable, List
from redis.asyncio import Redis
from datetime import datetime
import asyncio
import logging
logger = logging.getLogger(__name__)
class RedisTaskQueue:
"""Redis异步任务队列"""
def __init__(self, redis_client: Redis, queue_name: str = "default"):
self.redis = redis_client
self.queue_name = queue_name
self.task_queue = f"queue:{queue_name}:tasks"
self.processing_queue = f"queue:{queue_name}:processing"
self.delayed_queue = f"queue:{queue_name}:delayed"
self.retry_queue = f"queue:{queue_name}:retry"
self.dead_letter_queue = f"queue:{queue_name}:dead_letter"
# 任务配置
self.max_retries = 3
self.retry_delay = 60 # 重试延迟(秒)
self.visibility_timeout = 300 # 可见性超时(秒)
async def enqueue(
self,
task_name: str,
task_args: Dict[str, Any],
delay: Optional[int] = None,
priority: int = 0
) -> str:
"""入队任务"""
task_id = str(uuid.uuid4())
task_data = {
"id": task_id,
"name": task_name,
"args": task_args,
"priority": priority,
"created_at": time.time(),
"attempts": 0,
"max_retries": self.max_retries
}
if delay:
# 延迟任务,使用有序集合
await self.redis.zadd(
self.delayed_queue,
{json.dumps(task_data): time.time() + delay}
)
else:
# 立即任务,使用列表
await self.redis.rpush(
self.task_queue,
json.dumps(task_data)
)
logger.info(f"Task enqueued: {task_id} - {task_name}")
return task_id
async def dequeue(self, timeout: int = 5) -> Optional[Dict[str, Any]]:
"""出队任务"""
# 首先检查延迟队列
await self._move_delayed_tasks()
# 阻塞式弹出任务
result = await self.redis.blpop([self.task_queue], timeout=timeout)
if result:
_, task_json = result
task_data = json.loads(task_json)
# 移动到处理队列,设置可见性超时
await self.redis.zadd(
self.processing_queue,
{json.dumps(task_data): time.time() + self.visibility_timeout}
)
return task_data
return None
async def _move_delayed_tasks(self):
"""移动到期的延迟任务"""
now = time.time()
expired_tasks = await self.redis.zrangebyscore(
self.delayed_queue,
"-inf",
now
)
if expired_tasks:
pipe = self.redis.pipeline()
for task_json in expired_tasks:
pipe.lpush(self.task_queue, task_json)
pipe.zrem(self.delayed_queue, task_json)
await pipe.execute()
async def complete_task(self, task_data: Dict[str, Any]) -> bool:
"""完成任务"""
task_json = json.dumps(task_data)
result = await self.redis.zrem(self.processing_queue, task_json)
if result:
logger.info(f"Task completed: {task_data['id']}")
return bool(result)
async def fail_task(self, task_data: Dict[str, Any], error: str) -> bool:
"""任务失败处理"""
task_id = task_data["id"]
attempts = task_data.get("attempts", 0) + 1
if attempts <= task_data.get("max_retries", self.max_retries):
# 重试任务
task_data["attempts"] = attempts
task_data["error"] = error
task_data["retry_at"] = time.time() + self.retry_delay
await self.redis.zadd(
self.retry_queue,
{json.dumps(task_data): task_data["retry_at"]}
)
logger.warning(f"Task {task_id} failed, scheduled for retry {attempts}/{task_data['max_retries']}")
else:
# 移动到死信队列
task_data["final_error"] = error
await self.redis.rpush(self.dead_letter_queue, json.dumps(task_data))
logger.error(f"Task {task_id} moved to dead letter queue after {attempts} attempts")
# 从处理队列移除
await self.redis.zrem(self.processing_queue, json.dumps(task_data))
return True
async def get_queue_stats(self) -> Dict[str, int]:
"""获取队列统计信息"""
return {
"pending": await self.redis.llen(self.task_queue),
"processing": await self.redis.zcard(self.processing_queue),
"delayed": await self.redis.zcard(self.delayed_queue),
"retry": await self.redis.zcard(self.retry_queue),
"dead_letter": await self.redis.llen(self.dead_letter_queue)
}
async def process_task_with_handler(self, handler_func: Callable) -> bool:
"""使用处理器函数处理任务"""
task_data = await self.dequeue()
if not task_data:
return False
try:
# 执行任务处理函数
await handler_func(task_data)
# 完成任务
await self.complete_task(task_data)
return True
except Exception as e:
# 任务失败
await self.fail_task(task_data, str(e))
return False
# 全局任务队列实例
task_queue = RedisTaskQueue(get_redis_client(), "default")#任务处理器
# handlers/task_handlers.py
from typing import Dict, Any
import logging
logger = logging.getLogger(__name__)
class TaskHandlers:
"""任务处理器集合"""
@staticmethod
async def send_email_task(task_data: Dict[str, Any]):
"""发送邮件任务处理器"""
args = task_data["args"]
recipient = args.get("recipient")
subject = args.get("subject")
body = args.get("body")
logger.info(f"Sending email to {recipient}: {subject}")
# 实际的邮件发送逻辑
# await send_email(recipient, subject, body)
logger.info(f"Email sent successfully to {recipient}")
@staticmethod
async def process_user_registration_task(task_data: Dict[str, Any]):
"""用户注册处理任务"""
args = task_data["args"]
user_id = args.get("user_id")
logger.info(f"Processing registration for user {user_id}")
# 发送欢迎邮件
# await send_welcome_email(user_id)
# 更新用户状态
# await update_user_status(user_id, "active")
# 记录注册事件
# await log_registration_event(user_id)
logger.info(f"Registration processed for user {user_id}")
@staticmethod
async def cleanup_old_data_task(task_data: Dict[str, Any]):
"""清理旧数据任务"""
args = task_data["args"]
days_old = args.get("days_old", 30)
logger.info(f"Cleaning up data older than {days_old} days")
# 执行清理逻辑
# cleaned_count = await cleanup_data(days_old)
logger.info(f"Cleaned up old data: X records removed")
@staticmethod
async def generate_report_task(task_data: Dict[str, Any]):
"""生成报告任务"""
args = task_data["args"]
report_type = args.get("report_type")
start_date = args.get("start_date")
end_date = args.get("end_date")
logger.info(f"Generating {report_type} report for {start_date} to {end_date}")
# 生成报告逻辑
# report = await generate_report(report_type, start_date, end_date)
# 保存报告
# await save_report(report)
logger.info(f"Report generated: {report_type}")#后台任务处理器
# workers/background_worker.py
import asyncio
import signal
import sys
from typing import Dict, Any, Callable
import logging
logger = logging.getLogger(__name__)
class BackgroundWorker:
"""后台任务处理器"""
def __init__(self, task_queue, task_handlers):
self.task_queue = task_queue
self.task_handlers = task_handlers
self.running = False
self.task_mapping = {
"send_email": task_handlers.send_email_task,
"process_user_registration": task_handlers.process_user_registration_task,
"cleanup_old_data": task_handlers.cleanup_old_data_task,
"generate_report": task_handlers.generate_report_task,
}
async def start(self):
"""启动后台处理器"""
self.running = True
logger.info("Background worker started")
# 注册信号处理器
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down...")
self.running = False
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 主处理循环
while self.running:
try:
# 处理一个任务
handled = await self.task_queue.process_task_with_handler(
self._handle_single_task
)
if not handled:
# 没有任务时短暂休眠
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error in worker loop: {e}")
await asyncio.sleep(5) # 出错后等待5秒再继续
logger.info("Background worker stopped")
async def _handle_single_task(self, task_data: Dict[str, Any]):
"""处理单个任务"""
task_name = task_data["name"]
handler = self.task_mapping.get(task_name)
if not handler:
raise ValueError(f"No handler found for task: {task_name}")
logger.info(f"Processing task {task_data['id']}: {task_name}")
await handler(task_data)
logger.info(f"Completed task {task_data['id']}: {task_name}")
async def stop(self):
"""停止后台处理器"""
self.running = False
logger.info("Stopping background worker...")
# 启动后台处理器的函数
async def run_background_worker():
"""运行后台任务处理器"""
from queue.task_queue import task_queue
from handlers.task_handlers import TaskHandlers
worker = BackgroundWorker(task_queue, TaskHandlers())
await worker.start()
# 在应用启动时运行后台处理器(可选)
# if __name__ == "__main__":
# import asyncio
# asyncio.run(run_background_worker())#高级Redis数据结构应用
#Sorted Set实现排行榜
# data_structures/rankings.py
import time
from typing import List, Tuple, Optional
from redis.asyncio import Redis
import json
class RankingSystem:
"""基于Sorted Set的排行榜系统"""
def __init__(self, redis_client: Redis, name: str):
self.redis = redis_client
self.key = f"ranking:{name}"
async def add_score(self, member: str, score: float):
"""添加分数"""
await self.redis.zadd(self.key, {member: score})
async def increment_score(self, member: str, increment: float) -> float:
"""增量更新分数"""
new_score = await self.redis.zincrby(self.key, increment, member)
return new_score
async def get_rank(self, member: str) -> Optional[int]:
"""获取成员排名(从1开始)"""
rank = await self.redis.zrevrank(self.key, member)
return rank + 1 if rank is not None else None
async def get_score(self, member: str) -> Optional[float]:
"""获取成员分数"""
score = await self.redis.zscore(self.key, member)
return score
async def get_top_n(self, n: int) -> List[Tuple[str, float]]:
"""获取前N名"""
members = await self.redis.zrevrange(self.key, 0, n-1, withscores=True)
return [(member.decode() if isinstance(member, bytes) else member, score)
for member, score in members]
async def get_range(self, start: int, end: int) -> List[Tuple[str, float]]:
"""获取指定范围的排名"""
members = await self.redis.zrevrange(self.key, start, end, withscores=True)
return [(member.decode() if isinstance(member, bytes) else member, score)
for member, score in members]
async def get_members_by_score(self, min_score: float, max_score: float) -> List[Tuple[str, float]]:
"""根据分数范围获取成员"""
members = await self.redis.zrangebyscore(
self.key, min_score, max_score, withscores=True
)
return [(member.decode() if isinstance(member, bytes) else member, score)
for member, score in members]
async def remove_member(self, member: str) -> bool:
"""移除成员"""
result = await self.redis.zrem(self.key, member)
return bool(result)
async def get_total_count(self) -> int:
"""获取总成员数"""
return await self.redis.zcard(self.key)
async def get_percentile_rank(self, member: str) -> Optional[float]:
"""获取成员百分位排名"""
total = await self.get_total_count()
if total == 0:
return None
rank = await self.get_rank(member)
if rank is None:
return None
return (total - rank) / total * 100
# 使用示例
class GameRankingSystem:
"""游戏排行榜系统"""
def __init__(self, redis_client: Redis):
self.score_ranking = RankingSystem(redis_client, "game_scores")
self.win_ranking = RankingSystem(redis_client, "game_wins")
self.playtime_ranking = RankingSystem(redis_client, "game_playtime")
async def update_player_stats(self, player_id: str, score: float, wins: int, playtime: float):
"""更新玩家统计数据"""
# 更新各项排名
await self.score_ranking.add_score(player_id, score)
await self.win_ranking.add_score(player_id, wins)
await self.playtime_ranking.add_score(player_id, playtime)
async def get_player_overview(self, player_id: str) -> dict:
"""获取玩家综合排名信息"""
return {
"player_id": player_id,
"score": await self.score_ranking.get_score(player_id),
"score_rank": await self.score_ranking.get_rank(player_id),
"wins": await self.win_ranking.get_score(player_id),
"wins_rank": await self.win_ranking.get_rank(player_id),
"playtime": await self.playtime_ranking.get_score(player_id),
"playtime_rank": await self.playtime_ranking.get_rank(player_id),
}#HyperLogLog实现UV统计
# data_structures/uv_counter.py
import time
from redis.asyncio import Redis
from datetime import datetime, timedelta
class UVCounter:
"""基于HyperLogLog的UV统计器"""
def __init__(self, redis_client: Redis, name: str):
self.redis = redis_client
self.name = name
def _get_key(self, date: str) -> str:
"""获取日期对应的key"""
return f"uv:{self.name}:{date}"
def _get_period_key(self, period: str) -> str:
"""获取周期对应的key"""
return f"uv:{self.name}:{period}"
async def add_visitor(self, visitor_id: str, date: str = None):
"""添加访客"""
if date is None:
date = datetime.now().strftime("%Y-%m-%d")
key = self._get_key(date)
await self.redis.pfadd(key, visitor_id)
# 同时添加到周期统计
today = datetime.now().date()
if datetime.strptime(date, "%Y-%m-%d").date() == today:
# 今日数据也计入周、月统计
week_key = self._get_period_key("week")
month_key = self._get_period_key("month")
await self.redis.pfadd(week_key, visitor_id)
await self.redis.pfadd(month_key, visitor_id)
async def get_daily_uv(self, date: str = None) -> int:
"""获取指定日期UV"""
if date is None:
date = datetime.now().strftime("%Y-%m-%d")
key = self._get_key(date)
return await self.redis.pfcount(key)
async def get_period_uv(self, period: str) -> int:
"""获取周期UV(week/month/year)"""
key = self._get_period_key(period)
return await self.redis.pfcount(key)
async def get_uv_trend(self, days: int = 7) -> List[Tuple[str, int]]:
"""获取UV趋势"""
trend = []
for i in range(days):
date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
uv = await self.get_daily_uv(date)
trend.append((date, uv))
return list(reversed(trend)) # 从最早到最近
# 使用示例
class AnalyticsService:
"""分析服务"""
def __init__(self, redis_client: Redis):
self.page_uv = UVCounter(redis_client, "page_views")
self.user_uv = UVCounter(redis_client, "unique_users")
self.region_uv = UVCounter(redis_client, "region_visits")
async def record_page_view(self, user_id: str, page: str, region: str):
"""记录页面访问"""
date = datetime.now().strftime("%Y-%m-%d")
# 记录页面UV
await self.page_uv.add_visitor(f"{page}:{user_id}", date)
# 记录用户UV
await self.user_uv.add_visitor(user_id, date)
# 记录地区UV
await self.region_uv.add_visitor(f"{region}:{user_id}", date)#BitMap实现签到系统
# data_structures/checkin_system.py
import time
from datetime import datetime, date
from redis.asyncio import Redis
class CheckInSystem:
"""基于BitMap的签到系统"""
def __init__(self, redis_client: Redis, name: str = "checkin"):
self.redis = redis_client
self.name = name
def _get_key(self, year_month: str) -> str:
"""获取年月对应的key"""
return f"bitmap:{self.name}:{year_month}"
def _get_day_offset(self, day: int) -> int:
"""获取日期在月份中的偏移量"""
return day - 1
async def check_in(self, user_id: str, check_date: date = None):
"""用户签到"""
if check_date is None:
check_date = date.today()
year_month = check_date.strftime("%Y-%m")
day_offset = self._get_day_offset(check_date.day)
key = self._get_key(year_month)
await self.redis.setbit(key, day_offset, 1)
# 设置过期时间(次月清零)
next_month = (check_date.replace(day=1) + timedelta(days=32)).replace(day=1)
expire_time = int((next_month - datetime.now()).total_seconds())
await self.redis.expire(key, expire_time)
async def is_checked_in(self, user_id: str, check_date: date = None) -> bool:
"""检查是否已签到"""
if check_date is None:
check_date = date.today()
year_month = check_date.strftime("%Y-%m")
day_offset = self._get_day_offset(check_date.day)
key = self._get_key(year_month)
return bool(await self.redis.getbit(key, day_offset))
async def get_checkin_days(self, user_id: str, target_month: str) -> int:
"""获取指定月份签到天数"""
key = self._get_key(target_month)
return await self.redis.bitcount(key)
async def get_continuous_checkin_streak(self, user_id: str, check_dates: list) -> int:
"""获取连续签到天数"""
streak = 0
for check_date in reversed(check_dates):
if await self.is_checked_in(user_id, check_date):
streak += 1
else:
break
return streak
async def get_monthly_report(self, user_id: str, target_month: str) -> dict:
"""获取月度签到报告"""
key = self._get_key(target_month)
# 获取总签到天数
total_days = await self.redis.bitcount(key)
# 获取每日签到情况
daily_status = []
for day in range(1, 32): # 最多31天
try:
date_str = f"{target_month}-{day:02d}"
check_date = datetime.strptime(date_str, "%Y-%m-%d").date()
is_checked = await self.is_checked_in(user_id, check_date)
daily_status.append({"day": day, "checked": is_checked})
except ValueError:
# 无效日期(如2月30日),跳出循环
break
return {
"month": target_month,
"total_checkins": total_days,
"daily_status": daily_status,
"percentage": round(total_days / len(daily_status) * 100, 2) if daily_status else 0
}#性能监控与调优
#Redis性能监控
# monitoring/redis_monitor.py
from typing import Dict, Any, Optional
import time
import asyncio
from redis.asyncio import Redis
import logging
logger = logging.getLogger(__name__)
class RedisMonitor:
"""Redis性能监控器"""
def __init__(self, redis_client: Redis):
self.redis = redis_client
self.metrics = {}
async def get_server_info(self) -> Dict[str, Any]:
"""获取Redis服务器信息"""
info = await self.redis.info()
return {
"redis_version": info.get("redis_version"),
"uptime_in_seconds": info.get("uptime_in_seconds"),
"connected_clients": info.get("connected_clients"),
"used_memory": info.get("used_memory"),
"used_memory_human": info.get("used_memory_human"),
"used_memory_peak": info.get("used_memory_peak"),
"used_memory_peak_human": info.get("used_memory_peak_human"),
"total_commands_processed": info.get("total_commands_processed"),
"instantaneous_ops_per_sec": info.get("instantaneous_ops_per_sec"),
"keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"),
"hit_rate": (
info.get("keyspace_hits", 0) /
max(info.get("keyspace_hits", 0) + info.get("keyspace_misses", 1), 1) * 100
)
}
async def get_command_stats(self) -> Dict[str, Any]:
"""获取命令统计信息"""
info = await self.redis.info("commandstats")
command_stats = {}
for key, value in info.items():
if key.startswith("cmdstat_"):
cmd = key.replace("cmdstat_", "")
command_stats[cmd] = {
"calls": value.get("calls", 0),
"usec": value.get("usec", 0),
"usec_per_call": value.get("usec_per_call", 0)
}
return command_stats
async def get_slowlog(self, count: int = 10) -> list:
"""获取慢查询日志"""
slowlog = await self.redis.slowlog_get(count)
return [
{
"id": entry["id"],
"timestamp": entry["time"],
"duration": entry["duration"],
"command": " ".join(entry["command"])
}
for entry in slowlog
]
async def monitor_performance(self, interval: int = 60) -> None:
"""持续监控性能"""
while True:
try:
info = await self.get_server_info()
logger.info(f"Redis Performance: {info}")
# 检查内存使用情况
used_memory = info.get("used_memory", 0)
memory_limit = 1024 * 1024 * 512 # 512MB
if used_memory > memory_limit * 0.8: # 超过80%警告
logger.warning(f"High memory usage: {info.get('used_memory_human')}")
# 检查命中率
hit_rate = info.get("hit_rate", 0)
if hit_rate < 80: # 命中率低于80%警告
logger.warning(f"Low cache hit rate: {hit_rate:.2f}%")
await asyncio.sleep(interval)
except Exception as e:
logger.error(f"Error monitoring Redis: {e}")
await asyncio.sleep(interval)
# 在应用启动时启动监控
async def start_redis_monitoring():
"""启动Redis监控"""
monitor = RedisMonitor(get_redis_client())
await monitor.monitor_performance()#缓存性能优化
# optimization/cache_optimizer.py
from typing import Dict, Any, List
import time
from redis.asyncio import Redis
import logging
logger = logging.getLogger(__name__)
class CacheOptimizer:
"""缓存性能优化器"""
def __init__(self, redis_client: Redis):
self.redis = redis_client
self.access_patterns = {}
async def analyze_cache_usage(self) -> Dict[str, Any]:
"""分析缓存使用模式"""
# 获取所有key
keys = await self.redis.keys("*")
analysis = {
"total_keys": len(keys),
"key_patterns": {},
"memory_usage": {},
"expiration_analysis": {}
}
for key in keys:
# 分析key模式
key_str = key if isinstance(key, str) else key.decode()
pattern = ":".join(key_str.split(":")[:-1]) if ":" in key_str else key_str
if pattern not in analysis["key_patterns"]:
analysis["key_patterns"][pattern] = 0
analysis["key_patterns"][pattern] += 1
# 获取内存使用情况
memory_usage = await self.redis.memory_usage(key)
if memory_usage:
analysis["memory_usage"][key_str] = memory_usage
# 分析过期时间分布
expirations = {}
for key in keys:
ttl = await self.redis.ttl(key)
ttl_range = "no_expire" if ttl == -1 else f"{ttl // 3600}h" # 按小时分组
expirations[ttl_range] = expirations.get(ttl_range, 0) + 1
analysis["expiration_analysis"] = expirations
return analysis
async def suggest_optimizations(self) -> List[str]:
"""建议优化措施"""
suggestions = []
analysis = await self.analyze_cache_usage()
# 检查内存使用
total_keys = analysis["total_keys"]
if total_keys > 10000:
suggestions.append("缓存key数量过多,考虑使用更大的过期时间或分片")
# 检查内存占用大的key
large_keys = sorted(
analysis["memory_usage"].items(),
key=lambda x: x[1],
reverse=True
)[:5]
for key, size in large_keys:
if size > 1024 * 100: # 超过100KB
suggestions.append(f"Key '{key}' size is large: {size} bytes, consider compression")
# 检查过期策略
no_expire_count = analysis["expiration_analysis"].get("no_expire", 0)
if no_expire_count > total_keys * 0.5: # 50%以上的key不过期
suggestions.append("大量key不过期,建议设置合理的过期时间防止内存泄漏")
# 检查key模式
for pattern, count in analysis["key_patterns"].items():
if count > 1000: # 某种模式的key过多
suggestions.append(f"Pattern '{pattern}' has {count} keys, consider optimization")
return suggestions
async def optimize_cache_strategy(self, strategy: str = "auto") -> Dict[str, Any]:
"""优化缓存策略"""
if strategy == "auto":
suggestions = await self.suggest_optimizations()
results = []
for suggestion in suggestions:
if "large" in suggestion.lower():
# 压缩大key的建议
results.append("Consider compressing large values")
elif "expire" in suggestion.lower():
# 过期时间优化
results.append("Adjust expiration times for better memory management")
elif "pattern" in suggestion.lower():
# 模式优化
results.append("Review key naming patterns for efficiency")
return {"suggestions": suggestions, "actions": results}
return {"suggestions": [], "actions": []}#安全配置与最佳实践
#Redis安全配置
# security/redis_security.py
import redis.asyncio as redis
from redis.exceptions import AuthenticationError, ConnectionError
import logging
logger = logging.getLogger(__name__)
class RedisSecurity:
"""Redis安全配置管理"""
@staticmethod
def create_secure_connection(
host: str = "localhost",
port: int = 6379,
password: str = None,
username: str = None,
ssl: bool = False,
ssl_cert_reqs: str = "required",
ssl_ca_certs: str = None,
ssl_certfile: str = None,
ssl_keyfile: str = None,
decode_responses: bool = True,
encoding: str = "utf-8"
) -> redis.Redis:
"""创建安全的Redis连接"""
connection_kwargs = {
"host": host,
"port": port,
"decode_responses": decode_responses,
"encoding": encoding,
}
if password:
connection_kwargs["password"] = password
if username:
connection_kwargs["username"] = username
# SSL配置
if ssl:
connection_kwargs.update({
"ssl": True,
"ssl_cert_reqs": ssl_cert_reqs,
})
if ssl_ca_certs:
connection_kwargs["ssl_ca_certs"] = ssl_ca_certs
if ssl_certfile:
connection_kwargs["ssl_certfile"] = ssl_certfile
if ssl_keyfile:
connection_kwargs["ssl_keyfile"] = ssl_keyfile
return redis.Redis(**connection_kwargs)
@staticmethod
async def validate_connection(client: redis.Redis) -> bool:
"""验证Redis连接安全性"""
try:
# 测试连接
await client.ping()
# 检查服务器配置
info = await client.info()
# 检查绑定地址(不应是0.0.0.0除非必要)
# 注意:Redis服务器配置检查需要通过其他方式
# 这里主要是客户端连接验证
return True
except AuthenticationError:
logger.error("Redis authentication failed")
return False
except ConnectionError:
logger.error("Redis connection failed")
return False
except Exception as e:
logger.error(f"Redis security validation failed: {e}")
return False
# 安全的Redis客户端工厂
class SecureRedisClient:
"""安全Redis客户端工厂"""
def __init__(self, config):
self.config = config
self._client = None
async def get_client(self) -> redis.Redis:
"""获取安全配置的Redis客户端"""
if not self._client:
self._client = RedisSecurity.create_secure_connection(
host=self.config.redis_host,
port=self.config.redis_port,
password=self.config.redis_password,
ssl=self.config.redis_ssl_enabled
)
# 验证连接安全性
if not await RedisSecurity.validate_connection(self._client):
raise ConnectionError("Failed to establish secure Redis connection")
return self._client#缓存安全最佳实践
# security/cache_security.py
import hashlib
import hmac
import secrets
from typing import Any, Dict
import json
import logging
logger = logging.getLogger(__name__)
class SecureCache:
"""安全缓存实现"""
def __init__(self, redis_client, secret_key: str):
self.redis = redis_client
self.secret_key = secret_key
def _generate_cache_key(self, base_key: str, user_id: str = None) -> str:
"""生成安全的缓存key"""
if user_id:
# 用户特定的缓存key
key_material = f"{base_key}:{user_id}:{self.secret_key}"
else:
# 公共缓存key
key_material = f"{base_key}:{self.secret_key}"
# 使用SHA256生成固定长度的key
hashed_key = hashlib.sha256(key_material.encode()).hexdigest()
return f"secure:{base_key}:{hashed_key[:16]}"
def _encrypt_data(self, data: Any) -> str:
"""加密缓存数据"""
# 简化的加密示例,实际应用中应使用更强的加密
json_data = json.dumps(data, default=str)
encrypted = json_data # 实际加密逻辑
return encrypted
def _decrypt_data(self, encrypted_data: str) -> Any:
"""解密缓存数据"""
# 简化的解密示例
decrypted = encrypted_data # 实际解密逻辑
return json.loads(decrypted)
async def set_secure(self, key: str, value: Any, ttl: int = 300, user_id: str = None):
"""安全地设置缓存"""
cache_key = self._generate_cache_key(key, user_id)
encrypted_value = self._encrypt_data(value)
await self.redis.setex(cache_key, ttl, encrypted_value)
async def get_secure(self, key: str, user_id: str = None) -> Any:
"""安全地获取缓存"""
cache_key = self._generate_cache_key(key, user_id)
encrypted_value = await self.redis.get(cache_key)
if encrypted_value:
return self._decrypt_data(encrypted_value)
return None
async def invalidate_user_cache(self, user_id: str, pattern: str = "*"):
"""清除用户特定的缓存"""
search_pattern = f"secure:{pattern}:{user_id}:*"
keys = await self.redis.keys(search_pattern)
if keys:
await self.redis.delete(*keys)
# 使用示例
class SecureUserService:
"""安全用户服务"""
def __init__(self, redis_client, secret_key: str):
self.secure_cache = SecureCache(redis_client, secret_key)
async def get_user_profile_secure(self, user_id: int):
"""安全获取用户资料"""
cache_key = f"user:profile:{user_id}"
cached_data = await self.secure_cache.get_secure(cache_key, str(user_id))
if cached_data:
return cached_data
# 从数据库获取数据
from repositories.user_repository import UserRepository
user_repo = UserRepository()
user = await user_repo.get_by_id(user_id)
if user:
profile_data = {
"id": user.id,
"name": user.name,
"email": user.email,
"created_at": user.created_at.isoformat()
}
# 安全地缓存数据
await self.secure_cache.set_secure(
cache_key, profile_data, ttl=900, user_id=str(user_id)
)
return profile_data
return None#常见陷阱与避坑指南
#陷阱1:缓存穿透
# ❌ 错误:缓存穿透
async def get_user_bad(user_id: int):
# 恶意用户请求大量不存在的ID
# 每次都查询数据库,造成数据库压力
user = await db.get_user(user_id)
if not user:
return None # 空值不缓存,导致持续查询数据库
return user
# ✅ 正确:缓存空值
async def get_user_good(user_id: int):
# 检查缓存
cached_user = await redis.get(f"user:{user_id}")
if cached_user is not None:
return json.loads(cached_user) if cached_user != "NULL" else None
# 查询数据库
user = await db.get_user(user_id)
# 无论是否存在都缓存结果
if user:
await redis.setex(f"user:{user_id}", 300, json.dumps(user))
else:
# 缓存空值,防止缓存穿透,但设置较短过期时间
await redis.setex(f"user:{user_id}", 60, "NULL")
return user#陷阱2:缓存雪崩
# ❌ 错误:缓存雪崩
async def load_hot_data_bad():
# 大量缓存同时过期,导致数据库瞬间压力过大
tasks = []
for i in range(1000):
if not await redis.exists(f"data:{i}"):
# 同时查询数据库
task = db.load_data(i)
tasks.append(task)
return await asyncio.gather(*tasks)
# ✅ 正确:随机过期时间
async def load_hot_data_good():
async def load_with_jitter(data_id: int):
cache_key = f"data:{data_id}"
cached_data = await redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
# 加载数据
data = await db.load_data(data_id)
# 设置随机过期时间(300-600秒之间)
jitter = secrets.randbelow(300) # 0-299秒随机抖动
await redis.setex(cache_key, 300 + jitter, json.dumps(data))
return data
tasks = [load_with_jitter(i) for i in range(1000)]
return await asyncio.gather(*tasks)#陷阱3:缓存击穿
# ❌ 错误:缓存击穿
async def get_hot_article_bad(article_id: int):
article = await redis.get(f"article:{article_id}")
if not article:
# 热点数据过期时,大量请求同时查询数据库
article = await db.get_article(article_id)
await redis.setex(f"article:{article_id}", 3600, json.dumps(article))
return article
# ✅ 正确:分布式锁
async def get_hot_article_good(article_id: int):
cache_key = f"article:{article_id}"
lock_key = f"lock:article:{article_id}"
# 尝试获取分布式锁
lock_acquired = await redis.set(lock_key, "1", nx=True, ex=10) # 10秒锁
if lock_acquired:
try:
# 双重检查
article = await redis.get(cache_key)
if article:
return json.loads(article)
# 加载数据
article = await db.get_article(article_id)
await redis.setex(cache_key, 3600, json.dumps(article))
return article
finally:
# 释放锁
await redis.delete(lock_key)
else:
# 未获取到锁,等待后重试或返回旧数据
await asyncio.sleep(0.1)
article = await redis.get(cache_key)
return json.loads(article) if article else None#陷阱4:内存泄漏
# ❌ 错误:没有设置过期时间
async def cache_without_expire():
# 无限期缓存数据,可能导致内存泄漏
await redis.set("permanent_data", json.dumps(data))
# ✅ 正确:设置合理的过期时间
async def cache_with_expire():
# 设置合理的过期时间
await redis.setex("temp_data", 3600, json.dumps(data)) # 1小时过期#陷阱5:序列化问题
# ❌ 错误:序列化复杂对象
async def cache_complex_object_bad():
# datetime等复杂对象可能无法正确序列化
data = {"timestamp": datetime.now(), "user": User()}
await redis.setex("complex", 300, json.dumps(data)) # 可能失败
# ✅ 正确:处理序列化
async def cache_complex_object_good():
data = {
"timestamp": datetime.now().isoformat(),
"user": user.dict() if hasattr(user, 'dict') else str(user)
}
await redis.setex("complex", 300, json.dumps(data, default=str))#与其他缓存方案对比
#Redis vs Memcached
| 特性 | Redis | Memcached |
|---|---|---|
| 数据结构 | 支持多种(String, Hash, List, Set, ZSet等) | 仅支持String |
| 持久化 | ✅ 支持RDB和AOF持久化 | ❌ 仅内存存储 |
| 数据过期 | ✅ 精确到key的过期 | ✅ 支持过期时间 |
| 内存效率 | ⚠️ 相对较低(功能丰富) | ✅ 更高的内存效率 |
| 集群支持 | ✅ 原生集群模式 | ✅ 一致性哈希 |
| 事务支持 | ✅ 支持事务 | ❌ 无事务 |
| 发布订阅 | ✅ 支持Pub/Sub | ❌ 不支持 |
| 适用场景 | 复杂缓存需求、会话存储、消息队列 | 简单KV缓存 |
#Redis vs 本地缓存
| 特性 | Redis | 本地缓存 (如c
#1. Redis 快速入门
#1.1 为什么 Web 服务需要 Redis?
请求 → FastAPI → 数据库查询 → 返回
缓存后:
请求 → FastAPI → Redis(命中)→ 直接返回
请求 → FastAPI → Redis(未命中)→ 数据库 → 写入Redis → 返回Redis 作用:把频繁查询但不常变化的数据(如用户信息、文章列表)放在内存里,查询速度比数据库快 10-100 倍。
#1.2 安装与连接
pip install redis[hiredis]
# hiredis 是 C 语言实现的高性能解析器,推荐安装# redis_client.py
import redis.asyncio as redis
from config import get_settings
settings = get_settings()
# 异步 Redis 连接
redis_client = redis.from_url(
settings.redis_url, # "redis://localhost:6379/0"
encoding="utf-8",
decode_responses=True, # 自动将 bytes 解码为 str
)
# 测试连接
async def test_redis():
await redis_client.ping()
print("✅ Redis 连接成功")#2. 缓存实现
#2.1 通用缓存装饰器
# cache.py
import json
from functools import wraps
from redis.asyncio import Redis
def cached(
key_prefix: str,
expire: int = 300, # 默认 5 分钟过期
):
"""
异步缓存装饰器
用法:
@cached("user:{user_id}", expire=600)
async def get_user(user_id: int):
return await db.get_user(user_id)
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# 构造缓存 key
cache_key = key_prefix.format(*args, **kwargs)
# 尝试从 Redis 获取
cached_value = await redis_client.get(cache_key)
if cached_value:
return json.loads(cached_value)
# 未命中,执行原函数
result = await func(*args, **kwargs)
# 存入 Redis
await redis_client.setex(
cache_key,
expire,
json.dumps(result, default=str)
)
return result
return wrapper
return decorator
async def invalidate_cache(pattern: str):
"""清除匹配的缓存键"""
keys = await redis_client.keys(pattern)
if keys:
await redis_client.delete(*keys)#2.2 用户信息缓存
from cache import cached, invalidate_cache
# 缓存用户信息,5 分钟内重复查询直接返回缓存
@cached("user:{user_id}", expire=300)
async def get_user_cached(user_id: int):
user = await db.get_user(user_id)
return {
"id": user.id,
"name": user.name,
"email": user.email,
"avatar": user.avatar,
}
# 用户更新后清除缓存
async def update_user(user_id: int, data: dict):
await db.update_user(user_id, data)
# 清除该用户的缓存
await invalidate_cache(f"user:{user_id}")#2.3 页面缓存(整页缓存)
@app.get("/posts/{slug}")
async def get_post(slug: str):
cache_key = f"post:page:{slug}"
# 尝试获取缓存
cached_html = await redis_client.get(cache_key)
if cached_html:
return HTMLResponse(content=cached_html)
# 生成页面
post = await get_post_data(slug)
html = render_post_template(post)
# 缓存 10 分钟
await redis_client.setex(cache_key, 600, html)
return HTMLResponse(content=html)#3. 分布式 Session
#3.1 基于 Redis 的 Session 管理
# session.py
import uuid
import json
from datetime import timedelta
from redis.asyncio import Redis
class RedisSession:
def __init__(self, redis: Redis):
self.redis = redis
self.prefix = "session:"
self.expire = 86400 * 7 # 7 天过期
async def create(self, user_id: int, extra_data: dict = None) -> str:
"""创建新 Session,返回 session_id"""
session_id = str(uuid.uuid4())
session_key = self.prefix + session_id
data = {"user_id": user_id, **(extra_data or {})}
await self.redis.setex(
session_key,
self.expire,
json.dumps(data)
)
return session_id
async def get(self, session_id: str) -> dict | None:
"""获取 Session 数据"""
session_key = self.prefix + session_id
data = await self.redis.get(session_key)
if data:
return json.loads(data)
return None
async def refresh(self, session_id: str) -> bool:
"""刷新 Session 过期时间"""
session_key = self.prefix + session_id
return await self.redis.expire(session_key, self.expire)
async def destroy(self, session_id: str) -> bool:
"""销毁 Session(登出)"""
session_key = self.prefix + session_id
return await self.redis.delete(session_key) > 0
session_manager = RedisSession(redis_client)#3.2 Session 认证依赖
from fastapi import Cookie, HTTPException
async def get_current_user_via_session(session_id: str = Cookie(None)):
if not session_id:
raise HTTPException(401, "请先登录")
session = await session_manager.get(session_id)
if not session:
raise HTTPException(401, "Session 已过期,请重新登录")
await session_manager.refresh(session_id) # 续期
return session
@app.get("/profile")
async def profile(user: dict = Depends(get_current_user_via_session)):
return {"user_id": user["user_id"], "data": user}#4. 简单的消息队列
#4.1 任务队列:List 实现
# queue.py
TASK_QUEUE = "fastapi:task_queue"
async def enqueue_task(task_data: dict):
"""入队:添加任务到队列尾部"""
import json
await redis_client.rpush(TASK_QUEUE, json.dumps(task_data))
async def dequeue_task(blocking: bool = True, timeout: int = 5):
"""出队:从队列头部取任务"""
import json
if blocking:
result = await redis_client.blpop(TASK_QUEUE, timeout=timeout)
if result:
_, task = result
return json.loads(task)
else:
task = await redis_client.lpop(TASK_QUEUE)
if task:
return json.loads(task)
return None
# 后台 worker
async def process_tasks():
while True:
task = await dequeue_task(blocking=True, timeout=0)
if task:
print(f"处理任务: {task}")
# await process_task(task)#4.2 延迟队列:ZSet 实现
DELAY_QUEUE = "fastapi:delay_queue"
DELAY_TEMP_QUEUE =
