#FastAPI WebSocket实时通信完全指南
📂 所属阶段:第六阶段 — 2026 特色专题(AI 集成篇)
🔗 相关章节:流式响应 StreamingResponse · OAuth2 与 JWT 鉴权
#目录
#WebSocket基础概念
#为什么需要WebSocket?
在现代实时应用中,传统的HTTP请求-响应模式存在局限性:
传统HTTP轮询:
┌─────────────────────────────────────────────────────┐
│ 客户端 → 请求 → 服务器 → 响应 → 客户端 │
│ (重复请求,浪费带宽和资源) │
└─────────────────────────────────────────────────────┘
WebSocket实时通信:
┌─────────────────────────────────────────────────────┐
│ 客户端 ←→ 持久连接 ←→ 服务器 │
│ (双向实时通信,高效低延迟) │
└─────────────────────────────────────────────────────┘#WebSocket vs HTTP对比
| 特性 | HTTP | WebSocket |
|---|---|---|
| 连接方式 | 短连接(请求-响应后断开) | 长连接(持久化) |
| 通信方向 | 单向(客户端→服务器) | 双向(任意一方可发送) |
| 实时性 | 需要轮询(延迟高) | 真正实时推送 |
| 资源消耗 | 每次请求都有HTTP头部开销 | 首次握手后开销极小 |
| 适用场景 | REST API、文件传输 | 聊天、游戏、实时协作 |
#WebSocket协议升级过程
客户端发起HTTP升级请求:
GET /ws HTTP/1.1
Host: example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
服务器同意升级:
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=#典型应用场景
- 实时聊天应用:IM聊天、客服系统
- AI对话助手:实时流式AI回复
- 协作编辑器:多人协同编辑
- 实时通知:系统消息推送
- 在线游戏:实时游戏状态同步
- 股票行情:实时金融数据推送
- IoT设备:设备状态实时监控
#FastAPI WebSocket基础
#基础WebSocket端点
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from typing import Dict, List
import asyncio
import json
from datetime import datetime
import logging
app = FastAPI(title="Daoman WebSocket Service")
logger = logging.getLogger(__name__)
# 简单测试页面
html_page = """
<!DOCTYPE html>
<html>
<head>
<title>WebSocket 测试</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; }
#messages { height: 300px; overflow-y: scroll; border: 1px solid #ccc; padding: 10px; margin: 10px 0; }
.message { margin: 5px 0; padding: 5px; border-radius: 4px; }
.received { background-color: #e3f2fd; }
.sent { background-color: #e8f5e8; }
.system { background-color: #fff3e0; font-style: italic; }
input[type="text"] { width: 70%; padding: 8px; }
button { padding: 8px 16px; margin-left: 10px; }
</style>
</head>
<body>
<h1>WebSocket 实时通信测试</h1>
<div id="status">连接状态: 未连接</div>
<div id="messages"></div>
<div>
<input type="text" id="messageText" placeholder="输入消息..." onkeypress="handleKeyPress(event)">
<button onclick="sendMessage()">发送</button>
<button onclick="closeConnection()">断开连接</button>
</div>
<script>
let ws;
const messagesDiv = document.getElementById('messages');
const statusDiv = document.getElementById('status');
function connect() {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUri = protocol + '//' + window.location.host + '/ws';
ws = new WebSocket(wsUri);
ws.onopen = function(evt) {
statusDiv.innerHTML = '连接状态: <span style="color: green;">已连接</span>';
addMessage('已连接到服务器', 'system');
};
ws.onclose = function(evt) {
statusDiv.innerHTML = '连接状态: <span style="color: red;">已断开</span>';
addMessage('与服务器连接已断开', 'system');
};
ws.onmessage = function(evt) {
const data = JSON.parse(evt.data);
addMessage(`收到: ${data.message}`, 'received');
};
ws.onerror = function(evt) {
addMessage('WebSocket错误: ' + evt.data, 'system');
};
}
function sendMessage() {
const input = document.getElementById('messageText');
if (input.value) {
const message = {
type: 'chat',
content: input.value,
timestamp: new Date().toISOString()
};
ws.send(JSON.stringify(message));
addMessage(`发送: ${input.value}`, 'sent');
input.value = '';
}
}
function closeConnection() {
if (ws) {
ws.close();
}
}
function addMessage(message, type) {
const messageDiv = document.createElement('div');
messageDiv.className = 'message ' + type;
messageDiv.innerHTML = `<strong>[${new Date().toLocaleTimeString()}]</strong> ${message}`;
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
function handleKeyPress(event) {
if (event.key === 'Enter') {
sendMessage();
}
}
// 页面加载时自动连接
window.onload = connect;
</script>
</body>
</html>
"""
@app.get("/")
async def get_websocket_page():
"""WebSocket测试页面"""
return HTMLResponse(html_page)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""基础WebSocket端点"""
await websocket.accept()
client_host = websocket.client.host
client_port = websocket.client.port
logger.info(f"WebSocket连接建立: {client_host}:{client_port}")
try:
while True:
# 接收客户端消息
data = await websocket.receive_text()
logger.info(f"收到消息: {data}")
# 解析消息
try:
message_data = json.loads(data)
message_type = message_data.get('type', 'chat')
content = message_data.get('content', data)
except json.JSONDecodeError:
message_type = 'chat'
content = data
# 处理不同类型的消息
if message_type == 'chat':
response = {
"type": "response",
"message": f"服务器收到: {content}",
"timestamp": datetime.utcnow().isoformat(),
"server_time": datetime.now().isoformat()
}
await websocket.send_text(json.dumps(response))
elif message_type == 'ping':
response = {
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
}
await websocket.send_text(json.dumps(response))
elif message_type == 'echo':
response = {
"type": "echo",
"original": content,
"timestamp": datetime.utcnow().isoformat()
}
await websocket.send_text(json.dumps(response))
except WebSocketDisconnect:
logger.info(f"WebSocket连接断开: {client_host}:{client_port}")
except Exception as e:
logger.error(f"WebSocket处理错误: {str(e)}")
await websocket.close(code=1011, reason=f"服务器错误: {str(e)}")#WebSocket消息类型处理
from enum import Enum
from typing import Union
import asyncio
class MessageType(Enum):
"""WebSocket消息类型枚举"""
CHAT = "chat"
PING = "ping"
ECHO = "echo"
JOIN_ROOM = "join_room"
LEAVE_ROOM = "leave_room"
PRIVATE_MESSAGE = "private_message"
SYSTEM_ALERT = "system_alert"
FILE_TRANSFER = "file_transfer"
HEARTBEAT = "heartbeat"
class WebSocketMessage:
"""WebSocket消息处理器"""
def __init__(self, websocket: WebSocket):
self.websocket = websocket
self.client_info = {}
async def send_message(self, message_type: MessageType, data: dict):
"""发送标准化消息"""
message = {
"type": message_type.value,
"data": data,
"timestamp": datetime.utcnow().isoformat()
}
await self.websocket.send_text(json.dumps(message))
async def send_error(self, error_message: str, code: int = 4000):
"""发送错误消息"""
error_data = {
"type": "error",
"message": error_message,
"code": code,
"timestamp": datetime.utcnow().isoformat()
}
await self.websocket.send_text(json.dumps(error_data))
async def handle_message(self, raw_message: str):
"""处理接收到的消息"""
try:
message_data = json.loads(raw_message)
message_type = message_data.get('type', 'chat')
content = message_data.get('content', '')
if message_type == MessageType.CHAT.value:
await self.handle_chat_message(content)
elif message_type == MessageType.PING.value:
await self.handle_ping()
elif message_type == MessageType.ECHO.value:
await self.handle_echo(content)
elif message_type == MessageType.JOIN_ROOM.value:
room_id = message_data.get('room_id')
await self.handle_join_room(room_id)
elif message_type == MessageType.LEAVE_ROOM.value:
room_id = message_data.get('room_id')
await self.handle_leave_room(room_id)
elif message_type == MessageType.PRIVATE_MESSAGE.value:
target_user = message_data.get('target_user')
content = message_data.get('content')
await self.handle_private_message(target_user, content)
elif message_type == MessageType.HEARTBEAT.value:
await self.handle_heartbeat()
else:
await self.send_error(f"未知消息类型: {message_type}")
except json.JSONDecodeError:
await self.send_error("消息格式错误")
except Exception as e:
await self.send_error(f"消息处理错误: {str(e)}")
async def handle_chat_message(self, content: str):
"""处理聊天消息"""
response_data = {
"content": f"服务器已收到: {content}",
"processed_at": datetime.now().isoformat()
}
await self.send_message(MessageType.CHAT, response_data)
async def handle_ping(self):
"""处理ping消息"""
await self.send_message(MessageType.PING, {"status": "pong"})
async def handle_echo(self, content: str):
"""处理echo消息"""
await self.send_message(MessageType.ECHO, {"echo": content})
async def handle_join_room(self, room_id: str):
"""处理加入房间"""
# 这里会与连接管理器交互
await self.send_message(MessageType.JOIN_ROOM, {
"room_id": room_id,
"joined": True
})
async def handle_leave_room(self, room_id: str):
"""处理离开房间"""
await self.send_message(MessageType.LEAVE_ROOM, {
"room_id": room_id,
"left": True
})
async def handle_private_message(self, target_user: str, content: str):
"""处理私聊消息"""
await self.send_message(MessageType.PRIVATE_MESSAGE, {
"target_user": target_user,
"content": content,
"sent": True
})
async def handle_heartbeat(self):
"""处理心跳"""
await self.send_message(MessageType.HEARTBEAT, {
"status": "alive",
"timestamp": datetime.utcnow().isoformat()
})
@app.websocket("/ws/advanced")
async def advanced_websocket_endpoint(websocket: WebSocket):
"""高级WebSocket端点"""
await websocket.accept()
message_handler = WebSocketMessage(websocket)
try:
while True:
raw_message = await websocket.receive_text()
await message_handler.handle_message(raw_message)
except WebSocketDisconnect:
logger.info("高级WebSocket连接断开")
except Exception as e:
logger.error(f"高级WebSocket处理错误: {str(e)}")
await message_handler.send_error(f"服务器错误: {str(e)}", 1011)
await websocket.close(code=1011, reason=str(e))#连接管理器实现
#基础连接管理器
# managers/connection_manager.py - 连接管理器
from fastapi import WebSocket
from typing import Dict, List, Set, Optional
import json
from datetime import datetime, timedelta
import asyncio
import logging
from enum import Enum
logger = logging.getLogger(__name__)
class ConnectionStatus(Enum):
"""连接状态枚举"""
CONNECTED = "connected"
DISCONNECTED = "disconnected"
AWAY = "away"
BUSY = "busy"
class ConnectionInfo:
"""连接信息类"""
def __init__(self, websocket: WebSocket, user_id: str, user_info: dict = None):
self.websocket = websocket
self.user_id = user_id
self.user_info = user_info or {}
self.connected_at = datetime.utcnow()
self.last_activity = datetime.utcnow()
self.status = ConnectionStatus.CONNECTED
self.rooms = set()
self.ip_address = websocket.client.host
self.user_agent = websocket.headers.get('user-agent', '')
class ConnectionManager:
"""WebSocket连接管理器"""
def __init__(self):
# 活跃连接:user_id → ConnectionInfo
self.active_connections: Dict[str, ConnectionInfo] = {}
# 房间管理:room_id → set(user_ids)
self.room_connections: Dict[str, Set[str]] = {}
# IP限制:ip_address → connection_count
self.ip_connections: Dict[str, int] = {}
# 最大连接限制
self.max_connections_per_ip = 10
# 连接超时时间(秒)
self.connection_timeout = 3600 # 1小时
async def connect(self, websocket: WebSocket, user_id: str, user_info: dict = None) -> bool:
"""建立连接"""
ip_address = websocket.client.host
# 检查IP连接限制
current_ip_count = self.ip_connections.get(ip_address, 0)
if current_ip_count >= self.max_connections_per_ip:
logger.warning(f"IP {ip_address} 连接数超限")
await websocket.close(code=4000, reason="连接数超限")
return False
# 接受连接
await websocket.accept()
# 创建连接信息
connection_info = ConnectionInfo(websocket, user_id, user_info)
self.active_connections[user_id] = connection_info
# 更新IP计数
self.ip_connections[ip_address] = current_ip_count + 1
logger.info(f"用户 {user_id} 连接到WebSocket")
# 广播用户上线
await self.broadcast_system_message(f"用户 {user_info.get('name', user_id)} 上线了",
exclude_user_ids=[user_id])
return True
async def disconnect(self, user_id: str):
"""断开连接"""
if user_id in self.active_connections:
connection_info = self.active_connections[user_id]
ip_address = connection_info.ip_address
# 从房间中移除
for room_id in connection_info.rooms:
await self.leave_room(room_id, user_id)
# 移除连接
del self.active_connections[user_id]
# 更新IP计数
if ip_address in self.ip_connections:
self.ip_connections[ip_address] -= 1
if self.ip_connections[ip_address] <= 0:
del self.ip_connections[ip_address]
logger.info(f"用户 {user_id} 断开WebSocket连接")
# 广播用户下线
await self.broadcast_system_message(f"用户 {self.get_user_name(user_id)} 离开了")
def get_user_name(self, user_id: str) -> str:
"""获取用户名"""
if user_id in self.active_connections:
user_info = self.active_connections[user_id].user_info
return user_info.get('name', user_id)
return user_id
async def send_personal_message(self, user_id: str, message: dict):
"""发送个人消息"""
if user_id in self.active_connections:
try:
await self.active_connections[user_id].websocket.send_text(json.dumps(message))
self.active_connections[user_id].last_activity = datetime.utcnow()
except Exception as e:
logger.error(f"发送个人消息失败: {str(e)}")
await self.disconnect(user_id)
async def broadcast_message(self, message: dict, exclude_user_ids: List[str] = None):
"""广播消息"""
if exclude_user_ids is None:
exclude_user_ids = []
disconnected_users = []
for user_id, connection_info in self.active_connections.items():
if user_id not in exclude_user_ids:
try:
await connection_info.websocket.send_text(json.dumps(message))
connection_info.last_activity = datetime.utcnow()
except Exception as e:
logger.error(f"广播消息到用户 {user_id} 失败: {str(e)}")
disconnected_users.append(user_id)
# 清理断开的连接
for user_id in disconnected_users:
await self.disconnect(user_id)
async def broadcast_system_message(self, text: str, exclude_user_ids: List[str] = None):
"""广播系统消息"""
message = {
"type": "system",
"content": text,
"timestamp": datetime.utcnow().isoformat(),
"category": "system"
}
await self.broadcast_message(message, exclude_user_ids)
async def join_room(self, room_id: str, user_id: str):
"""加入房间"""
if user_id not in self.active_connections:
return False
# 创建房间如果不存在
if room_id not in self.room_connections:
self.room_connections[room_id] = set()
# 加入房间
self.room_connections[room_id].add(user_id)
self.active_connections[user_id].rooms.add(room_id)
logger.info(f"用户 {user_id} 加入房间 {room_id}")
# 通知房间内其他用户
await self.send_room_message(
room_id,
{
"type": "system",
"content": f"用户 {self.get_user_name(user_id)} 加入了房间",
"timestamp": datetime.utcnow().isoformat(),
"category": "room_join"
},
exclude_user_ids=[user_id]
)
return True
async def leave_room(self, room_id: str, user_id: str):
"""离开房间"""
if (room_id in self.room_connections and
user_id in self.room_connections[room_id]):
# 从房间移除
self.room_connections[room_id].remove(user_id)
# 从用户房间列表移除
if user_id in self.active_connections:
self.active_connections[user_id].rooms.discard(room_id)
logger.info(f"用户 {user_id} 离开房间 {room_id}")
# 通知房间内其他用户
await self.send_room_message(
room_id,
{
"type": "system",
"content": f"用户 {self.get_user_name(user_id)} 离开了房间",
"timestamp": datetime.utcnow().isoformat(),
"category": "room_leave"
},
exclude_user_ids=[user_id]
)
async def send_room_message(self, room_id: str, message: dict, exclude_user_ids: List[str] = None):
"""发送房间消息"""
if exclude_user_ids is None:
exclude_user_ids = []
if room_id not in self.room_connections:
return
disconnected_users = []
for user_id in self.room_connections[room_id]:
if user_id not in exclude_user_ids:
try:
await self.active_connections[user_id].websocket.send_text(json.dumps(message))
self.active_connections[user_id].last_activity = datetime.utcnow()
except Exception as e:
logger.error(f"发送房间消息到用户 {user_id} 失败: {str(e)}")
disconnected_users.append(user_id)
# 清理断开的连接
for user_id in disconnected_users:
await self.disconnect(user_id)
def get_online_users(self) -> List[dict]:
"""获取在线用户列表"""
return [
{
"user_id": user_id,
"user_info": connection_info.user_info,
"connected_at": connection_info.connected_at.isoformat(),
"last_activity": connection_info.last_activity.isoformat(),
"status": connection_info.status.value,
"rooms": list(connection_info.rooms),
"ip_address": connection_info.ip_address
}
for user_id, connection_info in self.active_connections.items()
]
def get_room_users(self, room_id: str) -> List[dict]:
"""获取房间内用户列表"""
if room_id not in self.room_connections:
return []
return [
{
"user_id": user_id,
"user_info": self.active_connections[user_id].user_info,
"connected_at": self.active_connections[user_id].connected_at.isoformat()
}
for user_id in self.room_connections[room_id]
if user_id in self.active_connections
]
def get_connection_count(self) -> int:
"""获取连接总数"""
return len(self.active_connections)
def get_room_count(self) -> int:
"""获取房间总数"""
return len(self.room_connections)
async def cleanup_expired_connections(self):
"""清理过期连接"""
current_time = datetime.utcnow()
expired_users = []
for user_id, connection_info in self.active_connections.items():
if (current_time - connection_info.connected_at).total_seconds() > self.connection_timeout:
expired_users.append(user_id)
for user_id in expired_users:
await self.disconnect(user_id)
logger.info(f"清理过期连接: {user_id}")
# 全局连接管理器实例
manager = ConnectionManager()#高级连接管理器
# managers/advanced_connection_manager.py - 高级连接管理器
import redis.asyncio as redis
from typing import Any, Callable
import asyncio
from dataclasses import dataclass
from enum import Enum
@dataclass
class UserSession:
"""用户会话信息"""
user_id: str
session_id: str
created_at: datetime
last_seen: datetime
device_info: dict
permissions: List[str]
class MessagePriority(Enum):
"""消息优先级"""
LOW = 1
NORMAL = 2
HIGH = 3
CRITICAL = 4
class AdvancedConnectionManager(ConnectionManager):
"""高级连接管理器"""
def __init__(self, redis_url: str = "redis://localhost:6379"):
super().__init__()
self.redis_client = redis.from_url(redis_url)
self.message_queue = asyncio.Queue()
self.message_handlers: Dict[str, Callable] = {}
self.session_store = {}
async def connect_with_session(self, websocket: WebSocket, user_id: str,
session_id: str, user_info: dict = None) -> bool:
"""带会话的连接"""
if not await self.validate_session(session_id, user_id):
await websocket.close(code=4001, reason="会话无效")
return False
# 创建会话信息
session = UserSession(
user_id=user_id,
session_id=session_id,
created_at=datetime.utcnow(),
last_seen=datetime.utcnow(),
device_info=self.extract_device_info(websocket),
permissions=user_info.get('permissions', [])
)
self.session_store[session_id] = session
return await self.connect(websocket, user_id, user_info)
async def validate_session(self, session_id: str, user_id: str) -> bool:
"""验证会话"""
try:
session_data = await self.redis_client.hgetall(f"session:{session_id}")
if not session_data:
return False
stored_user_id = session_data.get(b'user_id', b'').decode('utf-8')
expiry_time = int(session_data.get(b'expiry', 0))
if stored_user_id != user_id or datetime.utcnow().timestamp() > expiry_time:
return False
# 更新最后访问时间
await self.redis_client.hset(f"session:{session_id}", "last_seen",
datetime.utcnow().isoformat())
return True
except Exception as e:
logger.error(f"会话验证失败: {str(e)}")
return False
def extract_device_info(self, websocket: WebSocket) -> dict:
"""提取设备信息"""
headers = websocket.headers
return {
"user_agent": headers.get('user-agent', ''),
"accept_language": headers.get('accept-language', ''),
"connection": headers.get('connection', ''),
"upgrade": headers.get('upgrade', ''),
"sec_websocket_version": headers.get('sec-websocket-version', ''),
"sec_websocket_extensions": headers.get('sec-websocket-extensions', '')
}
async def send_priority_message(self, user_id: str, message: dict,
priority: MessagePriority = MessagePriority.NORMAL):
"""发送优先级消息"""
# 根据优先级决定处理方式
if priority == MessagePriority.CRITICAL:
# 立即发送
await self.send_personal_message(user_id, message)
elif priority == MessagePriority.HIGH:
# 高优先级队列
await self.message_queue.put((priority.value, user_id, message))
else:
# 普通队列
await self.message_queue.put((priority.value, user_id, message))
async def process_message_queue(self):
"""处理消息队列"""
while True:
try:
priority, user_id, message = await self.message_queue.get()
await self.send_personal_message(user_id, message)
self.message_queue.task_done()
except Exception as e:
logger.error(f"处理消息队列失败: {str(e)}")
async def broadcast_with_filter(self, message: dict,
filter_func: Callable[[str, dict], bool] = None,
exclude_user_ids: List[str] = None):
"""带过滤的广播"""
if exclude_user_ids is None:
exclude_user_ids = []
disconnected_users = []
for user_id, connection_info in self.active_connections.items():
if user_id not in exclude_user_ids:
# 应用过滤器
if filter_func and not filter_func(user_id, connection_info.user_info):
continue
try:
await connection_info.websocket.send_text(json.dumps(message))
connection_info.last_activity = datetime.utcnow()
except Exception as e:
logger.error(f"广播消息到用户 {user_id} 失败: {str(e)}")
disconnected_users.append(user_id)
# 清理断开的连接
for user_id in disconnected_users:
await self.disconnect(user_id)
async def get_user_stats(self, user_id: str) -> dict:
"""获取用户统计信息"""
if user_id not in self.active_connections:
return {}
connection_info = self.active_connections[user_id]
# 从Redis获取历史数据
stats_key = f"user_stats:{user_id}"
stats = await self.redis_client.hgetall(stats_key)
return {
"current_session": {
"connected_at": connection_info.connected_at.isoformat(),
"last_activity": connection_info.last_activity.isoformat(),
"rooms": list(connection_info.rooms),
"status": connection_info.status.value
},
"historical_stats": {k.decode(): v.decode() for k, v in stats.items()},
"total_messages_sent": int(stats.get(b'messages_sent', 0)),
"total_messages_received": int(stats.get(b'messages_received', 0))
}
async def update_user_stats(self, user_id: str, stat_updates: dict):
"""更新用户统计信息"""
stats_key = f"user_stats:{user_id}"
await self.redis_client.hmset(stats_key, stat_updates)
await self.redis_client.expire(stats_key, 86400) # 24小时过期
# 全局高级连接管理器实例
advanced_manager = AdvancedConnectionManager()#完整聊天室应用
#WebSocket聊天路由
# routers/websocket_chat.py - WebSocket聊天路由
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
from fastapi.responses import HTMLResponse
from managers.advanced_connection_manager import advanced_manager
from auth.jwt import decode_token # 假设有JWT认证模块
from typing import Optional
import json
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["websocket"])
@router.get("/chat-room", response_class=HTMLResponse)
async def get_chat_room():
"""聊天室页面"""
return HTMLResponse(open("templates/chat_room.html").read())
@router.websocket("/chat")
async def chat_websocket(
websocket: WebSocket,
token: str = Query(...),
room_id: str = Query("general", description="聊天室ID"),
client_version: str = Query("1.0.0", description="客户端版本")
):
"""
聊天WebSocket端点
客户端连接:new WebSocket("ws://localhost:8000/ws/chat?token=xxx&room_id=general")
"""
try:
# 1. 验证JWT令牌
payload = decode_token(token)
user_id = str(payload["sub"])
user_name = payload.get("name", payload.get("email", user_id))
permissions = payload.get("permissions", [])
# 2. 建立连接
session_id = f"session_{user_id}_{int(datetime.utcnow().timestamp())}"
user_info = {
"name": user_name,
"email": payload.get("email"),
"avatar": payload.get("avatar"),
"permissions": permissions,
"client_version": client_version
}
if not await advanced_manager.connect_with_session(websocket, user_id, session_id, user_info):
await websocket.close(code=4001, reason="认证失败")
return
# 3. 加入指定房间
await advanced_manager.join_room(room_id, user_id)
# 4. 发送欢迎消息
welcome_msg = {
"type": "welcome",
"content": f"欢迎来到 {room_id} 聊天室,{user_name}!",
"timestamp": datetime.utcnow().isoformat(),
"user_id": user_id,
"user_name": user_name
}
await advanced_manager.send_personal_message(user_id, welcome_msg)
# 5. 发送房间用户列表
room_users = advanced_manager.get_room_users(room_id)
user_list_msg = {
"type": "user_list",
"users": room_users,
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, user_list_msg)
logger.info(f"用户 {user_name}({user_id}) 连接到聊天室 {room_id}")
try:
while True:
# 6. 接收消息
raw_data = await websocket.receive_text()
message = json.loads(raw_data)
msg_type = message.get("type", "text")
# 更新用户统计
await advanced_manager.update_user_stats(user_id, {
"last_message_time": datetime.utcnow().isoformat(),
"messages_sent": 1
})
if msg_type == "text":
# 文本消息:广播给房间内所有人
chat_msg = {
"type": "message",
"user_id": user_id,
"user_name": user_name,
"content": message["content"],
"timestamp": datetime.utcnow().isoformat(),
"room_id": room_id
}
await advanced_manager.send_room_message(room_id, chat_msg)
elif msg_type == "private":
# 私聊消息
target_user = message["to"]
if target_user in advanced_manager.active_connections:
private_msg = {
"type": "private_message",
"from": user_id,
"from_name": user_name,
"content": message["content"],
"timestamp": datetime.utcnow().isoformat(),
}
await advanced_manager.send_personal_message(target_user, private_msg)
# 发送给自己确认
await advanced_manager.send_personal_message(user_id, {
**private_msg, "to": target_user, "to_name": "对方"
})
else:
error_msg = {
"type": "error",
"message": f"用户 {target_user} 不在线",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, error_msg)
elif msg_type == "typing":
# 正在输入状态
typing_msg = {
"type": "typing",
"user_id": user_id,
"user_name": user_name,
"is_typing": message.get("is_typing", True),
"timestamp": datetime.utcnow().isoformat(),
}
await advanced_manager.send_room_message(room_id, typing_msg,
exclude_user_ids=[user_id])
elif msg_type == "ping":
# 心跳:回复 pong
pong_msg = {
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, pong_msg)
elif msg_type == "join_room":
# 加入房间
new_room_id = message["room_id"]
await advanced_manager.leave_room(room_id, user_id)
await advanced_manager.join_room(new_room_id, user_id)
room_id = new_room_id
room_change_msg = {
"type": "room_changed",
"room_id": new_room_id,
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, room_change_msg)
elif msg_type == "leave_room":
# 离开房间
await advanced_manager.leave_room(room_id, user_id)
leave_msg = {
"type": "room_left",
"room_id": room_id,
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, leave_msg)
elif msg_type == "user_status":
# 用户状态更新
new_status = message.get("status", "online")
if user_id in advanced_manager.active_connections:
advanced_manager.active_connections[user_id].status = ConnectionStatus(new_status)
status_msg = {
"type": "user_status_updated",
"user_id": user_id,
"status": new_status,
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_room_message(room_id, status_msg)
else:
# 未知消息类型
error_msg = {
"type": "error",
"message": f"未知消息类型: {msg_type}",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, error_msg)
except WebSocketDisconnect:
logger.info(f"用户 {user_name}({user_id}) 断开连接")
await advanced_manager.leave_room(room_id, user_id)
await advanced_manager.disconnect(user_id)
except Exception as e:
logger.error(f"WebSocket聊天处理错误: {str(e)}")
await websocket.close(code=4002, reason=f"处理错误: {str(e)}")
@router.websocket("/ai-chat")
async def ai_chat_websocket(websocket: WebSocket, token: str = Query(...)):
"""AI实时对话WebSocket"""
try:
# 验证用户
payload = decode_token(token)
user_id = str(payload["sub"])
user_name = payload.get("name", payload.get("email", user_id))
if not await advanced_manager.connect_with_session(
websocket, user_id,
f"ai_session_{user_id}_{int(datetime.utcnow().timestamp())}",
{"name": user_name, "type": "ai_chat"}
):
await websocket.close(code=4001, reason="认证失败")
return
logger.info(f"AI聊天会话开始: {user_name}({user_id})")
# 初始化对话历史
conversation_history = []
try:
while True:
user_message = await websocket.receive_text()
message_data = json.loads(user_message)
if message_data.get("type") == "chat":
user_content = message_data["content"]
# 添加用户消息到历史
conversation_history.append({
"role": "user",
"content": user_content,
"timestamp": datetime.utcnow().isoformat()
})
# 发送AI思考中状态
await advanced_manager.send_personal_message(user_id, {
"type": "ai_thinking",
"timestamp": datetime.utcnow().isoformat()
})
# 调用AI服务(这里简化为模拟)
ai_response = await simulate_ai_response(user_content)
# 添加AI回复到历史
conversation_history.append({
"role": "assistant",
"content": ai_response,
"timestamp": datetime.utcnow().isoformat()
})
# 发送AI回复
ai_msg = {
"type": "ai_response",
"content": ai_response,
"timestamp": datetime.utcnow().isoformat(),
"conversation_length": len(conversation_history)
}
await advanced_manager.send_personal_message(user_id, ai_msg)
# 限制对话历史长度
if len(conversation_history) > 20:
conversation_history = conversation_history[-10:] # 保留最近10条
except WebSocketDisconnect:
logger.info(f"AI聊天会话结束: {user_name}({user_id})")
await advanced_manager.disconnect(user_id)
except Exception as e:
logger.error(f"AI聊天WebSocket错误: {str(e)}")
await websocket.close(code=4002, reason=str(e))
async def simulate_ai_response(user_input: str) -> str:
"""模拟AI响应(实际应用中会调用真实的AI服务)"""
import asyncio
# 模拟AI处理时间
await asyncio.sleep(0.5)
responses = {
"你好": "你好!很高兴见到你。今天过得怎么样?",
"天气": "我是一个AI助手,无法获取实时天气信息。建议你查看天气预报应用。",
"帮助": "我可以帮你回答问题、提供信息或者进行聊天。请告诉我你需要什么帮助!",
"再见": "再见!期待下次再聊。祝你有美好的一天!"
}
for keyword, response in responses.items():
if keyword in user_input.lower():
return response
return f"我理解了你说的:'{user_input}'。你能提供更多细节吗?我会尽力帮助你。"
@router.get("/online-users")
async def get_online_users():
"""获取在线用户列表"""
users = advanced_manager.get_online_users()
return {
"count": len(users),
"users": users,
"timestamp": datetime.utcnow().isoformat()
}
@router.get("/room-users/{room_id}")
async def get_room_users(room_id: str):
"""获取房间内用户列表"""
users = advanced_manager.get_room_users(room_id)
return {
"room_id": room_id,
"count": len(users),
"users": users,
"timestamp": datetime.utcnow().isoformat()
}
@router.get("/connection-stats")
async def get_connection_stats():
"""获取连接统计"""
return {
"total_connections": advanced_manager.get_connection_count(),
"total_rooms": advanced_manager.get_room_count(),
"timestamp": datetime.utcnow().isoformat()
}#前端聊天室页面
<!-- templates/chat_room.html -->
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI 实时聊天室 - 道满PythonAI</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.chat-container {
max-width: 1200px;
margin: 0 auto;
background: white;
border-radius: 15px;
box-shadow: 0 20px 40px rgba(0,0,0,0.1);
overflow: hidden;
height: 80vh;
display: grid;
grid-template-rows: auto 1fr auto;
}
.chat-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
text-align: center;
}
.chat-header h1 {
font-size: 1.5em;
margin-bottom: 5px;
}
.chat-header .status {
font-size: 0.9em;
opacity: 0.9;
}
.chat-main {
display: grid;
grid-template-columns: 300px 1fr;
height: 100%;
}
.sidebar {
border-right: 1px solid #eee;
padding: 20px;
background: #f8f9fa;
}
.sidebar h3 {
margin-bottom: 15px;
color: #333;
}
.room-list, .user-list {
margin-bottom: 20px;
}
.room-item, .user-item {
padding: 8px 12px;
margin: 5px 0;
background: white;
border-radius: 8px;
cursor: pointer;
transition: all 0.3s ease;
}
.room-item:hover, .user-item:hover {
background: #e3f2fd;
transform: translateX(5px);
}
.room-item.active {
background: #2196f3;
color: white;
}
.chat-messages {
padding: 20px;
overflow-y: auto;
height: 100%;
}
.message {
margin-bottom: 15px;
padding: 10px 15px;
border-radius: 15px;
max-width: 70%;
animation: fadeIn 0.3s ease;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
.message.own {
background: #dcf8c6;
margin-left: auto;
text-align: right;
}
.message.other {
background: #fff;
border: 1px solid #eee;
}
.message.system {
background: #f0f0f0;
text-align: center;
font-style: italic;
max-width: 100%;
border: none;
font-size: 0.9em;
color: #666;
}
.message .sender {
font-weight: bold;
font-size: 0.9em;
margin-bottom: 5px;
}
.message .time {
font-size: 0.7em;
color: #999;
text-align: right;
margin-top: 5px;
}
.chat-input {
padding: 20px;
border-top: 1px solid #eee;
display: flex;
gap: 10px;
}
.chat-input input {
flex: 1;
padding: 12px 15px;
border: 2px solid #e0e0e0;
border-radius: 25px;
outline: none;
font-size: 1em;
transition: border-color 0.3s ease;
}
.chat-input input:focus {
border-color: #667eea;
}
.chat-input button {
padding: 12px 25px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 25px;
cursor: pointer;
font-size: 1em;
transition: transform 0.2s ease;
}
.chat-input button:hover {
transform: scale(1.05);
}
.typing-indicator {
font-style: italic;
color: #999;
font-size: 0.9em;
margin-top: 5px;
}
.ai-thinking {
display: inline-block;
padding: 5px 10px;
background: #fff3cd;
border-radius: 15px;
font-size: 0.9em;
}
.ai-thinking::after {
content: "...";
animation: dots 1.5s infinite;
}
@keyframes dots {
0%, 20% { content: "."; }
40% { content: ".."; }
60% { content: "..."; }
80%, 100% { content: ""; }
}
@media (max-width: 768px) {
.chat-main {
grid-template-columns: 1fr;
}
.sidebar {
display: none;
}
.chat-container {
height: 90vh;
}
}
</style>
</head>
<body>
<div class="chat-container">
<div class="chat-header">
<h1>💬 AI 实时聊天室</h1>
<div class="status" id="status">连接状态: 未连接</div>
</div>
<div class="chat-main">
<div class="sidebar">
<div class="room-list">
<h3>Rooms</h3>
<div class="room-item active" data-room="general">公共大厅</div>
<div class="room-item" data-room="ai-help">AI助手</div>
<div class="room-item" data-room="tech">技术交流</div>
<div class="room-item" data-room="random">闲聊灌水</div>
</div>
<div class="user-list">
<h3>在线用户 (<span id="user-count">0</span>)</h3>
<div id="users-container"></div>
</div>
</div>
<div class="chat-messages" id="messages"></div>
</div>
<div class="chat-input">
<input type="text" id="messageInput" placeholder="输入消息..." autocomplete="off">
<button onclick="sendMessage()">发送</button>
</div>
</div>
<script>
// 注意:在实际应用中,token应该从后端安全获取
const token = localStorage.getItem('websocket_token') || 'your_jwt_token_here';
let ws;
let currentRoom = 'general';
// 连接WebSocket
function connect() {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUri = `${protocol}//${window.location.host}/ws/chat?token=${token}&room_id=${currentRoom}`;
ws = new WebSocket(wsUri);
ws.onopen = function(evt) {
document.getElementById('status').innerHTML = '连接状态: <span style="color: #4caf50;">已连接</span>';
addMessage('已连接到聊天室', 'system');
};
ws.onclose = function(evt) {
document.getElementById('status').innerHTML = '连接状态: <span style="color: #f44336;">已断开</span>';
addMessage('与服务器连接已断开,正在尝试重连...', 'system');
setTimeout(connect, 3000); // 3秒后重连
};
ws.onmessage = function(evt) {
try {
const data = JSON.parse(evt.data);
handleMessage(data);
} catch (e) {
console.error('解析消息失败:', e);
}
};
ws.onerror = function(evt) {
console.error('WebSocket错误:', evt);
addMessage('WebSocket连接错误', 'system');
};
}
// 处理消息
function handleMessage(data) {
switch(data.type) {
case 'message':
addMessage(`${data.user_name}: ${data.content}`, 'other');
break;
case 'system':
addMessage(data.content, 'system');
break;
case 'private_message':
addMessage(`.Private from ${data.from_name}: ${data.content}`, 'other');
break;
case 'typing':
if (data.is_typing) {
document.getElementById('typingIndicator').textContent = `${data.user_name} 正在输入...`;
} else {
document.getElementById('typingIndicator').textContent = '';
}
break;
case 'ai_thinking':
addMessage('AI助手正在思考中', 'ai-thinking');
break;
case 'ai_response':
addMessage(`🤖 AI助手: ${data.content}`, 'other');
break;
case 'user_list':
updateUserList(data.users);
break;
case 'welcome':
addMessage(data.content, 'system');
break;
case 'pong':
// 心跳响应,不做处理
break;
default:
addMessage(`未知消息类型: ${JSON.stringify(data)}`, 'system');
}
}
// 添加消息到聊天窗口
function addMessage(content, type = 'other') {
const messagesDiv = document.getElementById('messages');
const messageDiv = document.createElement('div');
messageDiv.className = `message ${type}`;
if (type === 'system') {
messageDiv.innerHTML = `<div class="content">${content}</div>`;
} else if (type === 'ai-thinking') {
messageDiv.innerHTML = `<div class="ai-thinking">${content}</div>`;
} else {
const isOwn = type === 'own';
const sender = isOwn ? '你' : '';
const time = new Date().toLocaleTimeString();
messageDiv.innerHTML = `
<div class="sender">${sender}</div>
<div class="content">${content}</div>
<div class="time">${time}</div>
`;
}
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
// 发送消息
function sendMessage() {
const input = document.getElementById('messageInput');
const content = input.value.trim();
if (content) {
const message = {
type: 'text',
content: content,
timestamp: new Date().toISOString()
};
ws.send(JSON.stringify(message));
addMessage(content, 'own');
input.value = '';
}
}
// 更新用户列表
function updateUserList(users) {
const usersContainer = document.getElementById('users-container');
const userCount = document.getElementById('user-count');
userCount.textContent = users.length;
usersContainer.innerHTML = users.map(user =>
`<div class="user-item">
<strong>${user.user_info.name}</strong>
<small>(${user.user_id})</small>
</div>`
).join('');
}
// 切换房间
function switchRoom(roomId) {
currentRoom = roomId;
const message = {
type: 'join_room',
room_id: roomId
};
ws.send(JSON.stringify(message));
// 更新房间选择样式
document.querySelectorAll('.room-item').forEach(item => {
item.classList.remove('active');
});
event.target.classList.add('active');
}
// 键盘事件
document.getElementById('messageInput').addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
sendMessage();
}
});
// 绑定房间切换事件
document.querySelectorAll('.room-item').forEach(item => {
item.addEventListener('click', function() {
switchRoom(this.dataset.room);
});
});
// 启动连接
window.onload = connect;
// 心跳机制
setInterval(() => {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({type: 'ping'}));
}
}, 30000); // 每30秒发送一次心跳
</script>
</body>
</html>#AI实时助手集成
#AI WebSocket处理器
# ai_websocket_handler.py - AI WebSocket处理器
import asyncio
import json
from typing import Dict, Any, AsyncGenerator
from datetime import datetime
import logging
from openai import AsyncOpenAI
import tiktoken
logger = logging.getLogger(__name__)
class AIWebSocketHandler:
"""AI WebSocket处理器"""
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
self.encoders = {} # 缓存token编码器
self.max_context_tokens = 8192 # 最大上下文长度
self.conversation_history = {} # 用户对话历史
def get_encoder(self, model_name: str):
"""获取token编码器"""
if model_name not in self.encoders:
self.encoders[model_name] = tiktoken.encoding_for_model(model_name)
return self.encoders[model_name]
def count_tokens(self, text: str, model_name: str) -> int:
"""计算文本token数量"""
encoder = self.get_encoder(model_name)
return len(encoder.encode(text))
def trim_conversation_history(self, user_id: str, max_tokens: int = 4000):
"""修剪对话历史以节省token"""
if user_id not in self.conversation_history:
return
history = self.conversation_history[user_id]
encoder = self.get_encoder(self.model)
# 计算总token数
total_tokens = 0
for msg in history:
total_tokens += len(encoder.encode(msg['content']))
# 如果超过限制,移除最早的对话
while total_tokens > max_tokens and len(history) > 2: # 保留至少2条消息(系统消息+最后一条)
removed_msg = history.pop(0)
total_tokens -= len(encoder.encode(removed_msg['content']))
async def handle_ai_request(self, websocket, user_id: str, message_data: Dict[str, Any]):
"""处理AI请求"""
try:
message_type = message_data.get('type', 'chat')
if message_type == 'chat':
await self.handle_chat_message(websocket, user_id, message_data)
elif message_type == 'reset_conversation':
await self.reset_conversation(user_id)
await self.send_message(websocket, 'conversation_reset', {'status': 'success'})
elif message_type == 'get_conversation_history':
history = self.get_conversation_history(user_id)
await self.send_message(websocket, 'conversation_history', {'history': history})
else:
await self.send_error(websocket, f"未知AI消息类型: {message_type}")
except Exception as e:
logger.error(f"处理AI请求错误: {str(e)}")
await self.send_error(websocket, f"AI处理错误: {str(e)}")
async def handle_chat_message(self, websocket, user_id: str, message_data: Dict[str, Any]):
"""处理聊天消息"""
user_message = message_data.get('content', '')
# 获取或初始化对话历史
if user_id not in self.conversation_history:
self.conversation_history[user_id] = [
{
"role": "system",
"content": "你是一个有用的AI助手,专门帮助用户解决问题和提供信息。请保持回答简洁明了,如果不确定答案,请诚实地说不知道。"
}
]
# 修剪对话历史
self.trim_conversation_history(user_id)
# 添加用户消息到历史
self.conversation_history[user_id].append({
"role": "user",
"content": user_message,
"timestamp": datetime.utcnow().isoformat()
})
# 发送思考中状态
await self.send_message(websocket, 'ai_thinking', {
'status': 'generating',
'timestamp': datetime.utcnow().isoformat()
})
try:
# 流式调用AI
response = await self.client.chat.completions.create(
model=self.model,
messages=self.conversation_history[user_id],
stream=True,
temperature=0.7,
max_tokens=1000
)
# 收集AI回复
full_response = ""
async for chunk in response:
if chunk.choices[0].delta.content is not None:
token = chunk.choices[0].delta.content
full_response += token
# 实时发送token
await self.send_message(websocket, 'ai_token', {
'token': token,
'partial_response': full_response
})
# 添加AI回复到历史
self.conversation_history[user_id].append({
"role": "assistant",
"content": full_response,
"timestamp": datetime.utcnow().isoformat()
})
# 发送完成消息
await self.send_message(websocket, 'ai_response_complete', {
'full_response': full_response,
'timestamp': datetime.utcnow().isoformat(),
'conversation_length': len(self.conversation_history[user_id])
})
except Exception as e:
logger.error(f"AI调用错误: {str(e)}")
await self.send_error(websocket, f"AI服务错误: {str(e)}")
async def reset_conversation(self, user_id: str):
"""重置对话历史"""
if user_id in self.conversation_history:
# 保留系统消息
system_msg = self.conversation_history[user_id][0] if self.conversation_history[user_id] else None
self.conversation_history[user_id] = [system_msg] if system_msg else []
def get_conversation_history(self, user_id: str) -> list:
"""获取对话历史"""
return self.conversation_history.get(user_id, [])
async def send_message(self, websocket, message_type: str, data: Dict[str, Any]):
"""发送消息"""
message = {
'type': message_type,
'data': data,
'timestamp': datetime.utcnow().isoformat()
}
await websocket.send_text(json.dumps(message))
async def send_error(self, websocket, error_message: str):
"""发送错误消息"""
await self.send_message(websocket, 'error', {
'message': error_message,
'timestamp': datetime.utcnow().isoformat()
})
# 全局AI处理器实例
ai_handler = AIWebSocketHandler(api_key="your-openai-api-key-here")#AI聊天WebSocket端点
@router.websocket("/ai-assistant")
async def ai_assistant_websocket(websocket: WebSocket, token: str = Query(...)):
"""AI助手WebSocket端点"""
try:
# 验证用户
payload = decode_token(token)
user_id = str(payload["sub"])
user_name = payload.get("name", payload.get("email", user_id))
if not await advanced_manager.connect_with_session(
websocket, user_id,
f"ai_assistant_{user_id}_{int(datetime.utcnow().timestamp())}",
{"name": user_name, "type": "ai_assistant", "capabilities": ["chat", "analysis"]}
):
await websocket.close(code=4001, reason="认证失败")
return
logger.info(f"AI助手会话开始: {user_name}({user_id})")
try:
while True:
raw_message = await websocket.receive_text()
message_data = json.loads(raw_message)
# 处理AI请求
await ai_handler.handle_ai_request(websocket, user_id, message_data)
except WebSocketDisconnect:
logger.info(f"AI助手会话结束: {user_name}({user_id})")
await advanced_manager.disconnect(user_id)
# 清理对话历史
await ai_handler.reset_conversation(user_id)
except Exception as e:
logger.error(f"AI助手WebSocket错误: {str(e)}")
await websocket.close(code=4002, reason=str(e))#安全认证机制
#WebSocket认证中间件
# middleware/websocket_auth.py - WebSocket认证中间件
import jwt
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import redis.asyncio as redis
import hashlib
import secrets
class WebSocketAuth:
"""WebSocket认证类"""
def __init__(self, secret_key: str, redis_url: str = "redis://localhost:6379"):
self.secret_key = secret_key
self.redis_client = redis.from_url(redis_url)
self.algorithm = "HS256"
def generate_token(self, user_id: str, user_info: Dict[str, Any] = None,
expires_delta: timedelta = None) -> str:
"""生成JWT令牌"""
if expires_delta is None:
expires_delta = timedelta(hours=24)
expire = datetime.utcnow() + expires_delta
payload = {
"sub": user_id,
"exp": expire.timestamp(),
"iat": datetime.utcnow().timestamp(),
"jti": secrets.token_hex(16), # JWT ID for revocation
**(user_info or {})
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
async def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""验证JWT令牌"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# 检查是否被撤销
jti = payload.get("jti")
if jti:
revoked = await self.redis_client.get(f"jwt_revoked:{jti}")
if revoked:
return None
return payload
except jwt.ExpiredSignatureError:
logger.warning("JWT令牌已过期")
return None
except jwt.InvalidTokenError as e:
logger.warning(f"无效的JWT令牌: {str(e)}")
return None
async def revoke_token(self, token: str) -> bool:
"""撤销JWT令牌"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm],
options={"verify_signature": False})
jti = payload.get("jti")
if jti:
# 设置过期时间为原始过期时间,确保在真正过期前都有效
exp = payload.get("exp", datetime.utcnow().timestamp() + 3600)
ttl = max(0, int(exp - datetime.utcnow().timestamp()))
await self.redis_client.setex(f"jwt_revoked:{jti}", ttl, "1")
return True
return False
except Exception as e:
logger.error(f"撤销令牌失败: {str(e)}")
return False
async def rate_limit(self, identifier: str, limit: int = 10, window: int = 60) -> bool:
"""速率限制"""
key = f"rate_limit:{identifier}"
current = await self.redis_client.get(key)
if current is None:
await self.redis_client.setex(key, window, 1)
return True
current = int(current)
if current >= limit:
return False
await self.redis_client.incr(key)
return True
async def check_concurrent_sessions(self, user_id: str, max_sessions: int = 3) -> bool:
"""检查并发会话限制"""
sessions_key = f"user_sessions:{user_id}"
current_sessions = await self.redis_client.zcount(sessions_key, "-inf", "+inf")
if current_sessions >= max_sessions:
return False
# 添加当前会话
session_id = f"session_{user_id}_{int(datetime.utcnow().timestamp())}"
await self.redis_client.zadd(sessions_key, {session_id: datetime.utcnow().timestamp()})
# 设置会话过期时间
await self.redis_client.expire(sessions_key, 86400) # 24小时
return True
async def cleanup_expired_sessions(self, user_id: str):
"""清理过期会话"""
sessions_key = f"user_sessions:{user_id}"
cutoff = datetime.utcnow().timestamp() - 86400 # 24小时前
await self.redis_client.zremrangebyscore(sessions_key, "-inf", cutoff)
# 全局认证实例
websocket_auth = WebSocketAuth(secret_key="your-secret-key-here")#认证保护的WebSocket端点
@router.websocket("/secure-chat")
async def secure_chat_websocket(
websocket: WebSocket,
token: str = Query(...),
room_id: str = Query("general")
):
"""安全认证的聊天WebSocket"""
# 1. 验证JWT令牌
payload = await websocket_auth.verify_token(token)
if not payload:
await websocket.close(code=4001, reason="认证失败")
return
user_id = str(payload["sub"])
user_name = payload.get("name", payload.get("email", user_id))
permissions = payload.get("permissions", [])
# 2. 检查速率限制
client_ip = websocket.client.host
if not await websocket_auth.rate_limit(f"ws_{client_ip}", limit=5, window=60):
await websocket.close(code=4002, reason="请求过于频繁")
return
if not await websocket_auth.rate_limit(f"ws_user_{user_id}", limit=100, window=300):
await websocket.close(code=4002, reason="用户请求过于频繁")
return
# 3. 检查并发会话限制
if not await websocket_auth.check_concurrent_sessions(user_id, max_sessions=5):
await websocket.close(code=4003, reason="并发会话数超限")
return
# 4. 检查房间访问权限
required_permission = f"room:{room_id}"
if required_permission not in permissions and "admin" not in permissions:
await websocket.close(code=4003, reason="无房间访问权限")
return
# 5. 建立安全连接
session_id = f"secure_{user_id}_{int(datetime.utcnow().timestamp())}"
user_info = {
"name": user_name,
"email": payload.get("email"),
"permissions": permissions,
"security_level": "high"
}
if not await advanced_manager.connect_with_session(websocket, user_id, session_id, user_info):
await websocket.close(code=4001, reason="连接失败")
return
# 6. 加入房间
await advanced_manager.join_room(room_id, user_id)
# 7. 发送安全连接确认
security_msg = {
"type": "security",
"content": "安全连接已建立",
"timestamp": datetime.utcnow().isoformat(),
"encrypted": True
}
await advanced_manager.send_personal_message(user_id, security_msg)
try:
while True:
raw_data = await websocket.receive_text()
# 验证消息格式和内容
try:
message = json.loads(raw_data)
msg_type = message.get("type", "text")
# 检查消息频率限制
if not await websocket_auth.rate_limit(f"msg_{user_id}", limit=10, window=10):
warning_msg = {
"type": "warning",
"content": "消息发送过于频繁,请稍后再试",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, warning_msg)
continue
# 处理不同类型的消息
if msg_type == "text":
# 验证消息内容
content = message.get("content", "")
if len(content.strip()) == 0:
continue
# 过滤敏感内容
if await contains_sensitive_content(content):
filtered_msg = {
"type": "warning",
"content": "消息包含敏感内容,已被过滤",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, filtered_msg)
continue
# 广播消息
chat_msg = {
"type": "message",
"user_id": user_id,
"user_name": user_name,
"content": content,
"timestamp": datetime.utcnow().isoformat(),
"room_id": room_id,
"verified": True
}
await advanced_manager.send_room_message(room_id, chat_msg)
elif msg_type == "file_upload":
# 处理文件上传(需要额外的安全验证)
if "file_upload" not in permissions:
error_msg = {
"type": "error",
"content": "无文件上传权限",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, error_msg)
continue
# 这里可以添加文件上传逻辑
pass
elif msg_type == "admin_command":
# 管理员命令(需要管理员权限)
if "admin" not in permissions:
continue # 静默忽略
# 处理管理员命令
command = message.get("command")
await handle_admin_command(command, user_id, room_id)
except json.JSONDecodeError:
error_msg = {
"type": "error",
"content": "消息格式错误",
"timestamp": datetime.utcnow().isoformat()
}
await advanced_manager.send_personal_message(user_id, error_msg)
except WebSocketDisconnect:
logger.info(f"安全聊天连接断开: {user_name}({user_id})")
await advanced_manager.leave_room(room_id, user_id)
await advanced_manager.disconnect(user_id)
# 清理会话
await websocket_auth.cleanup_expired_sessions(user_id)
async def contains_sensitive_content(content: str) -> bool:
"""检查是否包含敏感内容"""
# 简单的敏感词检查(实际应用中应该使用更复杂的算法)
sensitive_words = ["敏感词1", "敏感词2", "违禁词"]
content_lower = content.lower()
return any(word.lower() in content_lower for word in sensitive_words)
async def handle_admin_command(command: str, user_id: str, room_id: str):
"""处理管理员命令"""
# 实现管理员命令逻辑
pass#性能优化策略
#连接池和异步优化
# performance/connection_pool.py - 连接池和性能优化
import asyncio
from asyncio import Queue
from typing import Dict, List, Optional, Callable
import logging
from dataclasses import dataclass
from datetime import datetime
import time
import psutil
from collections import deque
import weakref
logger = logging.getLogger(__name__)
@dataclass
class ConnectionMetrics:
"""连接指标"""
connection_id: str
connected_at: datetime
last_activity: datetime
message_count: int
bytes_sent: int
bytes_received: int
latency: float
class ConnectionPool:
"""WebSocket连接池"""
def __init__(self, max_connections: int = 10000, cleanup_interval: int = 300):
self.max_connections = max_connections
self.connections: Dict[str, weakref.ref] = {}
self.metrics: Dict[str, ConnectionMetrics] = {}
self.cleanup_interval = cleanup_interval
self.message_queue = Queue()
self.broadcast_queue = Queue()
self.stats = {
'total_connections': 0,
'active_connections': 0,
'messages_processed': 0,
'bytes_transferred': 0
}
def register_connection(self, connection_id: str, websocket) -> bool:
"""注册连接"""
if len(self.connections) >= self.max_connections:
return False
# 创建弱引用以避免内存泄漏
self.connections[connection_id] = weakref.ref(websocket)
self.metrics[connection_id] = ConnectionMetrics(
connection_id=connection_id,
connected_at=datetime.utcnow(),
last_activity=datetime.utcnow(),
message_count=0,
bytes_sent=0,
bytes_received=0,
latency=0.0
)
self.stats['total_connections'] += 1
self.stats['active_connections'] += 1
logger.info(f"连接已注册: {connection_id}, 当前活跃连接数: {self.stats['active_connections']}")
return True
def unregister_connection(self, connection_id: str):
"""注销连接"""
if connection_id in self.connections:
del self.connections[connection_id]
if connection_id in self.metrics:
del self.metrics[connection_id]
self.stats['active_connections'] -= 1
logger.info(f"连接已注销: {connection_id}, 当前活跃连接数: {self.stats['active_connections']}")
async def cleanup_inactive_connections(self, timeout_seconds: int = 3600):
"""清理非活跃连接"""
current_time = datetime.utcnow()
cutoff_time = current_time - timedelta(seconds=timeout_seconds)
inactive_connections = []
for conn_id, metrics in self.metrics.items():
if metrics.last_activity < cutoff_time:
inactive_connections.append(conn_id)
for conn_id in inactive_connections:
logger.info(f"清理超时连接: {conn_id}")
# 这里应该实际关闭连接
self.unregister_connection(conn_id)
async def broadcast_message_optimized(self, message: dict, connection_ids: List[str] = None):
"""优化的广播消息"""
if connection_ids is None:
connection_ids = list(self.connections.keys())
# 批量发送消息以提高性能
serialized_message = json.dumps(message)
message_size = len(serialized_message.encode('utf-8'))
successful_sends = 0
failed_sends = 0
# 使用异步批量发送
tasks = []
for conn_id in connection_ids:
if conn_id in self.connections:
websocket_ref = self.connections[conn_id]
websocket = websocket_ref()
if websocket is not None:
task = self._send_message_safe(websocket, serialized_message, message_size, conn_id)
tasks.append(task)
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_sends = sum(1 for r in results if r is True)
failed_sends = len(results) - successful_sends
logger.info(f"广播完成: 成功 {successful_sends}, 失败 {failed_sends}")
return {"successful": successful_sends, "failed": failed_sends}
async def _send_message_safe(self, websocket, message: str, message_size: int, conn_id: str) -> bool:
"""安全发送消息"""
try:
await websocket.send_text(message)
# 更新指标
if conn_id in self.metrics:
metrics = self.metrics[conn_id]
metrics.last_activity = datetime.utcnow()
metrics.bytes_sent += message_size
metrics.message_count += 1
return True
except Exception as e:
logger.error(f"发送消息失败到连接 {conn_id}: {str(e)}")
# 标记连接为失效
self.unregister_connection(conn_id)
return False
def get_performance_metrics(self) -> dict:
"""获取性能指标"""
current_time = time.time()
# 计算系统资源使用情况
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
# 计算连接统计
active_metrics = list(self.metrics.values())
avg_latency = sum(m.latency for m in active_metrics) / len(active_metrics) if active_metrics else 0
total_bytes = sum(m.bytes_sent + m.bytes_received for m in active_metrics)
return {
"system": {
"cpu_percent": cpu_percent,
"memory_percent": memory_percent,
"timestamp": current_time
},
"connections": {
"total_registered": len(self.connections),
"currently_active": self.stats['active_connections'],
"peak_connections": self.stats['total_connections'],
},
"messages": {
"processed": self.stats['messages_processed'],
"bytes_transferred": self.stats['bytes_transferred'],
"avg_latency": avg_latency
},
"health": {
"status": "healthy" if memory_percent < 80 and cpu_percent < 80 else "warning"
}
}
# 全局连接池实例
connection_pool = ConnectionPool(max_connections=5000)
class MessageBatchProcessor:
"""消息批量处理器"""
def __init__(self, batch_size: int = 100, flush_interval: float = 0.1):
self.batch_size = batch_size
self.flush_interval = flush_interval
self.message_buffer = deque()
self.is_running = False
async def start_processor(self):
"""启动处理器"""
self.is_running = True
while self.is_running:
await self._process_batch()
await asyncio.sleep(self.flush_interval)
async def stop_processor(self):
"""停止处理器"""
self.is_running = False
# 处理剩余的消息
await self._flush_remaining()
async def add_message(self, message: dict, target_connections: List[str] = None):
"""添加消息到缓冲区"""
self.message_buffer.append((message, target_connections))
# 如果达到批次大小,立即处理
if len(self.message_buffer) >= self.batch_size:
await self._process_batch()
async def _process_batch(self):
"""处理批次消息"""
if not self.message_buffer:
return
# 取出一批消息
batch = []
while len(batch) < self.batch_size and self.message_buffer:
batch.append(self.message_buffer.popleft())
# 批量处理消息
for message, target_connections in batch:
await connection_pool.broadcast_message_optimized(message, target_connections)
async def _flush_remaining(self):
"""清空剩余消息"""
while self.message_buffer:
message, target_connections = self.message_buffer.popleft()
await connection_pool.broadcast_message_optimized(message, target_connections)
# 全局消息处理器实例
message_processor = MessageBatchProcessor(batch_size=50, flush_interval=0.05)
# 启动消息处理器
asyncio.create_task(message_processor.start_processor())
## 集群部署方案 \{#集群部署方案}
### 使用Redis进行集群协调
```python
# cluster/redis_cluster.py - Redis集群协调
import redis.asyncio as redis
from typing import Dict, List, Optional
import json
import asyncio
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class RedisClusterCoordinator:
"""Redis集群协调器"""
def __init__(self, redis_urls: List[str]):
self.redis_clients = [redis.from_url(url) for url in redis_urls]
self.cluster_nodes = redis_urls
self.node_id = f"node_{id(self)}"
self.heartbeat_interval = 5 # 心跳间隔5秒
self.node_ttl = 30 # 节点存活时间30秒
async def register_node(self):
"""注册节点"""
node_info = {
"node_id": self.node_id,
"registered_at": datetime.utcnow().isoformat(),
"connections": 0,
"messages_per_second": 0
}
# 在Redis中注册节点
await self.redis_clients[0].hset("cluster:nodes", self.node_id, json.dumps(node_info))
await self.redis_clients[0].expire(f"cluster:nodes", self.node_ttl * 2)
logger.info(f"节点已注册: {self.node_id}")
async def heartbeat(self):
"""节点心跳"""
while True:
try:
# 更新节点信息
node_info = {
"node_id": self.node_id,
"last_heartbeat": datetime.utcnow().isoformat(),
"connections": len(advanced_manager.active_connections),
"messages_per_second": self.get_current_mps()
}
await self.redis_clients[0].hset("cluster:nodes", self.node_id, json.dumps(node_info))
await self.redis_clients[0].expire(f"cluster:nodes", self.node_ttl * 2)
# 检查其他节点状态
await self.check_cluster_health()
await asyncio.sleep(self.heartbeat_interval)
except Exception as e:
logger.error(f"心跳错误: {str(e)}")
await asyncio.sleep(1)
async def check_cluster_health(self):
"""检查集群健康状态"""
try:
all_nodes = await self.redis_clients[0].hgetall("cluster:nodes")
healthy_nodes = []
for node_id, node_data in all_nodes.items():
node_info = json.loads(node_data)
last_heartbeat = datetime.fromisoformat(node_info['last_heartbeat'])
# 检查节点是否超时
if (datetime.utcnow() - last_heartbeat).total_seconds() < self.node_ttl:
healthy_nodes.append(node_info)
logger.info(f"集群健康状态: {len(healthy_nodes)} 个节点正常")
return healthy_nodes
except Exception as e:
logger.error(f"检查集群健康状态错误: {str(e)}")
return []
async def broadcast_to_cluster(self, message: dict, exclude_node: str = None):
"""向集群广播消息"""
# 使用Redis发布订阅进行跨节点消息传递
channel_message = {
"source_node": self.node_id,
"exclude_node": exclude_node,
"message": message,
"timestamp": datetime.utcnow().isoformat()
}
await self.redis_clients[0].publish("cluster:messages", json.dumps(channel_message))
async def subscribe_to_cluster_messages(self):
"""订阅集群消息"""
pubsub = self.redis_clients[0].pubsub()
await pubsub.subscribe("cluster:messages")
async for message in pubsub.listen():
if message['type'] == 'message':
try:
data = json.loads(message['data'])
# 避免消息循环(来自自己的消息)
if data.get('source_node') != self.node_id and data.get('exclude_node') != self.node_id:
# 处理集群消息(转发给本地连接)
await self.forward_cluster_message(data['message'])
except Exception as e:
logger.error(f"处理集群消息错误: {str(e)}")
async def forward_cluster_message(self, message: dict):
"""转发集群消息到本地连接"""
# 这里可以根据消息类型决定如何处理
# 例如:广播给所有本地连接,或根据用户ID定向发送
for user_id in advanced_manager.active_connections.keys():
try:
await advanced_manager.send_personal_message(user_id, message)
except Exception as e:
logger.error(f"转发消息到本地用户 {user_id} 失败: {str(e)}")
def get_current_mps(self) -> int:
"""获取当前消息每秒处理数"""
# 这里应该根据实际的监控数据计算
return 0
# 全局集群协调器实例
cluster_coordinator = RedisClusterCoordinator([
"redis://localhost:6379/0",
"redis://localhost:6380/0", # 集群中的其他节点
])
# 启动集群协调器
asyncio.create_task(cluster_coordinator.register_node())
asyncio.create_task(cluster_coordinator.heartbeat())
asyncio.create_task(cluster_coordinator.subscribe_to_cluster_messages())#负载均衡配置
# cluster/load_balancer.py - 负载均衡配置
from typing import Dict, List, Tuple
import asyncio
import time
import logging
logger = logging.getLogger(__name__)
class LoadBalancer:
"""负载均衡器"""
def __init__(self):
self.node_stats: Dict[str, dict] = {}
self.connection_counts: Dict[str, int] = {}
self.response_times: Dict[str, List[float]] = {}
self.max_response_time = 1000 # ms
self.max_connections_per_node = 1000
def update_node_stats(self, node_id: str, connections: int, response_time: float):
"""更新节点统计"""
self.connection_counts[node_id] = connections
if node_id not in self.response_times:
self.response_times[node_id] = []
self.response_times[node_id].append(response_time)
# 保持响应时间列表不超过100个样本
if len(self.response_times[node_id]) > 100:
self.response_times[node_id] = self.response_times[node_id][-100:]
def select_best_node(self) -> str:
"""选择最佳节点"""
if not self.connection_counts:
return "local"
# 计算加权评分
best_score = float('inf')
best_node = None
for node_id, connections in self.connection_counts.items():
# 计算平均响应时间
avg_response_time = sum(self.response_times.get(node_id, [0])) / max(len(self.response_times.get(node_id, [0])), 1)
# 计算负载分数(越低越好)
load_score = (
connections / self.max_connections_per_node * 0.5 + # 连接数权重
avg_response_time / self.max_response_time * 0.5 # 响应时间权重
)
if load_score < best_score:
best_score = load_score
best_node = node_id
return best_node or "local"
def get_node_health_score(self, node_id: str) -> float:
"""获取节点健康分数(0-1,1为最健康)"""
connections = self.connection_counts.get(node_id, 0)
avg_response_time = sum(self.response_times.get(node_id, [0])) / max(len(self.response_times.get(node_id, [0])), 1)
# 连接数健康度(越少越健康)
connection_health = max(0, 1 - connections / self.max_connections_per_node)
# 响应时间健康度(越快越健康)
response_health = max(0, 1 - avg_response_time / self.max_response_time)
# 综合健康度
health_score = (connection_health + response_health) / 2
return health_score
# 全局负载均衡器实例
load_balancer = LoadBalancer()#生产环境最佳实践
#监控和告警
# monitoring/websocket_monitor.py - WebSocket监控
import asyncio
import psutil
import time
from datetime import datetime
import logging
from typing import Dict, Any
import aiohttp
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class MonitorMetrics:
"""监控指标"""
timestamp: datetime
active_connections: int
messages_per_second: int
cpu_percent: float
memory_percent: float
disk_percent: float
network_io: Dict[str, int]
response_time_ms: float
class WebSocketMonitor:
"""WebSocket监控器"""
def __init__(self, alert_webhook_url: str = None):
self.alert_webhook_url = alert_webhook_url
self.metrics_history = []
self.max_history = 1000 # 保留1000个历史记录
self.alert_thresholds = {
'cpu_percent': 80,
'memory_percent': 85,
'active_connections': 5000,
'response_time_ms': 500
}
self.monitor_interval = 10 # 每10秒监控一次
async def start_monitoring(self):
"""开始监控"""
while True:
try:
metrics = await self.collect_metrics()
self.metrics_history.append(metrics)
# 限制历史记录数量
if len(self.metrics_history) > self.max_history:
self.metrics_history = self.metrics_history[-500:] # 保留最新的500条
# 检查告警阈值
await self.check_alerts(metrics)
# 发送监控数据(如果配置了监控服务)
await self.send_monitoring_data(metrics)
await asyncio.sleep(self.monitor_interval)
except Exception as e:
logger.error(f"监控错误: {str(e)}")
await asyncio.sleep(5) # 错误后稍等再继续
async def collect_metrics(self) -> MonitorMetrics:
"""收集监控指标"""
# 系统指标
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
disk_percent = psutil.disk_usage('/').percent
network_io = psutil.net_io_counters()._asdict()
# WebSocket指标
active_connections = len(advanced_manager.active_connections)
messages_per_second = self.calculate_mps()
response_time_ms = await self.measure_response_time()
return MonitorMetrics(
timestamp=datetime.utcnow(),
active_connections=active_connections,
messages_per_second=messages_per_second,
cpu_percent=cpu_percent,
memory_percent=memory_percent,
disk_percent=disk_percent,
network_io=network_io,
response_time_ms=response_time_ms
)
def calculate_mps(self) -> int:
"""计算消息每秒处理数"""
# 这里应该根据实际的消息处理统计计算
return 0
async def measure_response_time(self) -> float:
"""测量响应时间"""
start_time = time.time()
# 发送测试消息并测量响应时间
test_start = time.time()
# 这里可以发送一个测试消息来测量实际响应时间
test_end = time.time()
return (test_end - test_start) * 1000 # 转换为毫秒
async def check_alerts(self, metrics: MonitorMetrics):
"""检查告警"""
alerts = []
if metrics.cpu_percent > self.alert_thresholds['cpu_percent']:
alerts.append(f"CPU使用率过高: {metrics.cpu_percent}%")
if metrics.memory_percent > self.alert_thresholds['memory_percent']:
alerts.append(f"内存使用率过高: {metrics.memory_percent}%")
if metrics.active_connections > self.alert_thresholds['active_connections']:
alerts.append(f"连接数过高: {metrics.active_connections}")
if metrics.response_time_ms > self.alert_thresholds['response_time_ms']:
alerts.append(f"响应时间过高: {metrics.response_time_ms}ms")
if alerts:
await self.send_alert(alerts, metrics)
async def send_alert(self, alerts: List[str], metrics: MonitorMetrics):
"""发送告警"""
alert_message = {
"alert_type": "websocket_monitoring",
"timestamp": metrics.timestamp.isoformat(),
"alerts": alerts,
"current_metrics": {
"active_connections": metrics.active_connections,
"cpu_percent": metrics.cpu_percent,
"memory_percent": metrics.memory_percent,
"response_time_ms": metrics.response_time_ms
}
}
logger.error(f"监控告警: {alert_message}")
# 如果配置了告警webhook,发送告警
if self.alert_webhook_url:
try:
async with aiohttp.ClientSession() as session:
await session.post(
self.alert_webhook_url,
json=alert_message,
timeout=aiohttp.ClientTimeout(total=10)
)
except Exception as e:
logger.error(f"发送告警失败: {str(e)}")
async def send_monitoring_data(self, metrics: MonitorMetrics):
"""发送监控数据到监控系统"""
# 这里可以集成Prometheus、InfluxDB等监控系统
pass
def get_health_status(self) -> Dict[str, Any]:
"""获取健康状态"""
if not self.metrics_history:
return {"status": "unknown", "reason": "no metrics collected"}
latest_metrics = self.metrics_history[-1]
# 计算健康状态
if (latest_metrics.cpu_percent < self.alert_thresholds['cpu_percent'] and
latest_metrics.memory_percent < self.alert_thresholds['memory_percent'] and
latest_metrics.response_time_ms < self.alert_thresholds['response_time_ms']):
status = "healthy"
else:
status = "degraded"
return {
"status": status,
"metrics": {
"active_connections": latest_metrics.active_connections,
"cpu_percent": latest_metrics.cpu_percent,
"memory_percent": latest_metrics.memory_percent,
"response_time_ms": latest_metrics.response_time_ms
},
"timestamp": latest_metrics.timestamp.isoformat()
}
# 全局监控实例
websocket_monitor = WebSocketMonitor(alert_webhook_url="https://hooks.slack.com/services/YOUR/SLACK/WEBHOOK")
# 启动监控
asyncio.create_task(websocket_monitor.start_monitoring())#生产环境配置
# config/production.py - 生产环境配置
import os
from typing import Dict, Any
import redis.asyncio as redis
from pydantic import BaseSettings
class ProductionConfig(BaseSettings):
"""生产环境配置"""
# 基础配置
APP_NAME: str = "Daoman WebSocket Service"
DEBUG: bool = False
LOG_LEVEL: str = "INFO"
# Redis配置
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
REDIS_CLUSTER_URLS: str = os.getenv("REDIS_CLUSTER_URLS", "redis://localhost:6379/0")
# WebSocket配置
MAX_CONNECTIONS: int = int(os.getenv("MAX_CONNECTIONS", "10000"))
CONNECTION_TIMEOUT: int = int(os.getenv("CONNECTION_TIMEOUT", "3600"))
MESSAGE_QUEUE_SIZE: int = int(os.getenv("MESSAGE_QUEUE_SIZE", "10000"))
# AI服务配置
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
AI_MODEL: str = os.getenv("AI_MODEL", "gpt-4o-mini")
# 安全配置
JWT_SECRET_KEY: str = os.getenv("JWT_SECRET_KEY", "your-production-secret-key")
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
MAX_LOGIN_ATTEMPTS: int = int(os.getenv("MAX_LOGIN_ATTEMPTS", "5"))
# 监控配置
MONITORING_ENABLED: bool = os.getenv("MONITORING_ENABLED", "true").lower() == "true"
ALERT_WEBHOOK_URL: str = os.getenv("ALERT_WEBHOOK_URL", "")
METRICS_COLLECTION_INTERVAL: int = int(os.getenv("METRICS_COLLECTION_INTERVAL", "30"))
# 性能配置
BATCH_PROCESSOR_SIZE: int = int(os.getenv("BATCH_PROCESSOR_SIZE", "100"))
BATCH_FLUSH_INTERVAL: float = float(os.getenv("BATCH_FLUSH_INTERVAL", "0.1"))
# 集群配置
CLUSTER_NODE_ID: str = os.getenv("CLUSTER_NODE_ID", f"node_{os.getpid()}")
CLUSTER_HEARTBEAT_INTERVAL: int = int(os.getenv("CLUSTER_HEARTBEAT_INTERVAL", "5"))
class Config:
env_file = ".env"
case_sensitive = True
# 生产环境配置实例
prod_config = ProductionConfig()
def get_redis_client():
"""获取Redis客户端"""
return redis.from_url(prod_config.REDIS_URL)
def get_cluster_redis_clients():
"""获取集群Redis客户端"""
urls = prod_config.REDIS_CLUSTER_URLS.split(",")
return [redis.from_url(url.strip()) for url in urls]#常见问题解答
#Q1: WebSocket连接频繁断开怎么办?
A: 常见原因及解决方案:
- 网络不稳定:实现自动重连机制,设置合适的重连间隔
- 服务器资源不足:监控CPU、内存使用情况,优化代码性能
- 超时设置不当:调整心跳机制,设置合理的超时时间
- 防火墙限制:检查网络配置,确保WebSocket端口开放
#Q2: 如何处理大量并发连接?
A:
- 优化连接管理:使用连接池,及时清理无效连接
- 水平扩展:部署多个服务实例,使用负载均衡
- Redis集群:使用Redis进行跨节点状态共享
- 消息批处理:减少频繁的I/O操作
#Q3: 如何保证消息顺序?
A:
- 消息序列号:为每条消息添加序列号
- 队列机制:使用有序队列处理消息
- 客户端缓冲:客户端实现消息缓冲和排序
- 分布式锁:关键操作使用分布式锁保证顺序
#Q4: 如何实现消息持久化?
A:
- 数据库存储:将重要消息存储到数据库
- 消息队列:使用RabbitMQ、Kafka等消息队列
- Redis持久化:利用Redis的持久化功能
- 文件存储:将消息记录到日志文件
#Q5: 如何进行压力测试?
A:
- 使用专业工具:如wrk、ab、JMeter等
- 编写测试脚本:模拟多客户端并发连接
- 监控指标:关注连接数、消息吞吐量、响应时间
- 逐步加压:从小规模开始,逐步增加负载
#总结
WebSocket为实时应用提供了高效的双向通信能力,在FastAPI中集成WebSocket可以构建强大的实时应用:
- 架构优势:持久连接,双向通信,低延迟
- 应用场景:聊天、游戏、协作、通知等多种实时场景
- 性能优化:连接池、消息批处理、异步处理
- 安全机制:认证授权、速率限制、内容过滤
- 扩展能力:集群部署、负载均衡、监控告警
💡 核心要点:合理设计消息协议,优化连接管理,建立完善监控体系,确保系统的稳定性和可扩展性。
#SEO优化建议
为了提高这篇WebSocket教程在搜索引擎中的排名,以下是几个关键的SEO优化建议:
#标题优化
- 主标题:使用包含核心关键词的标题,如"FastAPI WebSocket实时通信完全指南"
- 二级标题:每个章节标题都包含相关的长尾关键词
- H1-H6层次结构:保持正确的标题层级,便于搜索引擎理解内容结构
#内容优化
- 关键词密度:在内容中自然地融入关键词如"WebSocket", "实时通信", "AI聊天室", "协作应用", "消息推送"等
- 元描述:在文章开头的元数据中包含吸引人的描述
- 内部链接:链接到其他相关教程,如流式响应 StreamingResponse等
- 外部权威链接:引用官方文档和权威资源
#技术SEO
- 页面加载速度:优化代码块和图片加载
- 移动端适配:确保在移动设备上良好显示
- 结构化数据:使用适当的HTML标签和语义化元素
#用户体验优化
- 内容可读性:使用清晰的段落结构和代码示例
- 互动元素:提供实际可运行的代码示例
- 更新频率:定期更新内容以保持时效性
#常见问题解答(FAQ)
#Q1: WebSocket和FastAPI如何集成?
A: 通过FastAPI的WebSocket装饰器定义端点,使用await接受连接,然后在while循环中接收和发送消息。可以结合依赖注入、认证等功能。
#Q2: 如何处理大量并发连接?
A: 需要使用连接池管理,实现消息批处理,部署多个服务实例,使用Redis进行跨节点状态共享,实施有效的负载均衡策略。
#Q3: WebSocket消息丢失如何处理?
A: 实现消息确认机制,使用消息队列进行持久化,客户端实现消息重传机制,服务器端实现消息去重逻辑。
#Q4: 如何监控WebSocket性能?
A: 实现详细的监控指标收集,包括连接数、消息吞吐量、响应时间、系统资源使用等,设置告警机制,使用专业的APM工具。
#Q5: 生产环境部署需要注意什么?
A: 配置反向代理支持WebSocket,设置合适的超时时间,实施安全措施,建立完整的监控告警体系,考虑灾难恢复方案。
🔗 相关教程推荐
- 流式响应 StreamingResponse - 实时数据推送技术
- OAuth2 与 JWT 鉴权 - 安全认证机制
- Redis 集成 - 缓存与会话管理
- Celery 异步任务队列 - 后台任务处理
- Docker 容器化部署 - 容器化部署策略
🏷️ 标签云: FastAPI WebSocket 实时通信 AI聊天室 协作应用 消息推送 性能优化 安全认证 集群部署 监控告警

