#FastAPI Pytest单元测试完全指南
📂 所属阶段:第五阶段 — 工程化与部署(实战篇)
🔗 相关章节:FastAPI依赖注入系统 · FastAPI多环境配置
#目录
- 单元测试基础概念
- Pytest环境搭建
- TestClient基础使用
- 异步测试详解
- Fixtures与依赖注入测试
- 数据库测试策略
- API端点测试
- 参数化测试与边界测试
- 测试覆盖率分析
- TDD开发实践
- 测试性能优化
- 与其他测试框架对比
- 总结
#单元测试基础概念
#为什么需要单元测试?
单元测试是软件开发中的基石,它确保代码的正确性、可维护性和可扩展性。在FastAPI应用中,单元测试尤为重要:
# 无测试 vs 有测试的开发体验
"""
无测试场景:
- 修改代码 → 手动测试 → 上线 → 发现bug → 回滚 → 用户投诉 😱
- 重构代码 → 担心破坏功能 → 代码腐化 → 技术债务堆积
有测试场景:
- 修改代码 → 运行测试 → 快速反馈 → 自信上线 🚀
- 重构代码 → 测试保证 → 持续优化 → 代码质量提升
"""#测试金字塔
在FastAPI应用中,遵循测试金字塔原则:
┌─────────────────────────┐ ← 单元测试 (Unit Tests) - 70%
│ 业务逻辑层 │ • 快速、隔离、专注
│ (Service Layer) │ • 测试纯函数和业务逻辑
├─────────────────────────┤ ← 集成测试 (Integration) - 20%
│ API/路由层 │ • 测试API端点和数据流
│ (Route Layer) │ • 包含数据库、外部服务
├─────────────────────────┤ ← 端到端测试 (E2E) - 10%
│ UI/界面层 │ • 测试完整用户流程
│ (UI Layer) │ • 使用Selenium等工具
└─────────────────────────┘#测试驱动开发(TDD)的好处
- 设计驱动:先思考接口设计
- 快速反馈:即时验证代码正确性
- 重构安全:测试保护重构过程
- 文档作用:测试即使用示例
- 信心保证:确保功能按预期工作
#Pytest环境搭建
#核心依赖安装
# 基础测试依赖
pip install pytest pytest-asyncio httpx
# 覆盖率分析
pip install pytest-cov coverage
# Mock和补丁
pip install pytest-mock
# 参数化测试增强
pip install pytest-parametrize
# 测试数据生成
pip install factory-boy faker
# 断言增强
pip install pytest-check
# 环境变量管理
pip install python-dotenv pytest-dotenv
# 完整安装命令
pip install pytest pytest-asyncio httpx pytest-cov pytest-mock factory-boy faker#Pytest配置文件
# pytest.ini - 项目根目录
[tool:pytest]
# 测试路径
testpaths = tests
python_files = test_*.py *_test.py
python_classes = Test*
python_functions = test_*
# 异步支持
asyncio_mode = auto
# 输出选项
addopts =
-v
--tb=short
--strict-markers
--strict-config
--disable-warnings
# 覆盖率配置
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
error::pytest.PytestWarning
# 标记配置
markers =
slow: marks tests as slow
integration: marks tests as integration tests
unit: marks tests as unit tests
api: marks tests as api tests
database: marks tests as database tests
auth: marks tests as authentication tests#项目结构配置
# pyproject.toml - 现代Python项目配置
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
asyncio_mode = "auto"
addopts = [
"-ra", # 显示摘要和失败详情
"--showlocals", # 显示局部变量
"--tb=short", # 简短回溯
"--strict-markers",
"--strict-config",
]
markers = [
"slow: marks tests as slow",
"integration: marks tests as integration tests",
"unit: marks tests as unit tests",
"api: marks tests as api tests",
"database: marks tests as database tests",
"auth: marks tests as authentication tests"
]
[tool.coverage.run]
source = ["src/", "app/"]
omit = [
"*/venv/*",
"*/tests/*",
"*/migrations/*",
"*/config/*",
"*/__init__.py"
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:"
]#conftest.py配置文件
# tests/conftest.py - Pytest配置文件
import asyncio
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from fastapi.testclient import TestClient
# 应用程序导入
from app.main import app
from app.database import get_db, engine, Base
from app.models import User, Item
from app.config import settings
# 测试数据库URL - 使用内存数据库
TEST_DATABASE_URL = "sqlite:///./test.db"
@pytest.fixture(scope="session")
def event_loop():
"""创建事件循环"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def test_engine():
"""创建测试数据库引擎"""
engine = create_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
yield engine
engine.dispose()
@pytest.fixture(scope="function")
async def test_session(test_engine):
"""创建测试数据库会话"""
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async_session = sessionmaker(
test_engine,
expire_on_commit=False,
class_=AsyncSession
)
async with async_session() as session:
yield session
# 回滚事务,确保数据清洁
await session.rollback()
# 清理数据库
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture(scope="function")
def override_dependencies(test_session):
"""覆盖应用依赖"""
def get_test_db():
return test_session
app.dependency_overrides[get_db] = get_test_db
yield
app.dependency_overrides.clear()
@pytest.fixture
async def async_client(override_dependencies):
"""异步测试客户端"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
@pytest.fixture
async def authenticated_client(async_client):
"""认证的测试客户端"""
# 创建测试用户
response = await async_client.post("/auth/register", json={
"email": "test@example.com",
"password": "TestPassword123!",
"name": "Test User"
})
# 登录获取token
login_response = await async_client.post("/auth/login", data={
"username": "test@example.com",
"password": "TestPassword123!"
})
token = login_response.json()["access_token"]
async_client.headers["Authorization"] = f"Bearer {token}"
yield async_client
@pytest.fixture
def sample_user_data():
"""样本用户数据"""
return {
"email": "test@example.com",
"password": "TestPassword123!",
"name": "Test User",
"is_active": True
}
@pytest.fixture
def sample_item_data():
"""样本项目数据"""
return {
"title": "Test Item",
"description": "This is a test item",
"price": 99.99
}#TestClient基础使用
#TestClient同步测试
# tests/test_basic.py
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_root_endpoint():
"""测试根端点"""
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to DaomanAPI"}
def test_health_check():
"""测试健康检查端点"""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy", "timestamp": pytest.approx(0, abs=1)}
def test_post_request():
"""测试POST请求"""
test_data = {"name": "Test User", "email": "test@example.com"}
response = client.post("/users/", json=test_data)
assert response.status_code == 201
assert response.json()["name"] == "Test User"
assert response.json()["email"] == "test@example.com"
def test_query_parameters():
"""测试查询参数"""
response = client.get("/items/", params={"skip": 0, "limit": 10})
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_path_parameters():
"""测试路径参数"""
user_id = 1
response = client.get(f"/users/{user_id}")
assert response.status_code == 200
assert response.json()["id"] == user_id
def test_error_handling():
"""测试错误处理"""
response = client.get("/nonexistent-endpoint")
assert response.status_code == 404
def test_validation_errors():
"""测试验证错误"""
invalid_data = {"name": "Too Short"} # 假设name需要至少10个字符
response = client.post("/users/", json=invalid_data)
assert response.status_code == 422
assert "detail" in response.json()#HTTP方法测试
# tests/test_http_methods.py
import pytest
class TestHTTPMethods:
"""HTTP方法测试类"""
def test_get_method(self, client):
"""测试GET方法"""
response = client.get("/users/1")
assert response.status_code == 200
def test_post_method(self, client):
"""测试POST方法"""
data = {"name": "New User", "email": "new@example.com"}
response = client.post("/users/", json=data)
assert response.status_code == 201
assert response.json()["name"] == "New User"
def test_put_method(self, client):
"""测试PUT方法"""
data = {"name": "Updated User", "email": "updated@example.com"}
response = client.put("/users/1", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Updated User"
def test_patch_method(self, client):
"""测试PATCH方法"""
partial_data = {"name": "Partially Updated"}
response = client.patch("/users/1", json=partial_data)
assert response.status_code == 200
assert response.json()["name"] == "Partially Updated"
def test_delete_method(self, client):
"""测试DELETE方法"""
response = client.delete("/users/1")
assert response.status_code == 204 # No Content#请求头和Cookie测试
# tests/test_headers_cookies.py
def test_custom_headers(client):
"""测试自定义请求头"""
headers = {"X-Custom-Header": "test-value", "User-Agent": "Test-Client/1.0"}
response = client.get("/headers", headers=headers)
assert response.status_code == 200
# 验证服务器收到了正确的头信息
def test_authentication_headers(client):
"""测试认证头"""
# 无认证
response = client.get("/protected")
assert response.status_code == 401
# 有认证
headers = {"Authorization": "Bearer fake-token"}
response = client.get("/protected", headers=headers)
assert response.status_code in [200, 401] # 取决于token有效性
def test_cookie_handling(client):
"""测试Cookie处理"""
response = client.get("/set-cookie")
assert response.status_code == 200
assert "set-cookie" in response.headers
# 使用Cookie进行后续请求
cookie_value = response.cookies.get("session_id")
response = client.get("/with-cookie", cookies={"session_id": cookie_value})
assert response.status_code == 200#异步测试详解
#异步测试基础
# tests/test_async.py
import pytest
import asyncio
from httpx import AsyncClient, ASGITransport
from app.main import app
@pytest.mark.asyncio
async def test_async_endpoint():
"""异步端点测试"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
response = await ac.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to DaomanAPI"}
@pytest.mark.asyncio
async def test_async_post():
"""异步POST请求测试"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
data = {"name": "Async User", "email": "async@example.com"}
response = await ac.post("/users/", json=data)
assert response.status_code == 201
class TestAsyncEndpoints:
"""异步端点测试类"""
@pytest.mark.asyncio
async def test_list_users(self, async_client):
"""测试用户列表端点"""
response = await async_client.get("/users/")
assert response.status_code == 200
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_create_user(self, async_client):
"""测试创建用户"""
user_data = {
"email": "async_test@example.com",
"password": "AsyncPassword123!",
"name": "Async Test User"
}
response = await async_client.post("/users/", json=user_data)
assert response.status_code == 201
assert response.json()["email"] == user_data["email"]
@pytest.mark.asyncio
async def test_get_user(self, async_client):
"""测试获取用户"""
# 先创建用户
user_data = {
"email": "get_test@example.com",
"password": "GetPassword123!",
"name": "Get Test User"
}
create_response = await async_client.post("/users/", json=user_data)
user_id = create_response.json()["id"]
# 获取用户
response = await async_client.get(f"/users/{user_id}")
assert response.status_code == 200
assert response.json()["email"] == user_data["email"]#异步并发测试
# tests/test_async_concurrency.py
import pytest
import asyncio
from httpx import AsyncClient, ASGITransport
@pytest.mark.asyncio
async def test_concurrent_requests():
"""测试并发请求"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
# 创建多个并发请求
tasks = [
ac.get("/users"),
ac.get("/items"),
ac.get("/health"),
ac.get("/metrics")
]
responses = await asyncio.gather(*tasks)
# 验证所有请求都成功
for response in responses:
assert response.status_code in [200, 201]
@pytest.mark.asyncio
async def test_load_simulation():
"""负载模拟测试"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
# 模拟10个并发用户
tasks = [ac.get("/health") for _ in range(10)]
responses = await asyncio.gather(*tasks, return_exceptions=True)
# 验证响应
successful_responses = [
r for r in responses
if not isinstance(r, Exception) and r.status_code == 200
]
assert len(successful_responses) == 10
@pytest.mark.asyncio
async def test_streaming_response():
"""测试流式响应"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
# 假设有流式端点
async with ac.stream("GET", "/stream-data") as response:
assert response.status_code == 200
chunks = []
async for chunk in response.aiter_text():
chunks.append(chunk)
assert len(chunks) > 0#WebSocket测试
# tests/test_websocket.py
import pytest
import asyncio
from websockets.sync.client import connect
from app.main import app
@pytest.mark.asyncio
async def test_websocket_connection():
"""WebSocket连接测试"""
# 注意:FastAPI的WebSocket测试需要特殊的处理
import websockets
import json
async def test_websocket():
uri = "ws://localhost:8000/ws"
async with websockets.connect(uri) as websocket:
# 发送消息
await websocket.send(json.dumps({"type": "ping"}))
# 接收响应
response = await websocket.recv()
data = json.loads(response)
assert data["type"] == "pong"
# 由于WebSocket测试的复杂性,这里简化示例
# 实际应用中可能需要使用专门的WebSocket测试库
pass#Fixtures与依赖注入测试
#Fixtures高级用法
# tests/test_fixtures_advanced.py
import pytest
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from app.services.user_service import UserService
from app.dependencies import get_current_user
@pytest.fixture
def mock_user_service():
"""模拟用户服务"""
mock = AsyncMock(spec=UserService)
mock.get_user_by_id.return_value = {
"id": 1,
"email": "mock@example.com",
"name": "Mock User"
}
mock.create_user.return_value = {
"id": 1,
"email": "new@example.com",
"name": "New User"
}
return mock
@pytest.fixture
def mock_current_user():
"""模拟当前用户"""
return {
"id": 1,
"email": "current@example.com",
"name": "Current User",
"role": "admin"
}
@pytest.fixture
def override_user_dependency(mock_current_user):
"""覆盖用户依赖"""
def mock_get_current_user():
return mock_current_user
app.dependency_overrides[get_current_user] = mock_get_current_user
yield
app.dependency_overrides.clear()
def test_with_mocked_service(mock_user_service):
"""使用模拟服务的测试"""
# 测试逻辑使用mock_user_service
assert mock_user_service.get_user_by_id.called is False
user = mock_user_service.get_user_by_id(1)
assert user["email"] == "mock@example.com"
class TestWithFixtures:
"""使用Fixtures的测试类"""
def test_authenticated_endpoint(self, authenticated_client):
"""测试认证端点"""
response = authenticated_client.get("/users/me")
assert response.status_code == 200
def test_with_sample_data(self, client, sample_user_data):
"""使用样本数据的测试"""
response = client.post("/users/", json=sample_user_data)
assert response.status_code == 201
assert response.json()["email"] == sample_user_data["email"]
def test_with_multiple_fixtures(self, client, sample_user_data, sample_item_data):
"""使用多个Fixtures的测试"""
# 创建用户
user_response = client.post("/users/", json=sample_user_data)
assert user_response.status_code == 201
user_id = user_response.json()["id"]
# 创建关联的项目
sample_item_data["user_id"] = user_id
item_response = client.post("/items/", json=sample_item_data)
assert item_response.status_code == 201
assert item_response.json()["user_id"] == user_id
@pytest.fixture(scope="session")
def test_data_generator():
"""测试数据生成器"""
from faker import Faker
fake = Faker()
def generate_user_data():
return {
"email": fake.email(),
"name": fake.name(),
"password": fake.password(length=12, special_chars=True, digits=True),
"phone": fake.phone_number()
}
def generate_item_data():
return {
"title": fake.catch_phrase(),
"description": fake.text(max_nb_chars=200),
"price": round(fake.random.uniform(10.0, 1000.0), 2)
}
return {
"user": generate_user_data,
"item": generate_item_data
}
def test_with_generated_data(client, test_data_generator):
"""使用生成数据的测试"""
user_data = test_data_generator["user"]()
response = client.post("/users/", json=user_data)
assert response.status_code == 201#依赖注入测试
# tests/test_dependency_injection.py
import pytest
from fastapi import Depends, HTTPException
from unittest.mock import patch, MagicMock
from app.dependencies import get_db, get_current_user, get_settings
from app.models import User
def test_db_dependency():
"""测试数据库依赖"""
# 验证依赖函数返回正确的类型
db_gen = get_db()
db_session = next(db_gen)
assert hasattr(db_session, 'execute') # 验证是数据库会话
db_gen.close() # 清理
def test_current_user_dependency():
"""测试当前用户依赖"""
# 这个测试需要模拟认证逻辑
pass
class TestDependencyOverrides:
"""依赖覆盖测试"""
def test_override_database(self, client):
"""测试数据库依赖覆盖"""
# 在conftest.py中已经覆盖了数据库依赖
# 这里验证覆盖生效
response = client.get("/health")
assert response.status_code == 200
def test_override_settings(self):
"""测试设置依赖覆盖"""
# 模拟覆盖设置
from app.config import settings
# 保存原始值
original_debug = settings.debug
# 临时覆盖
settings.debug = True
try:
# 测试使用覆盖后的设置
assert settings.debug is True
finally:
# 恢复原始值
settings.debug = original_debug
@pytest.fixture
def mock_db_dependency():
"""模拟数据库依赖"""
mock_session = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = User(
id=1,
email="test@example.com",
name="Test User"
)
def get_mock_db():
yield mock_session
app.dependency_overrides[get_db] = get_mock_db
yield mock_session
app.dependency_overrides.clear()
def test_with_mocked_db(mock_db_dependency):
"""使用模拟数据库的测试"""
# 这里的测试会使用模拟的数据库会话
from app.routers.users import get_user # 假设有一个这样的路由函数
# 由于我们无法直接测试路由函数,这里展示概念
# 实际测试中,路由会被FastAPI自动调用并使用覆盖的依赖
assert mock_db_dependency.query.called is False
# 测试依赖注入的完整示例
class TestCompleteDependencyInjection:
"""完整的依赖注入测试"""
@pytest.fixture
def setup_dependencies(self, test_session):
"""设置多个依赖"""
def override_get_db():
yield test_session
def override_get_current_user():
return {"id": 1, "email": "test@example.com", "role": "admin"}
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
yield
app.dependency_overrides.clear()
def test_complete_setup(self, client, setup_dependencies):
"""测试完整依赖设置"""
# 测试需要认证的端点
response = client.get("/admin/dashboard") # 假设这是管理员端点
# 根据实际路由实现验证响应
assert response.status_code in [200, 404] # 404表示端点不存在但认证成功#数据库测试策略
#数据库事务管理
# tests/test_database_transactions.py
import pytest
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from app.models import User, Item
from app.schemas import UserCreate, ItemCreate
class TestDatabaseTransactions:
"""数据库事务测试"""
async def test_successful_transaction(self, test_session):
"""测试成功的事务"""
from app.crud import create_user
user_data = UserCreate(
email="transaction_test@example.com",
password="Password123!",
name="Transaction Test"
)
user = await create_user(test_session, user_data)
assert user.email == user_data.email
assert user.name == user_data.name
# 验证数据已提交
result = await test_session.execute(
text("SELECT COUNT(*) FROM users WHERE email = :email"),
{"email": user_data.email}
)
count = result.scalar()
assert count == 1
async def test_rollback_on_error(self, test_session):
"""测试错误时回滚"""
from app.crud import create_user
# 创建一个有效的用户
valid_user = UserCreate(
email="valid@example.com",
password="Password123!",
name="Valid User"
)
await create_user(test_session, valid_user)
# 尝试创建重复的用户(应该失败)
duplicate_user = UserCreate(
email="valid@example.com", # 重复邮箱
password="Password123!",
name="Duplicate User"
)
try:
await create_user(test_session, duplicate_user)
assert False, "应该抛出异常"
except IntegrityError:
# 验证事务已回滚,数据库中仍只有1个用户
result = await test_session.execute(text("SELECT COUNT(*) FROM users"))
count = result.scalar()
assert count == 1
async def test_nested_transactions(self, test_session):
"""测试嵌套事务"""
from app.crud import create_user, create_item
# 创建用户
user_data = UserCreate(
email="nested_test@example.com",
password="Password123!",
name="Nested Test"
)
user = await create_user(test_session, user_data)
# 创建关联的项目
item_data = ItemCreate(
title="Test Item",
description="Test Description",
price=99.99,
owner_id=user.id
)
item = await create_item(test_session, item_data)
# 验证关联关系
assert item.owner_id == user.id
# 验证数据持久化
user_result = await test_session.execute(
text("SELECT * FROM users WHERE id = :id"), {"id": user.id}
)
user_from_db = user_result.fetchone()
assert user_from_db is not None
item_result = await test_session.execute(
text("SELECT * FROM items WHERE id = :id"), {"id": item.id}
)
item_from_db = item_result.fetchone()
assert item_from_db is not None
@pytest.fixture
def clean_database_state(test_session):
"""确保数据库状态清洁"""
# 在测试开始前清理相关数据
await test_session.execute(text("DELETE FROM items"))
await test_session.execute(text("DELETE FROM users"))
await test_session.commit()
yield test_session
# 测试结束后再次清理(尽管在conftest.py中已经有回滚)
class TestDataIsolation:
"""数据隔离测试"""
async def test_isolation_between_tests(self, clean_database_state):
"""测试测试间的数据隔离"""
# 验证开始时数据库为空
result = await clean_database_state.execute(text("SELECT COUNT(*) FROM users"))
assert result.scalar() == 0
# 添加一些数据
result = await clean_database_state.execute(
text("INSERT INTO users (email, name, password_hash) VALUES (:email, :name, :password)"),
{"email": "isolation@example.com", "name": "Isolation", "password": "hash"}
)
# 在同个session中验证数据存在
result = await clean_database_state.execute(text("SELECT COUNT(*) FROM users"))
assert result.scalar() == 1
# 数据库性能测试
class TestDatabasePerformance:
"""数据库性能测试"""
@pytest.mark.benchmark
async def test_bulk_insert_performance(self, test_session):
"""测试批量插入性能"""
import time
from app.crud import create_user
start_time = time.time()
# 批量创建用户
for i in range(100):
user_data = UserCreate(
email=f"bulk_test_{i}@example.com",
password="Password123!",
name=f"Bulk User {i}"
)
await create_user(test_session, user_data)
end_time = time.time()
duration = end_time - start_time
print(f"Created 100 users in {duration:.2f} seconds")
assert duration < 5.0 # 应该在5秒内完成
@pytest.mark.benchmark
async def test_query_performance(self, test_session):
"""测试查询性能"""
import time
from app.crud import get_users
# 先创建一些测试数据
for i in range(50):
user_data = UserCreate(
email=f"query_test_{i}@example.com",
password="Password123!",
name=f"Query User {i}"
)
await create_user(test_session, user_data)
start_time = time.time()
# 执行多次查询
for _ in range(10):
users = await get_users(test_session, skip=0, limit=10)
assert len(users) <= 10
end_time = time.time()
duration = end_time - start_time
print(f"Executed 10 queries in {duration:.2f} seconds")
assert duration < 2.0 # 应该在2秒内完成#数据库迁移测试
# tests/test_database_migrations.py
import pytest
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine, text
from app.database import DATABASE_URL
class TestDatabaseMigrations:
"""数据库迁移测试"""
def test_migration_up_down(self):
"""测试迁移的上移和下移"""
# 创建临时数据库用于测试迁移
temp_engine = create_engine("sqlite:///./temp_test.db")
try:
# 应用所有迁移
alembic_cfg = Config("alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", "sqlite:///./temp_test.db")
command.upgrade(alembic_cfg, "head")
# 验证表是否创建
with temp_engine.connect() as conn:
result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
tables = [row[0] for row in result]
# 根据实际情况验证表是否存在
expected_tables = ["users", "items", "orders"] # 根据实际模型调整
for table in expected_tables:
assert table in tables, f"Table {table} not found after migration"
# 回滚迁移
command.downgrade(alembic_cfg, "base")
# 验证表是否删除
with temp_engine.connect() as conn:
result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
tables = [row[0] for row in result]
# 回滚后应该没有业务表
business_tables = [t for t in tables if t not in ["alembic_version"]]
assert len(business_tables) == 0
finally:
# 清理临时数据库
temp_engine.dispose()
def test_migration_script_generation(self):
"""测试迁移脚本生成"""
# 这个测试比较复杂,通常在CI/CD中运行
# 这里只做基本的概念验证
pass
@pytest.fixture
def migration_test_db():
"""迁移测试数据库"""
engine = create_engine("sqlite:///./migration_test.db")
# 创建基本表结构用于测试
with engine.connect() as conn:
conn.execute(text("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
email VARCHAR(255) UNIQUE NOT NULL,
name VARCHAR(255) NOT NULL
)
"""))
yield engine
# 清理
engine.dispose()
def test_data_migration(migration_test_db):
"""测试数据迁移"""
# 插入测试数据
with migration_test_db.connect() as conn:
conn.execute(text("""
INSERT INTO users (email, name) VALUES
('old@example.com', 'Old User'),
('another@example.com', 'Another User')
"""))
conn.commit()
# 验证数据存在
with migration_test_db.connect() as conn:
result = conn.execute(text("SELECT COUNT(*) FROM users"))
count = result.scalar()
assert count == 2#API端点测试
#用户管理API测试
# tests/test_user_api.py
import pytest
from httpx import AsyncClient
from app.models import User
from app.schemas import UserCreate, UserUpdate
class TestUserAPI:
"""用户API测试"""
@pytest.mark.asyncio
async def test_create_user(self, async_client):
"""测试创建用户"""
user_data = {
"email": "newuser@example.com",
"password": "SecurePassword123!",
"name": "New User"
}
response = await async_client.post("/users/", json=user_data)
assert response.status_code == 201
data = response.json()
assert data["email"] == user_data["email"]
assert data["name"] == user_data["name"]
assert "id" in data
assert "hashed_password" not in data # 敏感信息不应返回
@pytest.mark.asyncio
async def test_create_user_duplicate_email(self, async_client):
"""测试创建重复邮箱用户"""
user_data = {
"email": "duplicate@example.com",
"password": "Password123!",
"name": "Duplicate User"
}
# 第一次创建成功
response1 = await async_client.post("/users/", json=user_data)
assert response1.status_code == 201
# 第二次创建应该失败
response2 = await async_client.post("/users/", json=user_data)
assert response2.status_code == 409 # 冲突
@pytest.mark.asyncio
async def test_get_user(self, async_client):
"""测试获取用户"""
# 先创建用户
user_data = {
"email": "getuser@example.com",
"password": "Password123!",
"name": "Get User"
}
create_response = await async_client.post("/users/", json=user_data)
user_id = create_response.json()["id"]
# 获取用户
response = await async_client.get(f"/users/{user_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == user_id
assert data["email"] == user_data["email"]
@pytest.mark.asyncio
async def test_get_nonexistent_user(self, async_client):
"""测试获取不存在的用户"""
response = await async_client.get("/users/999999")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_user(self, async_client):
"""测试更新用户"""
# 先创建用户
user_data = {
"email": "update@example.com",
"password": "Password123!",
"name": "Update User"
}
create_response = await async_client.post("/users/", json=user_data)
user_id = create_response.json()["id"]
# 更新用户
update_data = {
"name": "Updated User Name",
"email": "updated@example.com"
}
response = await async_client.put(f"/users/{user_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == update_data["name"]
assert data["email"] == update_data["email"]
@pytest.mark.asyncio
async def test_delete_user(self, async_client):
"""测试删除用户"""
# 先创建用户
user_data = {
"email": "delete@example.com",
"password": "Password123!",
"name": "Delete User"
}
create_response = await async_client.post("/users/", json=user_data)
user_id = create_response.json()["id"]
# 验证用户存在
get_response = await async_client.get(f"/users/{user_id}")
assert get_response.status_code == 200
# 删除用户
delete_response = await async_client.delete(f"/users/{user_id}")
assert delete_response.status_code == 204
# 验证用户已删除
get_response_after = await async_client.get(f"/users/{user_id}")
assert get_response_after.status_code == 404
@pytest.mark.asyncio
async def test_list_users(self, async_client):
"""测试用户列表"""
# 创建多个用户
users_data = [
{"email": f"user{i}@example.com", "password": "Password123!", "name": f"User {i}"}
for i in range(5)
]
for user_data in users_data:
response = await async_client.post("/users/", json=user_data)
assert response.status_code == 201
# 获取用户列表
response = await async_client.get("/users/")
assert response.status_code == 200
users = response.json()
assert isinstance(users, list)
assert len(users) >= 5 # 至少有5个用户
# 测试分页
paginated_response = await async_client.get("/users/", params={"skip": 0, "limit": 2})
paginated_users = paginated_response.json()
assert len(paginated_users) <= 2
# 认证相关API测试
class TestAuthAPI:
"""认证API测试"""
@pytest.mark.asyncio
async def test_user_registration(self, async_client):
"""测试用户注册"""
registration_data = {
"email": "register@example.com",
"password": "RegisterPassword123!",
"name": "Register User"
}
response = await async_client.post("/auth/register", json=registration_data)
assert response.status_code == 201
data = response.json()
assert data["email"] == registration_data["email"]
assert "id" in data
@pytest.mark.asyncio
async def test_user_login_success(self, async_client):
"""测试用户登录成功"""
# 先注册用户
await async_client.post("/auth/register", json={
"email": "login@example.com",
"password": "LoginPassword123!",
"name": "Login User"
})
# 登录
login_data = {
"username": "login@example.com",
"password": "LoginPassword123!"
}
response = await async_client.post("/auth/login", data=login_data)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
@pytest.mark.asyncio
async def test_user_login_failure(self, async_client):
"""测试用户登录失败"""
login_data = {
"username": "nonexistent@example.com",
"password": "wrongpassword"
}
response = await async_client.post("/auth/login", data=login_data)
assert response.status_code == 401
# 项目相关API测试
class TestItemAPI:
"""项目API测试"""
@pytest.mark.asyncio
async def test_create_item(self, authenticated_client):
"""测试创建项目(需要认证)"""
item_data = {
"title": "Test Item",
"description": "This is a test item",
"price": 99.99
}
response = await authenticated_client.post("/items/", json=item_data)
assert response.status_code == 201
data = response.json()
assert data["title"] == item_data["title"]
assert data["price"] == item_data["price"]
@pytest.mark.asyncio
async def test_get_items(self, authenticated_client):
"""测试获取项目列表"""
response = await authenticated_client.get("/items/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)#错误处理测试
# tests/test_error_handling.py
import pytest
class TestErrorHandling:
"""错误处理测试"""
@pytest.mark.asyncio
async def test_validation_errors(self, async_client):
"""测试验证错误"""
invalid_data = {
"email": "invalid-email", # 无效邮箱
"password": "123", # 密码太短
"name": "T" # 名字太短
}
response = await async_client.post("/users/", json=invalid_data)
assert response.status_code == 422 # 验证错误
error_detail = response.json()
assert "detail" in error_detail
assert len(error_detail["detail"]) > 0
@pytest.mark.asyncio
async def test_not_found_errors(self, async_client):
"""测试未找到错误"""
response = await async_client.get("/users/999999")
assert response.status_code == 404
error_detail = response.json()
assert "detail" in error_detail
@pytest.mark.asyncio
async def test_unauthorized_errors(self, async_client):
"""测试未授权错误"""
response = await async_client.get("/users/me")
assert response.status_code == 401 # 需要认证
@pytest.mark.asyncio
async def test_rate_limiting(self, async_client):
"""测试速率限制"""
# 发送大量请求来测试速率限制
responses = []
for i in range(20): # 假设速率限制是10次/分钟
response = await async_client.get("/health")
responses.append(response.status_code)
# 检查是否有请求被限制
rate_limited_count = responses.count(429) # 429 Too Many Requests
# 根据实际的速率限制配置来验证
# assert rate_limited_count > 0
@pytest.mark.asyncio
async def test_internal_server_errors(self, async_client):
"""测试服务器内部错误"""
# 这个测试比较难模拟,通常通过异常处理中间件测试
# 可以通过模拟异常来测试
pass
# 自定义异常测试
class TestCustomExceptions:
"""自定义异常测试"""
@pytest.mark.asyncio
async def test_business_logic_errors(self, async_client):
"""测试业务逻辑错误"""
# 例如:尝试创建已存在的资源
user_data = {
"email": "business@example.com",
"password": "Password123!",
"name": "Business User"
}
# 第一次创建
response1 = await async_client.post("/users/", json=user_data)
assert response1.status_code == 201
# 第二次创建相同邮箱(业务逻辑错误)
response2 = await async_client.post("/users/", json=user_data)
assert response2.status_code == 409 # 业务冲突
@pytest.mark.asyncio
async def test_permission_errors(self, async_client):
"""测试权限错误"""
# 创建一个普通用户
user_data = {
"email": "normal@example.com",
"password": "Password123!",
"name": "Normal User"
}
response = await async_client.post("/auth/register", json=user_data)
assert response.status_code == 201
# 尝试访问需要管理员权限的端点
admin_response = await async_client.get("/admin/users")
# 根据实际权限控制实现验证响应
assert admin_response.status_code in [401, 403] # 未认证或禁止访问#参数化测试与边界测试
#参数化测试
# tests/test_parameterized.py
import pytest
class TestParameterizedTests:
"""参数化测试"""
@pytest.mark.parametrize("email,is_valid", [
("valid@example.com", True),
("user.name@example.com", True),
("user+tag@example.co.uk", True),
("invalid", False),
("@example.com", False),
("user@", False),
("", False),
("user@.com", False),
("user@domain", False),
])
@pytest.mark.asyncio
async def test_email_validation(self, async_client, email, is_valid):
"""参数化测试邮箱验证"""
user_data = {
"email": email,
"password": "Password123!",
"name": "Test User"
}
response = await async_client.post("/users/", json=user_data)
if is_valid:
assert response.status_code in [201, 409] # 201成功,409已存在
else:
assert response.status_code == 422 # 验证失败
@pytest.mark.parametrize("password,min_length,max_length,expected_validity", [
("short", 8, 128, False),
("LongEnoughButNoSpecialChar123", 8, 128, False),
("ValidPass1!", 8, 128, True),
("Another_Valid_Pass123!", 8, 128, True),
("", 8, 128, False),
("A" * 129, 8, 128, False), # 超过最大长度
])
@pytest.mark.asyncio
async def test_password_validation(
self, async_client, password, min_length, max_length, expected_validity
):
"""参数化测试密码验证"""
user_data = {
"email": "password_test@example.com",
"password": password,
"name": "Password Test User"
}
response = await async_client.post("/users/", json=user_data)
if expected_validity:
assert response.status_code in [201, 409]
else:
assert response.status_code == 422
@pytest.mark.parametrize("number,input_type,expected_result", [
(0, "integer", 0),
(1, "integer", 1),
(-1, "integer", -1),
(3.14, "float", 3.14),
("123", "string_number", 123),
("invalid", "invalid_string", None),
])
@pytest.mark.asyncio
async def test_input_processing(
self, async_client, number, input_type, expected_result
):
"""参数化测试输入处理"""
# 这里测试一个假想的数值处理端点
# response = await async_client.post("/process-number", json={"input": number})
# 根据实际端点实现验证
pass
@pytest.mark.parametrize("user_role,can_access_admin_panel", [
("admin", True),
("superuser", True),
("moderator", False),
("user", False),
("guest", False),
])
@pytest.mark.asyncio
async def test_role_based_access(
self, async_client, user_role, can_access_admin_panel
):
"""参数化测试基于角色的访问控制"""
# 这个测试需要更复杂的设置,包括创建不同角色的用户
# 并验证他们对管理面板的访问权限
pass
# 边界值测试
class TestBoundaryValueTesting:
"""边界值测试"""
@pytest.mark.asyncio
async def test_pagination_boundaries(self, async_client):
"""测试分页边界"""
# 创建足够的测试数据
for i in range(25):
await async_client.post("/users/", json={
"email": f"page_test_{i}@example.com",
"password": "Password123!",
"name": f"Page Test User {i}"
})
# 测试边界值
boundary_tests = [
{"skip": 0, "limit": 0}, # 0限制
{"skip": 0, "limit": 1}, # 1限制
{"skip": 0, "limit": 10}, # 正常限制
{"skip": 0, "limit": 100}, # 大限制
{"skip": 20, "limit": 5}, # 高偏移
{"skip": 100, "limit": 10}, # 高偏移,无结果
]
for params in boundary_tests:
response = await async_client.get("/users/", params=params)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
@pytest.mark.parametrize("field,value,min_val,max_val", [
("price", 0.01, 0.01, 999999.99), # 最小价格
("price", 999999.99, 0.01, 999999.99), # 最大价格
("quantity", 1, 1, 10000), # 最小数量
("quantity", 10000, 1, 10000), # 最大数量
("rating", 1.0, 1.0, 5.0), # 最小评分
("rating", 5.0, 1.0, 5.0), # 最大评分
])
@pytest.mark.asyncio
async def test_numeric_field_boundaries(
self, async_client, field, value, min_val, max_val
):
"""测试数值字段边界"""
item_data = {
"title": f"Boundary Test Item {field}",
"description": f"Testing {field} boundaries",
}
if field in ["price", "quantity", "rating"]:
item_data[field] = value
response = await async_client.post("/items/", json=item_data)
# 边界值应该被接受
assert response.status_code in [201, 400] # 400可能是由于其他验证
@pytest.mark.parametrize("string_field,string_value,min_length,max_length", [
("name", "A", 1, 255), # 最短名称
("name", "A" * 255, 1, 255), # 最长名称
("description", "A", 1, 1000), # 最短描述
("description", "A" * 1000, 1, 1000), # 最长描述
("title", "A", 1, 200), # 最短标题
("title", "A" * 200, 1, 200), # 最长标题
])
@pytest.mark.asyncio
async def test_string_field_boundaries(
self, async_client, string_field, string_value, min_length, max_length
):
"""测试字符串字段边界"""
data = {
"title": "String Boundary Test",
"description": "Testing string field boundaries",
}
if string_field in ["name", "description", "title"]:
data[string_field] = string_value
response = await async_client.post("/items/", json=data)
# 边界值应该被接受
assert response.status_code in [201, 422] # 422是验证错误
# 组合参数化测试
class TestCombinatorialTesting:
"""组合测试"""
@pytest.mark.parametrize("email_domain,password_complexity,rate_limit_status", [
("gmail.com", "simple", "normal"),
("company.co.uk", "complex", "normal"),
("free.email", "simple", "limited"),
("enterprise.com", "very_complex", "normal"),
])
@pytest.mark.asyncio
async def test_combination_scenarios(
self, async_client, email_domain, password_complexity, rate_limit_status
):
"""测试组合场景"""
# 根据参数组合创建不同的测试场景
password_chars = {
"simple": "password123",
"complex": "ComplexPass1!",
"very_complex": "VeryComplexPass123!@#"
}
user_data = {
"email": f"user@{email_domain}",
"password": password_chars[password_complexity],
"name": f"Test User {email_domain}"
}
response = await async_client.post("/users/", json=user_data)
assert response.status_code in [201, 409, 422]#性能边界测试
# tests/test_performance_boundaries.py
import pytest
import asyncio
import time
class TestPerformanceBoundaries:
"""性能边界测试"""
@pytest.mark.performance
@pytest.mark.asyncio
async def test_response_time_under_load(self, async_client):
"""测试负载下的响应时间"""
start_time = time.time()
# 并发发送多个请求
tasks = []
for i in range(50):
task = async_client.get("/health")
tasks.append(task)
responses = await asyncio.gather(*tasks)
end_time = time.time()
total_time = end_time - start_time
print(f"Completed 50 requests in {total_time:.2f} seconds")
print(f"Average response time: {total_time/50:.3f} seconds")
# 验证所有请求都成功
for response in responses:
assert response.status_code == 200
# 验证总时间在可接受范围内(根据系统性能调整)
assert total_time < 10.0 # 10秒内完成50个请求
@pytest.mark.performance
@pytest.mark.asyncio
async def test_memory_usage(self, async_client):
"""测试内存使用"""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
# 执行内存密集型操作
for i in range(100):
response = await async_client.get("/health")
assert response.status_code == 200
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory
print(f"Memory usage increase: {memory_increase:.2f} MB")
# 验证内存增长在可接受范围内
assert memory_increase < 50.0 # 小于50MB
@pytest.mark.stress
@pytest.mark.asyncio
async def test_stress_test(self, async_client):
"""压力测试"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
# 配置压力测试参数
num_requests = 200
concurrency_level = 20
async def make_request():
try:
response = await async_client.get("/health")
return response.status_code == 200
except Exception:
return False
# 创建并发请求
tasks = [make_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
success_count = sum(results)
success_rate = success_count / num_requests * 100
print(f"Stress test results: {success_count}/{num_requests} successful ({success_rate:.1f}%)")
# 验证成功率
assert success_rate >= 95.0 # 95%成功率
@pytest.mark.performance
@pytest.mark.asyncio
async def test_database_connection_pool(self, test_session):
"""测试数据库连接池"""
import asyncio
# 并发数据库操作测试
async def db_operation(i):
from app.crud import get_user_by_email
# 执行数据库查询
try:
await get_user_by_email(test_session, f"test{i}@example.com")
return True
except Exception:
return False
# 并发执行多个数据库操作
tasks = [db_operation(i) for i in range(10)]
results = await asyncio.gather(*tasks)
success_count = sum(results)
assert success_count == 10 # 所有操作都应该成功#测试覆盖率分析
#覆盖率配置与执行
# coverage_setup.sh - 覆盖率设置脚本
#!/bin/bash
echo "🚀 Setting up test coverage analysis..."
# 安装覆盖率工具
pip install pytest-cov coverage
# 运行测试并生成覆盖率报告
echo "📊 Running tests with coverage analysis..."
pytest --cov=app --cov-report=html --cov-report=term-missing --cov-fail-under=80
# 生成XML报告用于CI/CD
pytest --cov=app --cov-report=xml
echo "✅ Coverage analysis completed!"
echo "📄 HTML report available at: htmlcov/index.html"
echo "📊 XML report available at: coverage.xml"#详细的覆盖率配置
# .coveragerc - 覆盖率配置文件
[run]
# 源代码路径
source = app/
omit =
*/venv/*
*/env/*
*/.venv/*
*/tests/*
*/migrations/*
*/config/*
*/__init__.py
*/settings.py
*/alembic/*
*/venv/*/**
*/env/*/**
*/.venv/*/**
# 包含分支覆盖率
branch = True
# 平行模式(用于并行测试)
parallel = True
[report]
# 忽略的行
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:
if TYPE_CHECKING:
if settings.debug:
if DEBUG:
@abstractmethod
@overload
@runtime_checkable
if sys.version_info <
if typing.TYPE_CHECKING:
if MYPY:
assert_never\(
assert False
assert 0
if 0:
if __debug__:
if not __debug__:
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
if "NO PYTEST COVERAGE" in __doc__:
if "PYTEST_NO_COVERAGE" in __file__:
if "COVERAGE DISABLE" in __file__:
# 覆盖率阈值
precision = 2
show_missing = True
skip_covered = False
skip_empty = True
[html]
directory = htmlcov
title = FastAPI Application Test Coverage Report
[xml]
output = coverage.xml#覆盖率分析工具
# tools/coverage_analyzer.py - 覆盖率分析工具
import subprocess
import json
import xml.etree.ElementTree as ET
from pathlib import Path
import sys
class CoverageAnalyzer:
"""覆盖率分析器"""
def __init__(self, source_path="app/", coverage_threshold=80):
self.source_path = source_path
self.threshold = coverage_threshold
self.results = {}
def run_coverage_analysis(self):
"""运行覆盖率分析"""
print("🔍 Running coverage analysis...")
# 运行pytest with coverage
cmd = [
"pytest",
f"--cov={self.source_path}",
"--cov-report=json",
"--cov-report=term-missing",
f"--cov-fail-under={self.threshold}"
]
result = subprocess.run(cmd, capture_output=True, text=True)
print("STDOUT:", result.stdout)
if result.stderr:
print("STDERR:", result.stderr)
# 读取JSON报告
json_report = Path("coverage.json")
if json_report.exists():
with open(json_report, 'r') as f:
self.results = json.load(f)
return result.returncode == 0
def get_coverage_summary(self):
"""获取覆盖率摘要"""
if not self.results:
return None
summary = {
"total_coverage": self.results.get("totals", {}).get("percent_covered", 0),
"num_files": len(self.results.get("files", {})),
"missing_lines": self.results.get("totals", {}).get("missing_lines", 0),
"excluded_lines": self.results.get("totals", {}).get("excluded", 0)
}
return summary
def get_low_coverage_files(self, threshold=80):
"""获取低覆盖率文件"""
low_coverage_files = []
for filepath, file_data in self.results.get("files", {}).items():
coverage_percent = file_data.get("summary", {}).get("percent_covered", 0)
if coverage_percent < threshold:
low_coverage_files.append({
"file": filepath,
"coverage": coverage_percent,
"missing_lines": file_data.get("summary", {}).get("missing_lines", 0)
})
return sorted(low_coverage_files, key=lambda x: x["coverage"])
def generate_report(self):
"""生成覆盖率报告"""
summary = self.get_coverage_summary()
if not summary:
print("❌ No coverage data available")
return
print("\n" + "="*60)
print("📋 COVERAGE ANALYSIS REPORT")
print("="*60)
print(f"Total Coverage: {summary['total_coverage']:.2f}%")
print(f"Number of Files: {summary['num_files']}")
print(f"Missing Lines: {summary['missing_lines']}")
print(f"Excluded Lines: {summary['excluded_lines']}")
print("-"*60)
# 显示低覆盖率文件
low_coverage = self.get_low_coverage_files()
if low_coverage:
print("⚠️ LOW COVERAGE FILES (< 80%):")
for file_info in low_coverage[:10]: # 只显示前10个
print(f" {file_info['file']}: {file_info['coverage']:.1f}% "
f"({file_info['missing_lines']} lines missing)")
else:
print("✅ All files have good coverage!")
print("="*60)
# 检查是否达到阈值
if summary['total_coverage'] >= self.threshold:
print(f"✅ Coverage meets threshold ({self.threshold}%)")
return True
else:
print(f"❌ Coverage below threshold ({self.threshold}%)")
return False
def main():
"""主函数"""
analyzer = CoverageAnalyzer(source_path="app", coverage_threshold=80)
success = analyzer.run_coverage_analysis()
analyzer.generate_report()
if not success:
sys.exit(1)
if __name__ == "__main__":
main()#覆盖率优化策略
# tests/test_coverage_optimization.py
"""
覆盖率优化策略:
1. 业务逻辑覆盖率:确保核心业务逻辑100%覆盖
2. 错误路径覆盖:测试所有可能的错误情况
3. 边界条件覆盖:测试输入参数的边界值
4. 异常处理覆盖:确保所有异常都被捕获和处理
5. 集成路径覆盖:测试模块间的交互
"""
class TestCoverageOptimization:
"""覆盖率优化测试"""
def test_all_exception_paths(self):
"""测试所有异常路径"""
# 示例:测试除零异常
def divide(a, b):
if b == 0:
raise ZeroDivisionError("Cannot divide by zero")
return a / b
# 测试正常情况
assert divide(10, 2) == 5
# 测试异常情况
with pytest.raises(ZeroDivisionError):
divide(10, 0)
def test_all_branches(self):
"""测试所有分支"""
def process_user(user_type, is_active):
if user_type == "admin":
if is_active:
return "Admin Active"
else:
return "Admin Inactive"
elif user_type == "user":
if is_active:
return "User Active"
else:
return "User Inactive"
else:
return "Unknown"
# 测试所有分支
assert process_user("admin", True) == "Admin Active"
assert process_user("admin", False) == "Admin Inactive"
assert process_user("user", True) == "User Active"
assert process_user("user", False) == "User Inactive"
assert process_user("guest", True) == "Unknown"
def test_boundary_conditions(self):
"""测试边界条件"""
def validate_age(age):
if age < 0:
raise ValueError("Age cannot be negative")
elif age > 150:
raise ValueError("Age seems unrealistic")
elif age < 18:
return "Minor"
elif age >= 18 and age <= 65:
return "Adult"
else:
return "Senior"
# 测试边界值
assert validate_age(0) == "Minor" # 边界
assert validate_age(17) == "Minor" # 边界
assert validate_age(18) == "Adult" # 边界
assert validate_age(65) == "Adult" # 边界
assert validate_age(66) == "Senior" # 边界
assert validate_age(150) == "Senior" # 边界
# 测试异常边界
with pytest.raises(ValueError, match="negative"):
validate_age(-1)
with pytest.raises(ValueError, match="unrealistic"):
validate_age(151)
def test_all_input_combinations(self):
"""测试所有输入组合"""
def calculate_discount(customer_type, purchase_amount, is_member):
discount = 0
if customer_type == "premium":
discount += 15
elif customer_type == "regular":
discount += 5
if is_member:
discount += 10
if purchase_amount > 1000:
discount += 5
elif purchase_amount > 500:
discount += 3
return min(discount, 50) # 最大折扣50%
# 测试各种组合
test_cases = [
# (customer_type, purchase_amount, is_member, expected_discount)
("premium", 100, False, 15),
("premium", 100, True, 25),
("premium", 600, True, 28),
("premium", 1100, True, 35),
("regular", 100, False, 5),
("regular", 100, True, 15),
("regular", 600, True, 18),
("regular", 1100, True, 25),
("guest", 100, False, 0),
("guest", 1100, True, 15),
]
for customer_type, amount, is_member, expected in test_cases:
result = calculate_discount(customer_type, amount, is_member)
assert result == expected, f"Failed for {customer_type}, {amount}, {is_member}"
# 专门的覆盖率提高测试
class TestIncreaseCoverage:
"""提高覆盖率的测试"""
def test_repr_methods(self):
"""测试repr方法"""
class TestClass:
def __init__(self, value):
self.value = value
def __repr__(self):
return f"TestClass(value={self.value!r})"
obj = TestClass("test")
repr_str = repr(obj)
assert "TestClass" in repr_str
assert "test" in repr_str
def test_eq_methods(self):
"""测试相等性方法"""
class TestClass:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if not isinstance(other, TestClass):
return NotImplemented
return self.value == other.value
obj1 = TestClass("test")
obj2 = TestClass("test")
obj3 = TestClass("different")
assert obj1 == obj2
assert obj1 != obj3
assert obj1 != "not_a_TestClass"
def test_property_methods(self):
"""测试属性方法"""
class TestClass:
def __init__(self):
self._value = 0
@property
def value(self):
return self._value
@value.setter
def value(self, val):
if val < 0:
raise ValueError("Value cannot be negative")
self._value = val
obj = TestClass()
assert obj.value == 0
obj.value = 10
assert obj.value == 10
with pytest.raises(ValueError):
obj.value = -1#TDD开发实践
#TDD基础实践
"""
TDD (Test-Driven Development) 开发实践:
1. RED - 编写失败的测试
2. GREEN - 编写最少的代码使测试通过
3. REFACTOR - 重构代码,保持测试通过
"""
# 示例:开发一个用户服务类
# 1. RED: 先写测试
# tests/test_user_service_tdd.py
import pytest
from unittest.mock import AsyncMock, MagicMock
class TestUserServiceTDD:
"""TDD方式开发用户服务"""
@pytest.fixture
def mock_db_session(self):
"""模拟数据库会话"""
session = AsyncMock()
session.commit = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def user_service(self, mock_db_session):
"""用户服务实例"""
from app.services.user_service import UserService
return UserService(db_session=mock_db_session)
def test_create_user_with_valid_data(self, user_service, mock_db_session):
"""测试使用有效数据创建用户 - RED阶段"""
# 这个测试应该失败,因为UserService还没实现
user_data = {
"email": "test@example.com",
"password": "SecurePassword123!",
"name": "Test User"
}
# 预期会失败,因为我们还没有实现UserService
with pytest.raises(Exception): # 期望抛出异常直到实现完成
result = user_service.create_user(user_data)
def test_get_user_by_email(self, user_service):
"""测试通过邮箱获取用户"""
# RED: 编写失败的测试
with pytest.raises(Exception):
user_service.get_user_by_email("test@example.com")
# 2. GREEN: 实现最简单的代码使测试通过
"""
# app/services/user_service.py - 实现UserService
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
class UserService:
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def create_user(self, user_data: Dict[str, Any]):
# 简单实现使测试通过
return {
"id": 1,
"email": user_data["email"],
"name": user_data["name"],
"created_at": "2024-01-01T00:00:00Z"
}
async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
# 简单实现使测试通过
if email == "test@example.com":
return {
"id": 1,
"email": email,
"name": "Test User"
}
return None
"""
# 3. REFACTOR: 重构代码
"""
# 重构后的UserService实现
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import User
from app.schemas import UserCreate
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class UserService:
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def create_user(self, user_data: UserCreate) -> User:
hashed_password = pwd_context.hash(user_data.password)
db_user = User(
email=user_data.email,
name=user_data.name,
hashed_password=hashed_password
)
self.db_session.add(db_user)
await self.db_session.commit()
await self.db_session.refresh(db_user)
return db_user
async def get_user_by_email(self, email: str) -> Optional[User]:
result = await self.db_session.execute(
select(User).filter(User.email == email)
)
return result.scalars().first()
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
user = await self.get_user_by_email(email)
if user and pwd_context.verify(password, user.hashed_password):
return user
return None
"""
# TDD完整示例:开发购物车功能
class TestShoppingCartTDD:
"""购物车功能的TDD开发示例"""
def test_cart_starts_empty(self):
"""测试购物车初始化时空的 - RED"""
# 首先编写测试
cart = ShoppingCart()
assert len(cart.items) == 0
assert cart.total_price == 0
def test_add_item_to_cart(self):
"""测试向购物车添加商品 - RED"""
cart = ShoppingCart()
item = {"id": 1, "name": "Product", "price": 10.0, "quantity": 1}
cart.add_item(item)
assert len(cart.items) == 1
assert cart.items[0]["name"] == "Product"
assert cart.total_price == 10.0
def test_remove_item_from_cart(self):
"""测试从购物车移除商品 - RED"""
cart = ShoppingCart()
item = {"id": 1, "name": "Product", "price": 10.0, "quantity": 1}
cart.add_item(item)
cart.remove_item(1)
assert len(cart.items) == 0
assert cart.total_price == 0
def test_update_item_quantity(self):
"""测试更新商品数量 - RED"""
cart = ShoppingCart()
item = {"id": 1, "name": "Product", "price": 10.0, "quantity": 1}
cart.add_item(item)
cart.update_quantity(1, 3)
assert cart.items[0]["quantity"] == 3
assert cart.total_price == 30.0
"""
# GREEN: 实现功能使测试通过
class ShoppingCart:
def __init__(self):
self.items = []
@property
def total_price(self):
return sum(item["price"] * item["quantity"] for item in self.items)
def add_item(self, item):
self.items.append(item)
def remove_item(self, item_id):
self.items = [item for item in self.items if item["id"] != item_id]
def update_quantity(self, item_id, quantity):
for item in self.items:
if item["id"] == item_id:
item["quantity"] = quantity
break
"""
# REFACTOR: 重构为更健壮的实现
"""
class ShoppingCart:
def __init__(self):
self._items = {} # 使用字典以ID为键,便于查找
def add_item(self, item):
item_id = item["id"]
if item_id in self._items:
# 如果商品已存在,更新数量
self._items[item_id]["quantity"] += item["quantity"]
else:
self._items[item_id] = item.copy()
def remove_item(self, item_id):
if item_id in self._items:
del self._items[item_id]
def update_quantity(self, item_id, quantity):
if item_id in self._items:
if quantity <= 0:
self.remove_item(item_id)
else:
self._items[item_id]["quantity"] = quantity
def get_item(self, item_id):
return self._items.get(item_id)
@property
def items(self):
return list(self._items.values())
@property
def total_price(self):
return sum(item["price"] * item["quantity"]
for item in self._items.values())
@property
def item_count(self):
return sum(item["quantity"] for item in self._items.values())
"""
# TDD测试策略
class TestTDDStrategies:
"""TDD测试策略"""
def test_outside_in_tdd(self):
"""外部到内部TDD - 从API开始测试"""
# 先测试外部接口
response = client.get("/cart")
assert response.status_code == 200
# 然后逐步深入到内部实现
# 这种方式适合API驱动开发
def test_inside_out_tdd(self):
"""内部到外部TDD - 从核心逻辑开始测试"""
# 先测试核心业务逻辑
calculator = DiscountCalculator()
assert calculator.calculate(100, "regular") == 5 # 5%折扣
# 然后测试使用该逻辑的上层组件
# 这种方式适合算法和业务逻辑驱动开发
def test_mockist_tdd(self):
"""Mockist TDD - 使用模拟对象测试"""
mock_payment_gateway = Mock()
mock_payment_gateway.process.return_value = True
processor = PaymentProcessor(mock_payment_gateway)
result = processor.pay(100, "test_card")
assert result is True
mock_payment_gateway.process.assert_called_once()
def test_classic_tdd(self):
"""Classic TDD - 使用真实对象测试"""
# 使用真实的数据库、真实的外部服务
# 更接近生产环境,但测试速度较慢
# TDD最佳实践
"""
TDD Best Practices:
1. Test One Thing: 每个测试只验证一个行为
2. Fast Tests: 测试应该快速运行
3. Independent Tests: 测试之间不应该相互依赖
4. Repeatable Tests: 测试结果应该是一致的
5. Self-Validating: 测试应该自动判断通过还是失败
6. Timely Tests: 测试应该在代码之前或同时编写
"""
## 测试性能优化 \{#测试性能优化}
### 测试并行化
```python
# pytest_parallel_example.py - 并行测试示例
"""
pytest-xdist 插件可以实现测试并行化:
安装:
pip install pytest-xdist
运行:
pytest -n auto # 自动检测CPU核心数
pytest -n 4 # 指定4个进程
pytest --dist worksteal # 工作窃取模式
"""
# conftest.py 中的并行化配置
@pytest.fixture(scope="session")
def shared_resource():
"""共享资源,只初始化一次"""
print("Initializing shared resource")
return {"connection": "mock_connection"}
# 使用pytest-benchmark进行性能测试
"""
pip install pytest-benchmark
"""
def test_with_benchmark(benchmark):
"""使用benchmark进行性能测试"""
def expensive_function():
# 模拟耗时操作
result = sum(i**2 for i in range(1000))
return result
result = benchmark(expensive_function)
assert result == sum(i**2 for i in range(1000))
# 性能回归测试
def test_performance_regression(benchmark):
"""性能回归测试"""
# 设定性能阈值
threshold = 0.1 # 100ms
def target_function():
import time
time.sleep(0.05) # 模拟50ms操作
return "result"
result = benchmark(target_function)
# 检查是否超过阈值
assert benchmark.stats.stats.max < threshold#测试数据优化
# test_data_optimization.py
from pytest_lazyfixture import lazy_fixture
# 使用工厂函数创建测试数据
@pytest.fixture
def user_factory():
"""用户工厂"""
def _create_user(email=None, name=None, role="user"):
return {
"email": email or f"user{random.randint(1, 1000)}@example.com",
"name": name or f"User {random.randint(1, 1000)}",
"role": role
}
return _create_user
# 使用Faker生成测试数据
from faker import Faker
fake = Faker()
@pytest.fixture
def realistic_user_data():
"""生成逼真的测试数据"""
return {
"email": fake.email(),
"name": fake.name(),
"address": fake.address(),
"phone": fake.phone_number(),
"company": fake.company()
}
# 测试数据重用
@pytest.fixture(scope="module")
def test_dataset():
"""模块级别的测试数据集"""
# 生成大量测试数据,只生成一次
dataset = []
for i in range(1000):
dataset.append({
"id": i,
"name": f"Item {i}",
"value": random.random() * 100
})
return dataset
# 数据清理优化
@pytest.fixture(autouse=True)
def cleanup_test_data():
"""自动清理测试数据"""
yield
# 测试后清理
# 清理会话、临时文件等
pass#测试缓存和跳过
# test_caching_and_skipping.py
import pytest
# 使用缓存避免重复计算
@pytest.fixture(scope="session")
def expensive_computation():
"""昂贵的计算,只执行一次"""
print("Performing expensive computation...")
# 模拟昂贵的计算
result = sum(i**2 for i in range(100000))
return result
# 条件跳过测试
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
def test_new_feature():
"""只在Python 3.8+上运行的测试"""
# 使用Python 3.8+的新特性
pass
# 跳过不稳定的测试
@pytest.mark.skip(reason="Unstable test, needs investigation")
def test_unstable_feature():
"""暂时跳过的不稳定测试"""
pass
# 预期失败的测试
@pytest.mark.xfail(strict=True)
def test_expected_failure():
"""预期会失败的测试"""
assert False # 预期失败#与其他测试框架对比
#Pytest vs Unittest
| 特性 | Pytest | Unittest |
|---|---|---|
| 语法简洁性 | ⭐⭐⭐⭐⭐ 简洁直观 | ⭐⭐⭐ 传统但冗长 |
| Fixtures | ⭐⭐⭐⭐⭐ 强大的Fixtures系统 | ⭐⭐⭐ setUp/tearDown |
| 参数化测试 | ⭐⭐⭐⭐⭐ 简单易用 | ⭐⭐⭐ 较复杂 |
| 插件生态系统 | ⭐⭐⭐⭐⭐ 丰富的插件 | ⭐⭐⭐ 有限 |
| 异步测试支持 | ⭐⭐⭐⭐⭐ 原生支持 | ⭐⭐ 需要额外库 |
| 断言 | ⭐⭐⭐⭐⭐ 简单的assert | ⭐⭐⭐ 多种assert方法 |
| 学习曲线 | ⭐⭐⭐⭐⭐ 容易上手 | ⭐⭐⭐⭐ 标准库 |
#Pytest vs Nose
| 特性 | Pytest | Nose |
|---|---|---|
| 活跃度 | ⭐⭐⭐⭐⭐ 活跃维护 | ⭐ 不再维护 |
| Python 3支持 | ⭐⭐⭐⭐⭐ 完全支持 | ⭐⭐⭐⭐ 部分支持 |
| 插件系统 | ⭐⭐⭐⭐⭐ 优秀 | ⭐⭐⭐ 一般 |
| 性能 | ⭐⭐⭐⭐⭐ 快速 | ⭐⭐⭐⭐ 可接受 |
#选择建议
"""
选择测试框架的决策树:
1. 如果使用Python 3.6+ → 优先选择Pytest
2. 如果需要与unittest集成 → Pytest兼容性好
3. 如果是大型企业项目 → Pytest + 插件生态系统
4. 如果是小型项目 → Pytest简单够用
5. 如果有特殊需求 → 查看Pytest插件能否满足
Pytest优势:
- 语法简洁
- 功能强大
- 插件丰富
- 社区活跃
- 与主流工具集成好
"""#总结
Pytest是Python中最强大的测试框架之一,特别适合FastAPI应用的测试:
- 简洁语法:使用简单的assert语句
- 强大Fixtures:灵活的依赖管理
- 参数化测试:轻松测试多种场景
- 丰富的插件:扩展功能强大
- 异步支持:天然支持async/await
- 集成友好:与CI/CD无缝集成
💡 关键要点:良好的测试策略是应用质量的保障。Pytest提供了完整的测试解决方案,从单元测试到集成测试,从简单验证到复杂场景,都能胜任。
#SEO优化建议
为了提高这篇Pytest单元测试教程在搜索引擎中的排名,以下是几个关键的SEO优化建议:
#标题优化
- 主标题:使用包含核心关键词的标题,如"FastAPI Pytest单元测试完全指南"
- 二级标题:每个章节标题都包含相关的长尾关键词
- H1-H6层次结构:保持正确的标题层级,便于搜索引擎理解内容结构
#内容优化
- 关键词密度:在内容中自然地融入关键词如"Pytest"、"单元测试"、"FastAPI"、"异步测试"等
- 元描述:在文章开头的元数据中包含吸引人的描述
- 内部链接:链接到其他相关教程,如FastAPI依赖注入系统等
- 外部权威链接:引用官方文档和权威资源
#技术SEO
- 页面加载速度:优化代码块和图片加载
- 移动端适配:确保在移动设备上良好显示
- 结构化数据:使用适当的HTML标签和语义化元素
#用户体验优化
- 内容可读性:使用清晰的段落结构和代码示例
- 互动元素:提供实际可运行的代码示例
- 更新频率:定期更新内容以保持时效性
#常见问题解答(FAQ)
#Q1: Pytest和unittest有什么区别?
A: Pytest相比unittest有更简洁的语法、强大的fixture系统、原生的参数化测试支持,以及丰富的插件生态系统。Pytest可以直接使用简单的assert语句,而unittest需要使用特定的assert方法。
#Q2: 如何在FastAPI中测试异步函数?
A: 使用pytest-asyncio插件,添加@pytest.mark.asyncio装饰器,然后使用httpx.AsyncClient进行异步测试。
#Q3: 如何测试数据库相关的功能?
A: 使用内存数据库(如SQLite in-memory)进行测试,通过fixture提供测试数据库会话,确保每个测试后数据库状态清洁。
#Q4: 什么是TDD,为什么要使用它?
A: TDD(Test-Driven Development)是测试驱动开发,先写测试再写实现代码。它能提高代码质量、减少bug、改善设计并提供文档作用。
#Q5: 如何提高测试覆盖率?
A: 使用pytest-cov工具分析覆盖率,重点关注业务逻辑、错误路径、边界条件和异常处理的测试覆盖。
🔗 相关教程推荐
- FastAPI依赖注入系统 - 深入理解依赖注入机制
- FastAPI多环境配置 - 环境配置管理
- FastAPI异常处理 - 统一异常处理策略
- FastAPI中间件应用 - 中间件开发与应用
- Python基础语法 - Python语言基础
🏷️ 标签云: FastAPI测试 Pytest教程 单元测试 异步测试 TDD开发 API测试 测试覆盖率 Fixtures 参数化测试 数据库测试

