#FastAPI中间件完全指南
📂 所属阶段:第二阶段 — 进阶黑科技(核心篇)
🔗 相关章节:FastAPI异步编程深度解析 · FastAPI异常处理
#目录
- 中间件基础概念
- 中间件生命周期与执行顺序
- 自定义中间件开发
- CORS跨域处理
- GZip响应压缩
- 请求日志记录
- 性能监控与耗时统计
- 安全中间件
- 错误处理中间件
- 会话管理中间件
- 中间件最佳实践
- 生产环境部署
- 与其他框架对比
- 总结
#中间件基础概念
#什么是中间件?
中间件是在请求到达路由处理器之前和响应返回给客户端之后执行的代码。它可以用来执行各种任务,如身份验证、日志记录、错误处理、跨域资源共享(CORS)等。
#中间件在请求生命周期中的位置
请求到达
↓
┌─────────────────────────────────────────┐
│ 中间件 1(记录日志) │
│ ↓ │
│ 中间件 2(CORS 处理) │
│ ↓ │
│ 中间件 3(GZip 压缩) │
│ ↓ │
│ 路由处理函数 ──── 返回响应 ────→ 反向经过所有中间件 │
└─────────────────────────────────────────┘中间件是一个"包裹"着应用的函数,请求和响应都会经过它。你可以把它理解为请求的"安检通道"和"包装工厂"。
#FastAPI 中间件 vs 依赖注入
| 特性 | 中间件 | 依赖注入 |
|---|---|---|
| 触发时机 | 每个请求都自动经过 | 按需注入到路由参数 |
| 执行顺序 | 按注册顺序(先注册先执行) | 与路由函数参数对应 |
| 用途 | 全局拦截、日志、跨域、压缩 | 提取参数、认证、数据库 |
| 无法获取 | 路由函数的返回值 | 无法拦截全局请求 |
| 性能影响 | 影响所有请求 | 只影响使用了依赖的路由 |
| 错误处理 | 可以捕获和处理异常 | 依赖注入函数中的异常会传播到路由 |
| 返回响应 | 可以提前返回响应 | 不能直接返回响应 |
| 访问响应对象 | 可以访问和修改响应 | 无法访问响应对象 |
#中间件的优势
- 全局处理:对所有请求进行统一处理
- 代码复用:避免在每个路由中重复相同的逻辑
- 关注点分离:将横切关注点(如日志、安全)与业务逻辑分离
- 性能优化:如压缩、缓存等
- 安全防护:如CSRF防护、IP限制等
- 监控和追踪:性能监控、请求追踪等
#1. 中间件是什么?
#1.1 中间件在请求生命周期中的位置
请求到达
↓
┌─────────────────────────────────────────┐
│ 中间件 1(记录日志) │
│ ↓ │
│ 中间件 2(CORS 处理) │
│ ↓ │
│ 中间件 3(GZip 压缩) │
│ ↓ │
│ 路由处理函数 ──── 返回响应 ────→ 反向经过所有中间件 │
└─────────────────────────────────────────┘中间件是一个"包裹"着应用的函数,请求和响应都会经过它。你可以把它理解为请求的"安检通道"和"包装工厂"。
#1.2 FastAPI 中间件 vs 依赖注入
| 中间件 | 依赖注入 | |
|---|---|---|
| 触发时机 | 每个请求都自动经过 | 按需注入到路由参数 |
| 执行顺序 | 按注册顺序(先注册先执行) | 与路由函数参数对应 |
| 用途 | 全局拦截、日志、跨域、压缩 | 提取参数、认证、数据库 |
| 无法获取 | 路由函数的返回值 | 无法拦截全局请求 |
#2. 自定义中间件
#2.1 基础语法
from fastapi import FastAPI, Request
import time
app = FastAPI()
# 自定义中间件
@app.middleware("http")
async def middleware_name(request: Request, call_next):
# ── 请求前处理 ──
print(f"收到请求: {request.url.path}")
# call_next 传递请求到下一个中间件或路由
response = await call_next(request)
# ── 响应后处理 ──
print(f"响应状态: {response.status_code}")
return response#2.2 请求耗时统计中间件
@app.middleware("http")
async def add_process_time(request: Request, call_next):
start_time = time.perf_counter()
# 执行业务逻辑
response = await call_next(request)
# 计算耗时
process_time = time.perf_counter() - start_time
response.headers["X-Process-Time"] = str(round(process_time * 1000, 2)) + "ms"
return response#2.3 日志记录中间件
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.middleware("http")
async def log_requests(request: Request, call_next):
# 请求前
logger.info(f"📥 {request.method} {request.url.path}")
# 执行业务逻辑
response = await call_next(request)
# 请求后
logger.info(
f"📤 {request.method} {request.url.path} → {response.status_code}"
)
return response#3. CORS 跨域中间件
#3.1 为什么需要 CORS?
浏览器出于安全考虑,默认禁止一个域名(如 https://app.example.com)的网页向另一个域名(如 https://api.example.com)发请求。CORS(Cross-Origin Resource Sharing)就是让服务器声明"允许哪些外域来访问我"。
#3.2 FastAPI 内置 CORS 中间件
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://daomanpy.com", # 允许特定域名
"https://www.daomanpy.com",
"http://localhost:3000", # 开发环境
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"], # 允许所有方法,或指定 ["GET", "POST"]
allow_headers=["*"], # 允许所有请求头,或指定 ["Authorization"]
)
# 所有路由自动支持 CORS
@app.get("/api/data")
async def get_data():
return {"hello": "跨域成功!"}#3.3 不同场景的 CORS 配置
# 场景一:允许所有来源(仅开发环境使用)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 场景二:生产环境(推荐,明确指定域名)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://daomanpy.com",
"https://www.daomanpy.com",
],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
# 场景三:允许携带 Cookie(必须指定具体 origins,不能用 *)
app.add_middleware(
CORSMiddleware,
allow_origins=["https://daomanpy.com"],
allow_credentials=True, # 必须为 True
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Request-ID"], # 暴露给前端的响应头
max_age=600, # 预检请求缓存时间(秒)
)#4. GZip 压缩中间件
#4.1 Starlette 内置 GZip 中间件
from fastapi import FastAPI
from starlette.middleware.gzip import GZipMiddleware
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)minimum_size:响应体大于此字节数才压缩(默认 500 字节)- 自动对文本内容进行压缩
#5. BaseHTTPMiddleware 详解
#5.1 继承 BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 请求前处理
start_time = time.time()
# 调用下一个中间件或路由
response = await call_next(request)
# 响应后处理
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 注册中间件
app.add_middleware(CustomMiddleware)#5.2 高级中间件示例
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from starlette.requests import Request
import json
import time
class AdvancedLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 请求信息记录
request_id = str(uuid.uuid4())
start_time = time.time()
# 添加请求ID到请求状态
request.state.request_id = request_id
# 记录请求
print(f"REQUEST ID: {request_id}")
print(f"METHOD: {request.method}")
print(f"PATH: {request.url.path}")
print(f"HEADERS: {dict(request.headers)}")
try:
response = await call_next(request)
except Exception as e:
# 记录异常
print(f"EXCEPTION: {str(e)}")
raise
# 计算处理时间
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
response.headers["X-Request-ID"] = request_id
# 记录响应
print(f"RESPONSE STATUS: {response.status_code}")
print(f"PROCESS TIME: {process_time}s")
return response#6. 实际应用示例
#6.1 身份验证中间件
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import HTTPException, status
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 检查认证头
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
# 某些路径不需要认证
if request.url.path in ["/health", "/docs", "/redoc"]:
response = await call_next(request)
return response
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header"
)
token = auth_header[7:] # 移除 "Bearer "
# 验证 token(这里简化处理)
if not self.validate_token(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
response = await call_next(request)
return response
def validate_token(self, token: str) -> bool:
# 实现 token 验证逻辑
return True # 简化示例#6.2 速率限制中间件
import time
from collections import defaultdict
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_requests: int = 100, window: int = 60):
super().__init__(app)
self.max_requests = max_requests
self.window = window
self.requests = defaultdict(list)
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
now = time.time()
# 清理过期请求记录
self.requests[client_ip] = [
req_time for req_time in self.requests[client_ip]
if now - req_time < self.window
]
# 检查是否超过限制
if len(self.requests[client_ip]) >= self.max_requests:
return Response(
status_code=429,
content="Too Many Requests"
)
# 记录当前请求
self.requests[client_ip].append(now)
response = await call_next(request)
return response#7. 中间件注册顺序
app = FastAPI()
# 注意:中间件的注册顺序很重要!
# 请求时按顺序执行,响应时按相反顺序执行
# 1. CORS 中间件(最先执行请求部分)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 2. GZip 中间件
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 3. 自定义日志中间件
app.add_middleware(LoggingMiddleware)
# 4. 自定义认证中间件
app.add_middleware(AuthMiddleware)
# 请求流向:CORS → GZip → Logging → Auth → Route Handler
# 响应流向:Route Handler → Auth → Logging → GZip → CORS#8. 错误处理
@app.middleware("http")
async def handle_errors(request: Request, call_next):
try:
response = await call_next(request)
except HTTPException as e:
# 记录 HTTP 异常
print(f"HTTP Error: {e.status_code} - {e.detail}")
raise
except Exception as e:
# 记录其他异常
print(f"Server Error: {str(e)}")
return JSONResponse(
status_code=500,
content={"detail": "Internal Server Error"}
)
return response#9. 性能考虑
#9.1 中间件性能优化
# 避免在中间件中进行耗时操作
@app.middleware("http")
async def optimized_middleware(request: Request, call_next):
# ❌ 避免:数据库查询、外部API调用等耗时操作
# expensive_operation()
# ✅ 推荐:只进行必要的轻量级操作
if should_process_request(request):
# 仅在必要时进行处理
pass
response = await call_next(request)
return response#9.2 条件执行
@app.middleware("http")
async def conditional_middleware(request: Request, call_next):
# 仅为特定路径执行中间件逻辑
if request.url.path.startswith("/api/"):
# 执行API特定的中间件逻辑
pass
response = await call_next(request)
return response#10. 小结
# FastAPI 中间件完整示例
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
import time
import logging
app = FastAPI()
# 注册中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["https://daomanpy.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
@app.middleware("http")
async def timing_middleware(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/")
async def root():
return {"message": "Hello World"}⚠️ 注意:中间件会影响所有请求,所以要谨慎使用,避免不必要的性能开销。
🔗 扩展阅读
#中间件生命周期与执行顺序
#中间件执行流程详解
中间件在请求处理过程中遵循特定的执行顺序。理解这个顺序对于正确实现中间件至关重要:
客户端请求 →
↓
中间件1 (请求处理) →
↓
中间件2 (请求处理) →
↓
中间件3 (请求处理) →
↓
路由处理函数 →
↓
中间件3 (响应处理) →
↓
中间件2 (响应处理) →
↓
中间件1 (响应处理) →
↓
返回给客户端#中间件执行顺序示例
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
app = FastAPI()
@app.middleware("http")
async def middleware_1(request: Request, call_next):
print("Middleware 1 - 请求前")
response = await call_next(request)
print("Middleware 1 - 响应后")
return response
@app.middleware("http")
async def middleware_2(request: Request, call_next):
print("Middleware 2 - 请求前")
response = await call_next(request)
print("Middleware 2 - 响应后")
return response
@app.get("/")
async def root():
print("路由处理函数")
return {"message": "Hello World"}请求输出顺序:
Middleware 1 - 请求前
Middleware 2 - 请求前
路由处理函数
Middleware 2 - 响应后
Middleware 1 - 响应后#BaseHTTPMiddleware 与装饰器中间件的区别
# 装饰器方式
@app.middleware("http")
async def decorator_middleware(request: Request, call_next):
# 请求前处理
start_time = time.time()
response = await call_next(request)
# 响应后处理
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# BaseHTTPMiddleware 方式
from starlette.middleware.base import BaseHTTPMiddleware
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 请求前处理
start_time = time.time()
response = await call_next(request)
# 响应后处理
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 注册中间件
app.add_middleware(CustomMiddleware)#自定义中间件开发
#高级自定义中间件实现
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.types import ASGIApp
import time
import json
import uuid
from typing import Optional, Dict, Any
import logging
class AdvancedRequestMiddleware(BaseHTTPMiddleware):
"""高级请求处理中间件 - 集成日志、性能监控、安全检查等功能"""
def __init__(self, app: ASGIApp, enable_logging: bool = True):
super().__init__(app)
self.enable_logging = enable_logging
self.logger = logging.getLogger(__name__)
async def dispatch(self, request: Request, call_next) -> Response:
# 生成请求ID
request_id = str(uuid.uuid4())
start_time = time.time()
# 添加请求ID到请求状态
request.state.request_id = request_id
# 记录请求开始
if self.enable_logging:
self.logger.info(
f"REQUEST_START id={request_id} method={request.method} path={request.url.path}"
)
try:
# 执行下一个中间件或路由
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = f"{process_time:.3f}s"
# 记录请求完成
if self.enable_logging:
self.logger.info(
f"REQUEST_END id={request_id} status={response.status_code} "
f"time={process_time:.3f}s"
)
return response
except Exception as exc:
# 记录异常
process_time = time.time() - start_time
if self.enable_logging:
self.logger.error(
f"REQUEST_ERROR id={request_id} error={str(exc)} "
f"time={process_time:.3f}s"
)
# 返回错误响应
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"request_id": request_id,
"timestamp": time.time()
}
)
# 使用中间件
app.add_middleware(AdvancedRequestMiddleware, enable_logging=True)#条件中间件实现
class ConditionalMiddleware(BaseHTTPMiddleware):
"""条件执行中间件 - 根据路径、方法等条件决定是否执行"""
def __init__(
self,
app: ASGIApp,
skip_paths: Optional[list] = None,
skip_methods: Optional[list] = None
):
super().__init__(app)
self.skip_paths = skip_paths or ["/health", "/metrics"]
self.skip_methods = skip_methods or ["OPTIONS", "HEAD"]
async def dispatch(self, request: Request, call_next) -> Response:
# 检查是否跳过此中间件
if request.url.path in self.skip_paths:
return await call_next(request)
if request.method in self.skip_methods:
return await call_next(request)
# 执行中间件逻辑
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Conditional-Middleware-Time"] = f"{process_time:.3f}s"
return response
# 使用条件中间件
app.add_middleware(
ConditionalMiddleware,
skip_paths=["/health", "/docs", "/redoc"],
skip_methods=["OPTIONS"]
)#CORS跨域处理
#CORS基础概念
CORS(Cross-Origin Resource Sharing,跨源资源共享)是一种W3C标准,它允许服务器声明哪些外域可以访问其资源。在Web开发中,当一个页面试图从不同源(域、协议或端口)请求资源时,就会触发CORS机制。
#FastAPI内置CORS中间件详解
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 详细的CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://daomanpy.com",
"https://www.daomanpy.com",
"http://localhost:3000", # React开发服务器
"http://127.0.0.1:3000", # React开发服务器
"https://localhost:3001", # Vue开发服务器
],
allow_credentials=True, # 允许携带Cookie
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"X-Requested-With",
"Accept",
"Origin",
"User-Agent",
"DNT",
"Cache-Control",
"X-Mx-ReqToken",
"Keep-Alive",
"X-Requested-With",
"If-Modified-Since",
"X-CSRF-Token",
"X-Request-ID"
],
# 暴露给前端的响应头
expose_headers=["X-Request-ID", "X-Process-Time"],
# 预检请求缓存时间(秒)
max_age=600,
)#动态CORS配置
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import os
def create_cors_config():
"""根据环境变量动态创建CORS配置"""
if os.getenv("ENVIRONMENT") == "development":
return {
"allow_origins": ["*"], # 开发环境允许所有源
"allow_credentials": True,
"allow_methods": ["*"],
"allow_headers": ["*"],
}
else:
# 生产环境严格配置
return {
"allow_origins": [
"https://daomanpy.com",
"https://api.daomanpy.com",
],
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE"],
"allow_headers": ["Content-Type", "Authorization"],
"max_age": 3600,
}
app = FastAPI()
cors_config = create_cors_config()
app.add_middleware(
CORSMiddleware,
**cors_config
)#自定义CORS中间件
class CustomCORSMiddleware(BaseHTTPMiddleware):
"""自定义CORS中间件 - 更灵活的跨域控制"""
def __init__(
self,
app: ASGIApp,
allow_origins: list = None,
allow_credentials: bool = False,
allow_methods: list = None,
allow_headers: list = None
):
super().__init__(app)
self.allow_origins = allow_origins or ["*"]
self.allow_credentials = allow_credentials
self.allow_methods = allow_methods or ["*"]
self.allow_headers = allow_headers or ["*"]
async def dispatch(self, request: Request, call_next) -> Response:
# 获取请求来源
origin = request.headers.get("origin")
# 检查是否允许该来源
if origin and self._is_origin_allowed(origin):
# 设置CORS头部
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = str(self.allow_credentials).lower()
response.headers["Access-Control-Allow-Methods"] = ", ".join(self.allow_methods)
response.headers["Access-Control-Allow-Headers"] = ", ".join(self.allow_headers)
return response
else:
# 拒绝跨域请求
response = await call_next(request)
return response
def _is_origin_allowed(self, origin: str) -> bool:
"""检查来源是否被允许"""
if "*" in self.allow_origins:
return True
return origin in self.allow_origins
# 使用自定义CORS中间件
app.add_middleware(
CustomCORSMiddleware,
allow_origins=["https://daomanpy.com", "https://app.daomanpy.com"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Content-Type", "Authorization"]
)#GZip响应压缩
#GZip中间件配置
from starlette.middleware.gzip import GZipMiddleware
# 基础配置
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 高级配置
class AdvancedGZipMiddleware(GZipMiddleware):
"""高级GZip压缩中间件 - 支持更多配置选项"""
def __init__(
self,
app: ASGIApp,
minimum_size: int = 500,
compresslevel: int = 6,
content_types: tuple = None
):
super().__init__(app, minimum_size=minimum_size, compresslevel=compresslevel)
self.content_types = content_types or (
"text/plain",
"text/html",
"text/css",
"text/javascript",
"application/javascript",
"application/json",
"application/xml",
"application/rss+xml",
"image/svg+xml"
)
async def compress_response(self, response):
"""压缩响应内容"""
if response.status_code < 200 or response.status_code >= 300:
return response
content_type = response.headers.get("content-type", "").split(";")[0].strip()
if content_type in self.content_types:
return await super().compress_response(response)
return response
# 使用高级GZip中间件
app.add_middleware(
AdvancedGZipMiddleware,
minimum_size=500,
compresslevel=6,
content_types=(
"text/plain",
"text/html",
"application/json",
"application/javascript",
"text/css"
)
)#条件压缩中间件
class ConditionalGZipMiddleware(BaseHTTPMiddleware):
"""条件GZip压缩中间件 - 根据请求头和内容类型决定是否压缩"""
def __init__(self, app: ASGIApp, minimum_size: int = 500):
super().__init__(app)
self.minimum_size = minimum_size
async def dispatch(self, request: Request, call_next) -> Response:
# 检查客户端是否支持gzip
accept_encoding = request.headers.get("accept-encoding", "")
if "gzip" not in accept_encoding.lower():
response = await call_next(request)
return response
response = await call_next(request)
# 检查内容长度和类型
content_length = response.headers.get("content-length")
if content_length and int(content_length) < self.minimum_size:
return response
content_type = response.headers.get("content-type", "")
if any(ct in content_type.lower() for ct in ["text/", "application/json", "application/javascript"]):
# 应用GZip压缩
import gzip
from starlette.datastructures import Headers
# 获取原始响应体
body = b""
async for chunk in response.body_iterator:
body += chunk
# 压缩内容
compressed_body = gzip.compress(body)
# 创建压缩后的响应
headers = Headers(response.headers)
headers["Content-Encoding"] = "gzip"
headers["Content-Length"] = str(len(compressed_body))
response = Response(
content=compressed_body,
status_code=response.status_code,
headers=headers,
media_type=response.media_type
)
return response#请求日志记录
#详细请求日志中间件
import logging
from datetime import datetime
from typing import Optional
import json
class DetailedLoggingMiddleware(BaseHTTPMiddleware):
"""详细请求日志中间件 - 记录请求的详细信息"""
def __init__(self, app: ASGIApp, logger_name: str = "request_logger"):
super().__init__(app)
self.logger = logging.getLogger(logger_name)
# 配置日志格式
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
async def dispatch(self, request: Request, call_next) -> Response:
request_id = str(uuid.uuid4())
start_time = time.time()
# 记录请求信息
request_info = {
"request_id": request_id,
"method": request.method,
"path": str(request.url),
"query_params": dict(request.query_params),
"headers": dict(request.headers),
"client": request.client.host if request.client else None,
"timestamp": datetime.utcnow().isoformat()
}
self.logger.info(f"REQUEST_START - {json.dumps(request_info)}")
try:
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 记录响应信息
response_info = {
"request_id": request_id,
"status_code": response.status_code,
"process_time": f"{process_time:.3f}s",
"content_length": response.headers.get("content-length", "unknown"),
"timestamp": datetime.utcnow().isoformat()
}
log_level = logging.INFO if response.status_code < 400 else logging.WARNING
self.logger.log(log_level, f"REQUEST_END - {json.dumps(response_info)}")
# 添加请求ID到响应头
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
process_time = time.time() - start_time
error_info = {
"request_id": request_id,
"error": str(e),
"process_time": f"{process_time:.3f}s",
"timestamp": datetime.utcnow().isoformat()
}
self.logger.error(f"REQUEST_ERROR - {json.dumps(error_info)}")
raise
# 使用详细日志中间件
app.add_middleware(DetailedLoggingMiddleware)#性能监控与耗时统计
#性能监控中间件
import time
from collections import defaultdict, deque
from typing import Deque, Dict, List
import threading
class PerformanceMonitoringMiddleware(BaseHTTPMiddleware):
"""性能监控中间件 - 统计请求耗时、吞吐量等指标"""
def __init__(self, app: ASGIApp, window_size: int = 1000):
super().__init__(app)
self.window_size = window_size
self.lock = threading.Lock()
# 统计数据存储
self.stats = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"request_times": deque(maxlen=window_size), # 滑动窗口
"status_codes": defaultdict(int), # 状态码统计
"path_stats": defaultdict(lambda: {
"count": 0,
"total_time": 0.0,
"avg_time": 0.0
})
}
async def dispatch(self, request: Request, call_next) -> Response:
start_time = time.time()
try:
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 更新统计数据
with self.lock:
self.stats["total_requests"] += 1
self.stats["successful_requests"] += 1
self.stats["status_codes"][response.status_code] += 1
self.stats["request_times"].append(process_time)
# 更新路径统计
path_stats = self.stats["path_stats"][request.url.path]
path_stats["count"] += 1
path_stats["total_time"] += process_time
path_stats["avg_time"] = path_stats["total_time"] / path_stats["count"]
# 添加性能指标到响应头
response.headers["X-Process-Time"] = f"{process_time:.3f}s"
return response
except Exception as e:
process_time = time.time() - start_time
with self.lock:
self.stats["total_requests"] += 1
self.stats["failed_requests"] += 1
self.stats["request_times"].append(process_time)
raise
def get_performance_metrics(self) -> Dict:
"""获取性能指标"""
with self.lock:
if not self.stats["request_times"]:
return {"error": "No data available"}
times_list = list(self.stats["request_times"])
return {
"total_requests": self.stats["total_requests"],
"successful_requests": self.stats["successful_requests"],
"failed_requests": self.stats["failed_requests"],
"avg_response_time": sum(times_list) / len(times_list),
"p95_response_time": sorted(times_list)[int(0.95 * len(times_list))] if times_list else 0,
"p99_response_time": sorted(times_list)[int(0.99 * len(times_list))] if times_list else 0,
"requests_per_second": len(times_list) / 60 if len(times_list) > 0 else 0,
"status_codes": dict(self.stats["status_codes"]),
"top_slow_paths": sorted(
[(path, stats["avg_time"]) for path, stats in self.stats["path_stats"].items()],
key=lambda x: x[1],
reverse=True
)[:10]
}
# 创建性能监控中间件实例
perf_middleware = PerformanceMonitoringMiddleware(app)
@app.get("/metrics")
async def get_metrics():
"""获取性能指标端点"""
return perf_middleware.get_performance_metrics()
# 注册中间件
app.add_middleware(PerformanceMonitoringMiddleware, window_size=1000)#安全中间件
#安全头中间件
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""安全头中间件 - 添加重要的安全头部"""
def __init__(self, app: ASGIApp, strict_transport_security: bool = True):
super().__init__(app)
self.strict_transport_security = strict_transport_security
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# 添加安全头部
security_headers = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
}
if self.strict_transport_security:
security_headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
for header, value in security_headers.items():
response.headers[header] = value
return response
# 使用安全头中间件
app.add_middleware(SecurityHeadersMiddleware, strict_transport_security=True)#IP过滤中间件
import ipaddress
from typing import List, Union
class IPFilterMiddleware(BaseHTTPMiddleware):
"""IP过滤中间件 - 允许/阻止特定IP地址"""
def __init__(
self,
app: ASGIApp,
allowed_ips: List[str] = None,
blocked_ips: List[str] = None,
allowed_ranges: List[str] = None
):
super().__init__(app)
self.allowed_ips = set(allowed_ips or [])
self.blocked_ips = set(blocked_ips or [])
self.allowed_ranges = [ipaddress.IPv4Network(r, strict=False) for r in (allowed_ranges or [])]
async def dispatch(self, request: Request, call_next) -> Response:
client_ip_str = request.client.host if request.client else "127.0.0.1"
try:
client_ip = ipaddress.IPv4Address(client_ip_str)
except ipaddress.AddressValueError:
# 如果IP格式无效,允许通过
return await call_next(request)
# 检查是否在阻止列表中
if str(client_ip) in self.blocked_ips:
return JSONResponse(
status_code=403,
content={"error": "IP address is blocked"}
)
# 检查是否在允许列表中
if self.allowed_ips and str(client_ip) not in self.allowed_ips:
# 检查是否在允许的范围内
is_allowed_range = any(client_ip in network for network in self.allowed_ranges)
if not is_allowed_range:
return JSONResponse(
status_code=403,
content={"error": "IP address not allowed"}
)
return await call_next(request)
# 使用IP过滤中间件
app.add_middleware(
IPFilterMiddleware,
allowed_ips=["192.168.1.100", "10.0.0.1"],
blocked_ips=["192.168.1.200"],
allowed_ranges=["192.168.1.0/24"]
)#错误处理中间件
#统一错误处理中间件
import traceback
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
class UnifiedErrorHandlerMiddleware(BaseHTTPMiddleware):
"""统一错误处理中间件 - 捕获并格式化各种类型的错误"""
def __init__(self, app: ASGIApp, debug: bool = False):
super().__init__(app)
self.debug = debug
async def dispatch(self, request: Request, call_next) -> Response:
try:
response = await call_next(request)
return response
except StarletteHTTPException as e:
# FastAPI HTTP异常
error_detail = {
"error": "HTTP Exception",
"status_code": e.status_code,
"detail": e.detail
}
if self.debug:
error_detail["traceback"] = traceback.format_exc()
return JSONResponse(
status_code=e.status_code,
content=error_detail
)
except RequestValidationError as e:
# 请求验证错误
error_detail = {
"error": "Validation Error",
"status_code": 422,
"detail": e.errors()
}
if self.debug:
error_detail["body"] = e.body
return JSONResponse(
status_code=422,
content=error_detail
)
except Exception as e:
# 未处理的异常
error_detail = {
"error": "Internal Server Error",
"status_code": 500,
"message": "An unexpected error occurred"
}
if self.debug:
error_detail["traceback"] = traceback.format_exc()
error_detail["exception"] = str(e)
return JSONResponse(
status_code=500,
content=error_detail
)
# 使用统一错误处理中间件
app.add_middleware(UnifiedErrorHandlerMiddleware, debug=False)#会话管理中间件
#简单会话中间件
import pickle
from typing import Dict, Any
import hashlib
import secrets
class SessionMiddleware(BaseHTTPMiddleware):
"""简单会话管理中间件"""
def __init__(self, app: ASGIApp, secret_key: str, session_cookie_name: str = "session_id"):
super().__init__(app)
self.secret_key = secret_key.encode()
self.session_cookie_name = session_cookie_name
self.sessions: Dict[str, Dict[str, Any]] = {} # 简单内存存储,生产环境应使用Redis
async def dispatch(self, request: Request, call_next) -> Response:
# 获取会话ID
session_id = request.cookies.get(self.session_cookie_name)
# 验证会话ID签名
if session_id and self._verify_session_id(session_id):
session_data = self.sessions.get(session_id, {})
else:
session_id = self._generate_session_id()
session_data = {}
# 将会话数据添加到请求状态
request.state.session = session_data
request.state.session_id = session_id
response = await call_next(request)
# 保存会话数据
if hasattr(request.state, 'session_modified') and request.state.session_modified:
self.sessions[session_id] = session_data
response.set_cookie(
self.session_cookie_name,
session_id,
httponly=True,
secure=True,
samesite="lax"
)
return response
def _generate_session_id(self) -> str:
"""生成会话ID"""
session_bytes = secrets.token_bytes(32)
signature = hashlib.sha256(session_bytes + self.secret_key).hexdigest()
return session_bytes.hex() + signature
def _verify_session_id(self, session_id: str) -> bool:
"""验证会话ID签名"""
if len(session_id) < 64:
return False
session_part = session_id[:64]
signature_part = session_id[64:]
expected_signature = hashlib.sha256(
bytes.fromhex(session_part) + self.secret_key
).hexdigest()
return signature_part == expected_signature
# 会话辅助函数
def get_session(request: Request):
"""获取会话数据"""
return request.state.session
def set_session_value(request: Request, key: str, value: Any):
"""设置会话值"""
request.state.session[key] = value
request.state.session_modified = True
def get_session_value(request: Request, key: str, default=None):
"""获取会话值"""
return request.state.session.get(key, default)
# 使用会话中间件
app.add_middleware(SessionMiddleware, secret_key="your-secret-key-here")#中间件最佳实践
#中间件性能优化
class OptimizedMiddleware(BaseHTTPMiddleware):
"""性能优化的中间件实现"""
def __init__(self, app: ASGIApp):
super().__init__(app)
# 预编译正则表达式等昂贵操作
import re
self.health_check_pattern = re.compile(r'^/(health|ready|metrics)$')
async def dispatch(self, request: Request, call_next) -> Response:
# 快速路径:跳过健康检查等路径
if self.health_check_pattern.match(request.url.path):
return await call_next(request)
# 执行中间件逻辑
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = f"{process_time:.3f}s"
return response#中间件链管理
class MiddlewareChainManager:
"""中间件链管理器 - 更好地组织和管理中间件"""
def __init__(self):
self.middleware_stack = []
def add_middleware(self, middleware_class, **kwargs):
"""添加中间件到链中"""
self.middleware_stack.append((middleware_class, kwargs))
def apply_to_app(self, app: FastAPI):
"""将中间件链应用到FastAPI应用"""
# 反向应用中间件(因为FastAPI内部会反转顺序)
for middleware_class, kwargs in reversed(self.middleware_stack):
app.add_middleware(middleware_class, **kwargs)
# 使用中间件链管理器
middleware_chain = MiddlewareChainManager()
# 添加中间件
middleware_chain.add_middleware(SecurityHeadersMiddleware)
middleware_chain.add_middleware(
AdvancedRequestMiddleware,
enable_logging=True
)
middleware_chain.add_middleware(
CustomCORSMiddleware,
allow_origins=["https://daomanpy.com"]
)
# 应用到FastAPI
middleware_chain.apply_to_app(app)#生产环境部署
#生产环境中间件配置
import os
from typing import List
def configure_production_middlewares(app: FastAPI):
"""配置生产环境中间件"""
# 根据环境变量决定启用哪些中间件
environment = os.getenv("ENVIRONMENT", "development")
if environment == "production":
# 生产环境严格CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://daomanpy.com",
"https://www.daomanpy.com"
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Content-Type", "Authorization", "X-Request-ID"],
max_age=86400 # 24小时
)
# 生产环境安全头
app.add_middleware(
SecurityHeadersMiddleware,
strict_transport_security=True
)
# 生产环境日志
app.add_middleware(
DetailedLoggingMiddleware,
logger_name="prod_logger"
)
# 生产环境性能监控
app.add_middleware(
PerformanceMonitoringMiddleware,
window_size=10000 # 更大的滑动窗口
)
else:
# 开发环境宽松配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(DebuggingMiddleware) # 开发专用中间件
# 应用生产环境配置
configure_production_middlewares(app)#与其他框架对比
| 特性 | FastAPI | Django | Flask | Express.js |
|---|---|---|---|---|
| 中间件类型 | BaseHTTPMiddleware + 装饰器 | MIDDLEWARE 类 | @app.before_request 等 | app.use() 函数 |
| 性能 | 高性能ASGI | WSGI,性能一般 | WSGI,性能一般 | Node.js,高性能 |
| 异步支持 | 原生异步 | 有限支持 | 不支持 | 支持异步 |
| CORS支持 | fastapi-cors | django-cors-headers | flask-cors | cors 包 |
| 学习曲线 | 中等 | 较陡峭 | 平缓 | 平缓 |
#总结
FastAPI中间件提供了强大的请求/响应处理能力:
- 灵活的架构:支持装饰器和继承两种中间件实现方式
- 丰富的生态系统:内置CORS、GZip等常用中间件
- 高性能:异步处理,不影响应用性能
- 安全性:可添加各种安全头部和防护措施
- 监控能力:内置性能监控和日志记录功能
使用中间件时需要注意:
- 中间件执行顺序很重要
- 避免在中间件中进行耗时操作
- 合理使用条件判断避免不必要的处理
- 生产环境中要进行充分测试
💡 关键要点:中间件是FastAPI应用的重要组成部分,正确使用可以显著提升应用的安全性、性能和可维护性。
🔗 扩展阅读
</final_file_content>

