FastAPI Pytest单元测试完全指南

📂 所属阶段:第五阶段 — 工程化与部署(实战篇)
🔗 相关章节:FastAPI依赖注入系统 · FastAPI多环境配置

目录

单元测试基础概念

为什么需要单元测试?

单元测试是软件开发中的基石,它确保代码的正确性、可维护性和可扩展性。在FastAPI应用中,单元测试尤为重要:

# 无测试 vs 有测试的开发体验
"""
无测试场景:
- 修改代码 → 手动测试 → 上线 → 发现bug → 回滚 → 用户投诉 😱
- 重构代码 → 担心破坏功能 → 代码腐化 → 技术债务堆积

有测试场景:
- 修改代码 → 运行测试 → 快速反馈 → 自信上线 🚀
- 重构代码 → 测试保证 → 持续优化 → 代码质量提升
"""

测试金字塔

在FastAPI应用中,遵循测试金字塔原则:

┌─────────────────────────┐  ← 单元测试 (Unit Tests) - 70%
│      业务逻辑层         │    • 快速、隔离、专注
│    (Service Layer)      │    • 测试纯函数和业务逻辑
├─────────────────────────┤  ← 集成测试 (Integration) - 20%
│      API/路由层         │    • 测试API端点和数据流
│     (Route Layer)       │    • 包含数据库、外部服务
├─────────────────────────┤  ← 端到端测试 (E2E) - 10%
│      UI/界面层          │    • 测试完整用户流程
│    (UI Layer)           │    • 使用Selenium等工具
└─────────────────────────┘

测试驱动开发(TDD)的好处

  1. 设计驱动:先思考接口设计
  2. 快速反馈:即时验证代码正确性
  3. 重构安全:测试保护重构过程
  4. 文档作用:测试即使用示例
  5. 信心保证:确保功能按预期工作

Pytest环境搭建

核心依赖安装

# 基础测试依赖
pip install pytest pytest-asyncio httpx

# 覆盖率分析
pip install pytest-cov coverage

# Mock和补丁
pip install pytest-mock

# 参数化测试增强
pip install pytest-parametrize

# 测试数据生成
pip install factory-boy faker

# 断言增强
pip install pytest-check

# 环境变量管理
pip install python-dotenv pytest-dotenv

# 完整安装命令
pip install pytest pytest-asyncio httpx pytest-cov pytest-mock factory-boy faker

Pytest配置文件

# pytest.ini - 项目根目录
[tool:pytest]
# 测试路径
testpaths = tests
python_files = test_*.py *_test.py
python_classes = Test*
python_functions = test_*

# 异步支持
asyncio_mode = auto

# 输出选项
addopts = 
    -v
    --tb=short
    --strict-markers
    --strict-config
    --disable-warnings

# 覆盖率配置
filterwarnings = 
    ignore::DeprecationWarning
    ignore::UserWarning
    error::pytest.PytestWarning

# 标记配置
markers = 
    slow: marks tests as slow
    integration: marks tests as integration tests
    unit: marks tests as unit tests
    api: marks tests as api tests
    database: marks tests as database tests
    auth: marks tests as authentication tests

项目结构配置

# pyproject.toml - 现代Python项目配置
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
asyncio_mode = "auto"
addopts = [
    "-ra",           # 显示摘要和失败详情
    "--showlocals",  # 显示局部变量
    "--tb=short",    # 简短回溯
    "--strict-markers",
    "--strict-config",
]
markers = [
    "slow: marks tests as slow",
    "integration: marks tests as integration tests", 
    "unit: marks tests as unit tests",
    "api: marks tests as api tests",
    "database: marks tests as database tests",
    "auth: marks tests as authentication tests"
]

[tool.coverage.run]
source = ["src/", "app/"]
omit = [
    "*/venv/*",
    "*/tests/*",
    "*/migrations/*",
    "*/config/*",
    "*/__init__.py"
]

[tool.coverage.report]
exclude_lines = [
    "pragma: no cover",
    "def __repr__",
    "raise AssertionError",
    "raise NotImplementedError",
    "if __name__ == .__main__.:"
]

conftest.py配置文件

# tests/conftest.py - Pytest配置文件
import asyncio
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from fastapi.testclient import TestClient

# 应用程序导入
from app.main import app
from app.database import get_db, engine, Base
from app.models import User, Item
from app.config import settings

# 测试数据库URL - 使用内存数据库
TEST_DATABASE_URL = "sqlite:///./test.db"

@pytest.fixture(scope="session")
def event_loop():
    """创建事件循环"""
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()

@pytest.fixture(scope="session")
async def test_engine():
    """创建测试数据库引擎"""
    engine = create_engine(
        TEST_DATABASE_URL,
        connect_args={"check_same_thread": False},
        poolclass=StaticPool,
    )
    yield engine
    engine.dispose()

@pytest.fixture(scope="function")
async def test_session(test_engine):
    """创建测试数据库会话"""
    async with test_engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    
    async_session = sessionmaker(
        test_engine, 
        expire_on_commit=False, 
        class_=AsyncSession
    )
    
    async with async_session() as session:
        yield session
        
        # 回滚事务,确保数据清洁
        await session.rollback()
        
        # 清理数据库
        await conn.run_sync(Base.metadata.drop_all)

@pytest.fixture(scope="function")
def override_dependencies(test_session):
    """覆盖应用依赖"""
    def get_test_db():
        return test_session
    
    app.dependency_overrides[get_db] = get_test_db
    yield
    app.dependency_overrides.clear()

@pytest.fixture
async def async_client(override_dependencies):
    """异步测试客户端"""
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as ac:
        yield ac

@pytest.fixture
async def authenticated_client(async_client):
    """认证的测试客户端"""
    # 创建测试用户
    response = await async_client.post("/auth/register", json={
        "email": "test@example.com",
        "password": "TestPassword123!",
        "name": "Test User"
    })
    
    # 登录获取token
    login_response = await async_client.post("/auth/login", data={
        "username": "test@example.com",
        "password": "TestPassword123!"
    })
    
    token = login_response.json()["access_token"]
    async_client.headers["Authorization"] = f"Bearer {token}"
    
    yield async_client

@pytest.fixture
def sample_user_data():
    """样本用户数据"""
    return {
        "email": "test@example.com",
        "password": "TestPassword123!",
        "name": "Test User",
        "is_active": True
    }

@pytest.fixture
def sample_item_data():
    """样本项目数据"""
    return {
        "title": "Test Item",
        "description": "This is a test item",
        "price": 99.99
    }

TestClient基础使用

TestClient同步测试

# tests/test_basic.py
import pytest
from fastapi.testclient import TestClient
from app.main import app

client = TestClient(app)

def test_root_endpoint():
    """测试根端点"""
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"message": "Welcome to DaomanAPI"}

def test_health_check():
    """测试健康检查端点"""
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json() == {"status": "healthy", "timestamp": pytest.approx(0, abs=1)}

def test_post_request():
    """测试POST请求"""
    test_data = {"name": "Test User", "email": "test@example.com"}
    response = client.post("/users/", json=test_data)
    assert response.status_code == 201
    assert response.json()["name"] == "Test User"
    assert response.json()["email"] == "test@example.com"

def test_query_parameters():
    """测试查询参数"""
    response = client.get("/items/", params={"skip": 0, "limit": 10})
    assert response.status_code == 200
    assert isinstance(response.json(), list)

def test_path_parameters():
    """测试路径参数"""
    user_id = 1
    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200
    assert response.json()["id"] == user_id

def test_error_handling():
    """测试错误处理"""
    response = client.get("/nonexistent-endpoint")
    assert response.status_code == 404

def test_validation_errors():
    """测试验证错误"""
    invalid_data = {"name": "Too Short"}  # 假设name需要至少10个字符
    response = client.post("/users/", json=invalid_data)
    assert response.status_code == 422
    assert "detail" in response.json()

HTTP方法测试

# tests/test_http_methods.py
import pytest

class TestHTTPMethods:
    """HTTP方法测试类"""
    
    def test_get_method(self, client):
        """测试GET方法"""
        response = client.get("/users/1")
        assert response.status_code == 200
    
    def test_post_method(self, client):
        """测试POST方法"""
        data = {"name": "New User", "email": "new@example.com"}
        response = client.post("/users/", json=data)
        assert response.status_code == 201
        assert response.json()["name"] == "New User"
    
    def test_put_method(self, client):
        """测试PUT方法"""
        data = {"name": "Updated User", "email": "updated@example.com"}
        response = client.put("/users/1", json=data)
        assert response.status_code == 200
        assert response.json()["name"] == "Updated User"
    
    def test_patch_method(self, client):
        """测试PATCH方法"""
        partial_data = {"name": "Partially Updated"}
        response = client.patch("/users/1", json=partial_data)
        assert response.status_code == 200
        assert response.json()["name"] == "Partially Updated"
    
    def test_delete_method(self, client):
        """测试DELETE方法"""
        response = client.delete("/users/1")
        assert response.status_code == 204  # No Content

请求头和Cookie测试

# tests/test_headers_cookies.py
def test_custom_headers(client):
    """测试自定义请求头"""
    headers = {"X-Custom-Header": "test-value", "User-Agent": "Test-Client/1.0"}
    response = client.get("/headers", headers=headers)
    assert response.status_code == 200
    # 验证服务器收到了正确的头信息

def test_authentication_headers(client):
    """测试认证头"""
    # 无认证
    response = client.get("/protected")
    assert response.status_code == 401
    
    # 有认证
    headers = {"Authorization": "Bearer fake-token"}
    response = client.get("/protected", headers=headers)
    assert response.status_code in [200, 401]  # 取决于token有效性

def test_cookie_handling(client):
    """测试Cookie处理"""
    response = client.get("/set-cookie")
    assert response.status_code == 200
    assert "set-cookie" in response.headers
    
    # 使用Cookie进行后续请求
    cookie_value = response.cookies.get("session_id")
    response = client.get("/with-cookie", cookies={"session_id": cookie_value})
    assert response.status_code == 200

异步测试详解

异步测试基础

# tests/test_async.py
import pytest
import asyncio
from httpx import AsyncClient, ASGITransport
from app.main import app

@pytest.mark.asyncio
async def test_async_endpoint():
    """异步端点测试"""
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as ac:
        response = await ac.get("/")
        assert response.status_code == 200
        assert response.json() == {"message": "Welcome to DaomanAPI"}

@pytest.mark.asyncio
async def test_async_post():
    """异步POST请求测试"""
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as ac:
        data = {"name": "Async User", "email": "async@example.com"}
        response = await ac.post("/users/", json=data)
        assert response.status_code == 201

class TestAsyncEndpoints:
    """异步端点测试类"""
    
    @pytest.mark.asyncio
    async def test_list_users(self, async_client):
        """测试用户列表端点"""
        response = await async_client.get("/users/")
        assert response.status_code == 200
        assert isinstance(response.json(), list)
    
    @pytest.mark.asyncio
    async def test_create_user(self, async_client):
        """测试创建用户"""
        user_data = {
            "email": "async_test@example.com",
            "password": "AsyncPassword123!",
            "name": "Async Test User"
        }
        response = await async_client.post("/users/", json=user_data)
        assert response.status_code == 201
        assert response.json()["email"] == user_data["email"]
    
    @pytest.mark.asyncio
    async def test_get_user(self, async_client):
        """测试获取用户"""
        # 先创建用户
        user_data = {
            "email": "get_test@example.com",
            "password": "GetPassword123!",
            "name": "Get Test User"
        }
        create_response = await async_client.post("/users/", json=user_data)
        user_id = create_response.json()["id"]
        
        # 获取用户
        response = await async_client.get(f"/users/{user_id}")
        assert response.status_code == 200
        assert response.json()["email"] == user_data["email"]

异步并发测试

# tests/test_async_concurrency.py
import pytest
import asyncio
from httpx import AsyncClient, ASGITransport

@pytest.mark.asyncio
async def test_concurrent_requests():
    """测试并发请求"""
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as ac:
        # 创建多个并发请求
        tasks = [
            ac.get("/users"),
            ac.get("/items"),
            ac.get("/health"),
            ac.get("/metrics")
        ]
        
        responses = await asyncio.gather(*tasks)
        
        # 验证所有请求都成功
        for response in responses:
            assert response.status_code in [200, 201]

@pytest.mark.asyncio
async def test_load_simulation():
    """负载模拟测试"""
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as ac:
        # 模拟10个并发用户
        tasks = [ac.get("/health") for _ in range(10)]
        responses = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 验证响应
        successful_responses = [
            r for r in responses 
            if not isinstance(r, Exception) and r.status_code == 200
        ]
        assert len(successful_responses) == 10

@pytest.mark.asyncio
async def test_streaming_response():
    """测试流式响应"""
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as ac:
        # 假设有流式端点
        async with ac.stream("GET", "/stream-data") as response:
            assert response.status_code == 200
            chunks = []
            async for chunk in response.aiter_text():
                chunks.append(chunk)
            assert len(chunks) > 0

WebSocket测试

# tests/test_websocket.py
import pytest
import asyncio
from websockets.sync.client import connect
from app.main import app

@pytest.mark.asyncio
async def test_websocket_connection():
    """WebSocket连接测试"""
    # 注意:FastAPI的WebSocket测试需要特殊的处理
    import websockets
    import json
    
    async def test_websocket():
        uri = "ws://localhost:8000/ws"
        async with websockets.connect(uri) as websocket:
            # 发送消息
            await websocket.send(json.dumps({"type": "ping"}))
            
            # 接收响应
            response = await websocket.recv()
            data = json.loads(response)
            assert data["type"] == "pong"
    
    # 由于WebSocket测试的复杂性,这里简化示例
    # 实际应用中可能需要使用专门的WebSocket测试库
    pass

Fixtures与依赖注入测试

Fixtures高级用法

# tests/test_fixtures_advanced.py
import pytest
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from app.services.user_service import UserService
from app.dependencies import get_current_user

@pytest.fixture
def mock_user_service():
    """模拟用户服务"""
    mock = AsyncMock(spec=UserService)
    mock.get_user_by_id.return_value = {
        "id": 1,
        "email": "mock@example.com",
        "name": "Mock User"
    }
    mock.create_user.return_value = {
        "id": 1,
        "email": "new@example.com",
        "name": "New User"
    }
    return mock

@pytest.fixture
def mock_current_user():
    """模拟当前用户"""
    return {
        "id": 1,
        "email": "current@example.com",
        "name": "Current User",
        "role": "admin"
    }

@pytest.fixture
def override_user_dependency(mock_current_user):
    """覆盖用户依赖"""
    def mock_get_current_user():
        return mock_current_user
    
    app.dependency_overrides[get_current_user] = mock_get_current_user
    yield
    app.dependency_overrides.clear()

def test_with_mocked_service(mock_user_service):
    """使用模拟服务的测试"""
    # 测试逻辑使用mock_user_service
    assert mock_user_service.get_user_by_id.called is False
    user = mock_user_service.get_user_by_id(1)
    assert user["email"] == "mock@example.com"

class TestWithFixtures:
    """使用Fixtures的测试类"""
    
    def test_authenticated_endpoint(self, authenticated_client):
        """测试认证端点"""
        response = authenticated_client.get("/users/me")
        assert response.status_code == 200
    
    def test_with_sample_data(self, client, sample_user_data):
        """使用样本数据的测试"""
        response = client.post("/users/", json=sample_user_data)
        assert response.status_code == 201
        assert response.json()["email"] == sample_user_data["email"]
    
    def test_with_multiple_fixtures(self, client, sample_user_data, sample_item_data):
        """使用多个Fixtures的测试"""
        # 创建用户
        user_response = client.post("/users/", json=sample_user_data)
        assert user_response.status_code == 201
        user_id = user_response.json()["id"]
        
        # 创建关联的项目
        sample_item_data["user_id"] = user_id
        item_response = client.post("/items/", json=sample_item_data)
        assert item_response.status_code == 201
        assert item_response.json()["user_id"] == user_id

@pytest.fixture(scope="session")
def test_data_generator():
    """测试数据生成器"""
    from faker import Faker
    fake = Faker()
    
    def generate_user_data():
        return {
            "email": fake.email(),
            "name": fake.name(),
            "password": fake.password(length=12, special_chars=True, digits=True),
            "phone": fake.phone_number()
        }
    
    def generate_item_data():
        return {
            "title": fake.catch_phrase(),
            "description": fake.text(max_nb_chars=200),
            "price": round(fake.random.uniform(10.0, 1000.0), 2)
        }
    
    return {
        "user": generate_user_data,
        "item": generate_item_data
    }

def test_with_generated_data(client, test_data_generator):
    """使用生成数据的测试"""
    user_data = test_data_generator["user"]()
    response = client.post("/users/", json=user_data)
    assert response.status_code == 201

依赖注入测试

# tests/test_dependency_injection.py
import pytest
from fastapi import Depends, HTTPException
from unittest.mock import patch, MagicMock
from app.dependencies import get_db, get_current_user, get_settings
from app.models import User

def test_db_dependency():
    """测试数据库依赖"""
    # 验证依赖函数返回正确的类型
    db_gen = get_db()
    db_session = next(db_gen)
    assert hasattr(db_session, 'execute')  # 验证是数据库会话
    db_gen.close()  # 清理

def test_current_user_dependency():
    """测试当前用户依赖"""
    # 这个测试需要模拟认证逻辑
    pass

class TestDependencyOverrides:
    """依赖覆盖测试"""
    
    def test_override_database(self, client):
        """测试数据库依赖覆盖"""
        # 在conftest.py中已经覆盖了数据库依赖
        # 这里验证覆盖生效
        response = client.get("/health")
        assert response.status_code == 200
    
    def test_override_settings(self):
        """测试设置依赖覆盖"""
        # 模拟覆盖设置
        from app.config import settings
        
        # 保存原始值
        original_debug = settings.debug
        
        # 临时覆盖
        settings.debug = True
        
        try:
            # 测试使用覆盖后的设置
            assert settings.debug is True
        finally:
            # 恢复原始值
            settings.debug = original_debug

@pytest.fixture
def mock_db_dependency():
    """模拟数据库依赖"""
    mock_session = MagicMock()
    mock_session.query.return_value.filter.return_value.first.return_value = User(
        id=1,
        email="test@example.com",
        name="Test User"
    )
    
    def get_mock_db():
        yield mock_session
    
    app.dependency_overrides[get_db] = get_mock_db
    yield mock_session
    app.dependency_overrides.clear()

def test_with_mocked_db(mock_db_dependency):
    """使用模拟数据库的测试"""
    # 这里的测试会使用模拟的数据库会话
    from app.routers.users import get_user  # 假设有一个这样的路由函数
    
    # 由于我们无法直接测试路由函数,这里展示概念
    # 实际测试中,路由会被FastAPI自动调用并使用覆盖的依赖
    assert mock_db_dependency.query.called is False

# 测试依赖注入的完整示例
class TestCompleteDependencyInjection:
    """完整的依赖注入测试"""
    
    @pytest.fixture
    def setup_dependencies(self, test_session):
        """设置多个依赖"""
        def override_get_db():
            yield test_session
        
        def override_get_current_user():
            return {"id": 1, "email": "test@example.com", "role": "admin"}
        
        app.dependency_overrides[get_db] = override_get_db
        app.dependency_overrides[get_current_user] = override_get_current_user
        yield
        app.dependency_overrides.clear()
    
    def test_complete_setup(self, client, setup_dependencies):
        """测试完整依赖设置"""
        # 测试需要认证的端点
        response = client.get("/admin/dashboard")  # 假设这是管理员端点
        # 根据实际路由实现验证响应
        assert response.status_code in [200, 404]  # 404表示端点不存在但认证成功

数据库测试策略

数据库事务管理

# tests/test_database_transactions.py
import pytest
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from app.models import User, Item
from app.schemas import UserCreate, ItemCreate

class TestDatabaseTransactions:
    """数据库事务测试"""
    
    async def test_successful_transaction(self, test_session):
        """测试成功的事务"""
        from app.crud import create_user
        
        user_data = UserCreate(
            email="transaction_test@example.com",
            password="Password123!",
            name="Transaction Test"
        )
        
        user = await create_user(test_session, user_data)
        assert user.email == user_data.email
        assert user.name == user_data.name
        
        # 验证数据已提交
        result = await test_session.execute(
            text("SELECT COUNT(*) FROM users WHERE email = :email"),
            {"email": user_data.email}
        )
        count = result.scalar()
        assert count == 1
    
    async def test_rollback_on_error(self, test_session):
        """测试错误时回滚"""
        from app.crud import create_user
        
        # 创建一个有效的用户
        valid_user = UserCreate(
            email="valid@example.com",
            password="Password123!",
            name="Valid User"
        )
        await create_user(test_session, valid_user)
        
        # 尝试创建重复的用户(应该失败)
        duplicate_user = UserCreate(
            email="valid@example.com",  # 重复邮箱
            password="Password123!",
            name="Duplicate User"
        )
        
        try:
            await create_user(test_session, duplicate_user)
            assert False, "应该抛出异常"
        except IntegrityError:
            # 验证事务已回滚,数据库中仍只有1个用户
            result = await test_session.execute(text("SELECT COUNT(*) FROM users"))
            count = result.scalar()
            assert count == 1
    
    async def test_nested_transactions(self, test_session):
        """测试嵌套事务"""
        from app.crud import create_user, create_item
        
        # 创建用户
        user_data = UserCreate(
            email="nested_test@example.com",
            password="Password123!",
            name="Nested Test"
        )
        user = await create_user(test_session, user_data)
        
        # 创建关联的项目
        item_data = ItemCreate(
            title="Test Item",
            description="Test Description",
            price=99.99,
            owner_id=user.id
        )
        item = await create_item(test_session, item_data)
        
        # 验证关联关系
        assert item.owner_id == user.id
        
        # 验证数据持久化
        user_result = await test_session.execute(
            text("SELECT * FROM users WHERE id = :id"), {"id": user.id}
        )
        user_from_db = user_result.fetchone()
        assert user_from_db is not None
        
        item_result = await test_session.execute(
            text("SELECT * FROM items WHERE id = :id"), {"id": item.id}
        )
        item_from_db = item_result.fetchone()
        assert item_from_db is not None

@pytest.fixture
def clean_database_state(test_session):
    """确保数据库状态清洁"""
    # 在测试开始前清理相关数据
    await test_session.execute(text("DELETE FROM items"))
    await test_session.execute(text("DELETE FROM users"))
    await test_session.commit()
    yield test_session
    # 测试结束后再次清理(尽管在conftest.py中已经有回滚)

class TestDataIsolation:
    """数据隔离测试"""
    
    async def test_isolation_between_tests(self, clean_database_state):
        """测试测试间的数据隔离"""
        # 验证开始时数据库为空
        result = await clean_database_state.execute(text("SELECT COUNT(*) FROM users"))
        assert result.scalar() == 0
        
        # 添加一些数据
        result = await clean_database_state.execute(
            text("INSERT INTO users (email, name, password_hash) VALUES (:email, :name, :password)"),
            {"email": "isolation@example.com", "name": "Isolation", "password": "hash"}
        )
        
        # 在同个session中验证数据存在
        result = await clean_database_state.execute(text("SELECT COUNT(*) FROM users"))
        assert result.scalar() == 1

# 数据库性能测试
class TestDatabasePerformance:
    """数据库性能测试"""
    
    @pytest.mark.benchmark
    async def test_bulk_insert_performance(self, test_session):
        """测试批量插入性能"""
        import time
        from app.crud import create_user
        
        start_time = time.time()
        
        # 批量创建用户
        for i in range(100):
            user_data = UserCreate(
                email=f"bulk_test_{i}@example.com",
                password="Password123!",
                name=f"Bulk User {i}"
            )
            await create_user(test_session, user_data)
        
        end_time = time.time()
        duration = end_time - start_time
        
        print(f"Created 100 users in {duration:.2f} seconds")
        assert duration < 5.0  # 应该在5秒内完成
    
    @pytest.mark.benchmark
    async def test_query_performance(self, test_session):
        """测试查询性能"""
        import time
        from app.crud import get_users
        
        # 先创建一些测试数据
        for i in range(50):
            user_data = UserCreate(
                email=f"query_test_{i}@example.com",
                password="Password123!",
                name=f"Query User {i}"
            )
            await create_user(test_session, user_data)
        
        start_time = time.time()
        
        # 执行多次查询
        for _ in range(10):
            users = await get_users(test_session, skip=0, limit=10)
            assert len(users) <= 10
        
        end_time = time.time()
        duration = end_time - start_time
        
        print(f"Executed 10 queries in {duration:.2f} seconds")
        assert duration < 2.0  # 应该在2秒内完成

数据库迁移测试

# tests/test_database_migrations.py
import pytest
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine, text
from app.database import DATABASE_URL

class TestDatabaseMigrations:
    """数据库迁移测试"""
    
    def test_migration_up_down(self):
        """测试迁移的上移和下移"""
        # 创建临时数据库用于测试迁移
        temp_engine = create_engine("sqlite:///./temp_test.db")
        
        try:
            # 应用所有迁移
            alembic_cfg = Config("alembic.ini")
            alembic_cfg.set_main_option("sqlalchemy.url", "sqlite:///./temp_test.db")
            command.upgrade(alembic_cfg, "head")
            
            # 验证表是否创建
            with temp_engine.connect() as conn:
                result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
                tables = [row[0] for row in result]
                
                # 根据实际情况验证表是否存在
                expected_tables = ["users", "items", "orders"]  # 根据实际模型调整
                for table in expected_tables:
                    assert table in tables, f"Table {table} not found after migration"
            
            # 回滚迁移
            command.downgrade(alembic_cfg, "base")
            
            # 验证表是否删除
            with temp_engine.connect() as conn:
                result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
                tables = [row[0] for row in result]
                
                # 回滚后应该没有业务表
                business_tables = [t for t in tables if t not in ["alembic_version"]]
                assert len(business_tables) == 0
        
        finally:
            # 清理临时数据库
            temp_engine.dispose()

    def test_migration_script_generation(self):
        """测试迁移脚本生成"""
        # 这个测试比较复杂,通常在CI/CD中运行
        # 这里只做基本的概念验证
        pass

@pytest.fixture
def migration_test_db():
    """迁移测试数据库"""
    engine = create_engine("sqlite:///./migration_test.db")
    
    # 创建基本表结构用于测试
    with engine.connect() as conn:
        conn.execute(text("""
            CREATE TABLE IF NOT EXISTS users (
                id INTEGER PRIMARY KEY,
                email VARCHAR(255) UNIQUE NOT NULL,
                name VARCHAR(255) NOT NULL
            )
        """))
    
    yield engine
    
    # 清理
    engine.dispose()

def test_data_migration(migration_test_db):
    """测试数据迁移"""
    # 插入测试数据
    with migration_test_db.connect() as conn:
        conn.execute(text("""
            INSERT INTO users (email, name) VALUES 
            ('old@example.com', 'Old User'),
            ('another@example.com', 'Another User')
        """))
        conn.commit()
    
    # 验证数据存在
    with migration_test_db.connect() as conn:
        result = conn.execute(text("SELECT COUNT(*) FROM users"))
        count = result.scalar()
        assert count == 2

API端点测试

用户管理API测试

# tests/test_user_api.py
import pytest
from httpx import AsyncClient
from app.models import User
from app.schemas import UserCreate, UserUpdate

class TestUserAPI:
    """用户API测试"""
    
    @pytest.mark.asyncio
    async def test_create_user(self, async_client):
        """测试创建用户"""
        user_data = {
            "email": "newuser@example.com",
            "password": "SecurePassword123!",
            "name": "New User"
        }
        
        response = await async_client.post("/users/", json=user_data)
        assert response.status_code == 201
        
        data = response.json()
        assert data["email"] == user_data["email"]
        assert data["name"] == user_data["name"]
        assert "id" in data
        assert "hashed_password" not in data  # 敏感信息不应返回
    
    @pytest.mark.asyncio
    async def test_create_user_duplicate_email(self, async_client):
        """测试创建重复邮箱用户"""
        user_data = {
            "email": "duplicate@example.com",
            "password": "Password123!",
            "name": "Duplicate User"
        }
        
        # 第一次创建成功
        response1 = await async_client.post("/users/", json=user_data)
        assert response1.status_code == 201
        
        # 第二次创建应该失败
        response2 = await async_client.post("/users/", json=user_data)
        assert response2.status_code == 409  # 冲突
    
    @pytest.mark.asyncio
    async def test_get_user(self, async_client):
        """测试获取用户"""
        # 先创建用户
        user_data = {
            "email": "getuser@example.com",
            "password": "Password123!",
            "name": "Get User"
        }
        create_response = await async_client.post("/users/", json=user_data)
        user_id = create_response.json()["id"]
        
        # 获取用户
        response = await async_client.get(f"/users/{user_id}")
        assert response.status_code == 200
        
        data = response.json()
        assert data["id"] == user_id
        assert data["email"] == user_data["email"]
    
    @pytest.mark.asyncio
    async def test_get_nonexistent_user(self, async_client):
        """测试获取不存在的用户"""
        response = await async_client.get("/users/999999")
        assert response.status_code == 404
    
    @pytest.mark.asyncio
    async def test_update_user(self, async_client):
        """测试更新用户"""
        # 先创建用户
        user_data = {
            "email": "update@example.com",
            "password": "Password123!",
            "name": "Update User"
        }
        create_response = await async_client.post("/users/", json=user_data)
        user_id = create_response.json()["id"]
        
        # 更新用户
        update_data = {
            "name": "Updated User Name",
            "email": "updated@example.com"
        }
        response = await async_client.put(f"/users/{user_id}", json=update_data)
        assert response.status_code == 200
        
        data = response.json()
        assert data["name"] == update_data["name"]
        assert data["email"] == update_data["email"]
    
    @pytest.mark.asyncio
    async def test_delete_user(self, async_client):
        """测试删除用户"""
        # 先创建用户
        user_data = {
            "email": "delete@example.com",
            "password": "Password123!",
            "name": "Delete User"
        }
        create_response = await async_client.post("/users/", json=user_data)
        user_id = create_response.json()["id"]
        
        # 验证用户存在
        get_response = await async_client.get(f"/users/{user_id}")
        assert get_response.status_code == 200
        
        # 删除用户
        delete_response = await async_client.delete(f"/users/{user_id}")
        assert delete_response.status_code == 204
        
        # 验证用户已删除
        get_response_after = await async_client.get(f"/users/{user_id}")
        assert get_response_after.status_code == 404
    
    @pytest.mark.asyncio
    async def test_list_users(self, async_client):
        """测试用户列表"""
        # 创建多个用户
        users_data = [
            {"email": f"user{i}@example.com", "password": "Password123!", "name": f"User {i}"}
            for i in range(5)
        ]
        
        for user_data in users_data:
            response = await async_client.post("/users/", json=user_data)
            assert response.status_code == 201
        
        # 获取用户列表
        response = await async_client.get("/users/")
        assert response.status_code == 200
        
        users = response.json()
        assert isinstance(users, list)
        assert len(users) >= 5  # 至少有5个用户
        
        # 测试分页
        paginated_response = await async_client.get("/users/", params={"skip": 0, "limit": 2})
        paginated_users = paginated_response.json()
        assert len(paginated_users) <= 2

# 认证相关API测试
class TestAuthAPI:
    """认证API测试"""
    
    @pytest.mark.asyncio
    async def test_user_registration(self, async_client):
        """测试用户注册"""
        registration_data = {
            "email": "register@example.com",
            "password": "RegisterPassword123!",
            "name": "Register User"
        }
        
        response = await async_client.post("/auth/register", json=registration_data)
        assert response.status_code == 201
        
        data = response.json()
        assert data["email"] == registration_data["email"]
        assert "id" in data
    
    @pytest.mark.asyncio
    async def test_user_login_success(self, async_client):
        """测试用户登录成功"""
        # 先注册用户
        await async_client.post("/auth/register", json={
            "email": "login@example.com",
            "password": "LoginPassword123!",
            "name": "Login User"
        })
        
        # 登录
        login_data = {
            "username": "login@example.com",
            "password": "LoginPassword123!"
        }
        response = await async_client.post("/auth/login", data=login_data)
        assert response.status_code == 200
        
        data = response.json()
        assert "access_token" in data
        assert data["token_type"] == "bearer"
    
    @pytest.mark.asyncio
    async def test_user_login_failure(self, async_client):
        """测试用户登录失败"""
        login_data = {
            "username": "nonexistent@example.com",
            "password": "wrongpassword"
        }
        response = await async_client.post("/auth/login", data=login_data)
        assert response.status_code == 401

# 项目相关API测试
class TestItemAPI:
    """项目API测试"""
    
    @pytest.mark.asyncio
    async def test_create_item(self, authenticated_client):
        """测试创建项目(需要认证)"""
        item_data = {
            "title": "Test Item",
            "description": "This is a test item",
            "price": 99.99
        }
        
        response = await authenticated_client.post("/items/", json=item_data)
        assert response.status_code == 201
        
        data = response.json()
        assert data["title"] == item_data["title"]
        assert data["price"] == item_data["price"]
    
    @pytest.mark.asyncio
    async def test_get_items(self, authenticated_client):
        """测试获取项目列表"""
        response = await authenticated_client.get("/items/")
        assert response.status_code == 200
        
        data = response.json()
        assert isinstance(data, list)

错误处理测试

# tests/test_error_handling.py
import pytest

class TestErrorHandling:
    """错误处理测试"""
    
    @pytest.mark.asyncio
    async def test_validation_errors(self, async_client):
        """测试验证错误"""
        invalid_data = {
            "email": "invalid-email",  # 无效邮箱
            "password": "123",         # 密码太短
            "name": "T"               # 名字太短
        }
        
        response = await async_client.post("/users/", json=invalid_data)
        assert response.status_code == 422  # 验证错误
        
        error_detail = response.json()
        assert "detail" in error_detail
        assert len(error_detail["detail"]) > 0
    
    @pytest.mark.asyncio
    async def test_not_found_errors(self, async_client):
        """测试未找到错误"""
        response = await async_client.get("/users/999999")
        assert response.status_code == 404
        
        error_detail = response.json()
        assert "detail" in error_detail
    
    @pytest.mark.asyncio
    async def test_unauthorized_errors(self, async_client):
        """测试未授权错误"""
        response = await async_client.get("/users/me")
        assert response.status_code == 401  # 需要认证
    
    @pytest.mark.asyncio
    async def test_rate_limiting(self, async_client):
        """测试速率限制"""
        # 发送大量请求来测试速率限制
        responses = []
        for i in range(20):  # 假设速率限制是10次/分钟
            response = await async_client.get("/health")
            responses.append(response.status_code)
        
        # 检查是否有请求被限制
        rate_limited_count = responses.count(429)  # 429 Too Many Requests
        # 根据实际的速率限制配置来验证
        # assert rate_limited_count > 0
    
    @pytest.mark.asyncio
    async def test_internal_server_errors(self, async_client):
        """测试服务器内部错误"""
        # 这个测试比较难模拟,通常通过异常处理中间件测试
        # 可以通过模拟异常来测试
        pass

# 自定义异常测试
class TestCustomExceptions:
    """自定义异常测试"""
    
    @pytest.mark.asyncio
    async def test_business_logic_errors(self, async_client):
        """测试业务逻辑错误"""
        # 例如:尝试创建已存在的资源
        user_data = {
            "email": "business@example.com",
            "password": "Password123!",
            "name": "Business User"
        }
        
        # 第一次创建
        response1 = await async_client.post("/users/", json=user_data)
        assert response1.status_code == 201
        
        # 第二次创建相同邮箱(业务逻辑错误)
        response2 = await async_client.post("/users/", json=user_data)
        assert response2.status_code == 409  # 业务冲突
    
    @pytest.mark.asyncio
    async def test_permission_errors(self, async_client):
        """测试权限错误"""
        # 创建一个普通用户
        user_data = {
            "email": "normal@example.com",
            "password": "Password123!",
            "name": "Normal User"
        }
        response = await async_client.post("/auth/register", json=user_data)
        assert response.status_code == 201
        
        # 尝试访问需要管理员权限的端点
        admin_response = await async_client.get("/admin/users")
        # 根据实际权限控制实现验证响应
        assert admin_response.status_code in [401, 403]  # 未认证或禁止访问

参数化测试与边界测试

参数化测试

# tests/test_parameterized.py
import pytest

class TestParameterizedTests:
    """参数化测试"""
    
    @pytest.mark.parametrize("email,is_valid", [
        ("valid@example.com", True),
        ("user.name@example.com", True),
        ("user+tag@example.co.uk", True),
        ("invalid", False),
        ("@example.com", False),
        ("user@", False),
        ("", False),
        ("user@.com", False),
        ("user@domain", False),
    ])
    @pytest.mark.asyncio
    async def test_email_validation(self, async_client, email, is_valid):
        """参数化测试邮箱验证"""
        user_data = {
            "email": email,
            "password": "Password123!",
            "name": "Test User"
        }
        
        response = await async_client.post("/users/", json=user_data)
        
        if is_valid:
            assert response.status_code in [201, 409]  # 201成功,409已存在
        else:
            assert response.status_code == 422  # 验证失败
    
    @pytest.mark.parametrize("password,min_length,max_length,expected_validity", [
        ("short", 8, 128, False),
        ("LongEnoughButNoSpecialChar123", 8, 128, False),
        ("ValidPass1!", 8, 128, True),
        ("Another_Valid_Pass123!", 8, 128, True),
        ("", 8, 128, False),
        ("A" * 129, 8, 128, False),  # 超过最大长度
    ])
    @pytest.mark.asyncio
    async def test_password_validation(
        self, async_client, password, min_length, max_length, expected_validity
    ):
        """参数化测试密码验证"""
        user_data = {
            "email": "password_test@example.com",
            "password": password,
            "name": "Password Test User"
        }
        
        response = await async_client.post("/users/", json=user_data)
        
        if expected_validity:
            assert response.status_code in [201, 409]
        else:
            assert response.status_code == 422
    
    @pytest.mark.parametrize("number,input_type,expected_result", [
        (0, "integer", 0),
        (1, "integer", 1),
        (-1, "integer", -1),
        (3.14, "float", 3.14),
        ("123", "string_number", 123),
        ("invalid", "invalid_string", None),
    ])
    @pytest.mark.asyncio
    async def test_input_processing(
        self, async_client, number, input_type, expected_result
    ):
        """参数化测试输入处理"""
        # 这里测试一个假想的数值处理端点
        # response = await async_client.post("/process-number", json={"input": number})
        # 根据实际端点实现验证
        pass
    
    @pytest.mark.parametrize("user_role,can_access_admin_panel", [
        ("admin", True),
        ("superuser", True),
        ("moderator", False),
        ("user", False),
        ("guest", False),
    ])
    @pytest.mark.asyncio
    async def test_role_based_access(
        self, async_client, user_role, can_access_admin_panel
    ):
        """参数化测试基于角色的访问控制"""
        # 这个测试需要更复杂的设置,包括创建不同角色的用户
        # 并验证他们对管理面板的访问权限
        pass

# 边界值测试
class TestBoundaryValueTesting:
    """边界值测试"""
    
    @pytest.mark.asyncio
    async def test_pagination_boundaries(self, async_client):
        """测试分页边界"""
        # 创建足够的测试数据
        for i in range(25):
            await async_client.post("/users/", json={
                "email": f"page_test_{i}@example.com",
                "password": "Password123!",
                "name": f"Page Test User {i}"
            })
        
        # 测试边界值
        boundary_tests = [
            {"skip": 0, "limit": 0},      # 0限制
            {"skip": 0, "limit": 1},      # 1限制
            {"skip": 0, "limit": 10},     # 正常限制
            {"skip": 0, "limit": 100},    # 大限制
            {"skip": 20, "limit": 5},     # 高偏移
            {"skip": 100, "limit": 10},   # 高偏移,无结果
        ]
        
        for params in boundary_tests:
            response = await async_client.get("/users/", params=params)
            assert response.status_code == 200
            data = response.json()
            assert isinstance(data, list)
    
    @pytest.mark.parametrize("field,value,min_val,max_val", [
        ("price", 0.01, 0.01, 999999.99),    # 最小价格
        ("price", 999999.99, 0.01, 999999.99), # 最大价格
        ("quantity", 1, 1, 10000),           # 最小数量
        ("quantity", 10000, 1, 10000),       # 最大数量
        ("rating", 1.0, 1.0, 5.0),          # 最小评分
        ("rating", 5.0, 1.0, 5.0),          # 最大评分
    ])
    @pytest.mark.asyncio
    async def test_numeric_field_boundaries(
        self, async_client, field, value, min_val, max_val
    ):
        """测试数值字段边界"""
        item_data = {
            "title": f"Boundary Test Item {field}",
            "description": f"Testing {field} boundaries",
        }
        
        if field in ["price", "quantity", "rating"]:
            item_data[field] = value
        
        response = await async_client.post("/items/", json=item_data)
        
        # 边界值应该被接受
        assert response.status_code in [201, 400]  # 400可能是由于其他验证
    
    @pytest.mark.parametrize("string_field,string_value,min_length,max_length", [
        ("name", "A", 1, 255),              # 最短名称
        ("name", "A" * 255, 1, 255),        # 最长名称
        ("description", "A", 1, 1000),      # 最短描述
        ("description", "A" * 1000, 1, 1000), # 最长描述
        ("title", "A", 1, 200),             # 最短标题
        ("title", "A" * 200, 1, 200),       # 最长标题
    ])
    @pytest.mark.asyncio
    async def test_string_field_boundaries(
        self, async_client, string_field, string_value, min_length, max_length
    ):
        """测试字符串字段边界"""
        data = {
            "title": "String Boundary Test",
            "description": "Testing string field boundaries",
        }
        
        if string_field in ["name", "description", "title"]:
            data[string_field] = string_value
        
        response = await async_client.post("/items/", json=data)
        
        # 边界值应该被接受
        assert response.status_code in [201, 422]  # 422是验证错误

# 组合参数化测试
class TestCombinatorialTesting:
    """组合测试"""
    
    @pytest.mark.parametrize("email_domain,password_complexity,rate_limit_status", [
        ("gmail.com", "simple", "normal"),
        ("company.co.uk", "complex", "normal"),
        ("free.email", "simple", "limited"),
        ("enterprise.com", "very_complex", "normal"),
    ])
    @pytest.mark.asyncio
    async def test_combination_scenarios(
        self, async_client, email_domain, password_complexity, rate_limit_status
    ):
        """测试组合场景"""
        # 根据参数组合创建不同的测试场景
        password_chars = {
            "simple": "password123",
            "complex": "ComplexPass1!",
            "very_complex": "VeryComplexPass123!@#"
        }
        
        user_data = {
            "email": f"user@{email_domain}",
            "password": password_chars[password_complexity],
            "name": f"Test User {email_domain}"
        }
        
        response = await async_client.post("/users/", json=user_data)
        assert response.status_code in [201, 409, 422]

性能边界测试

# tests/test_performance_boundaries.py
import pytest
import asyncio
import time

class TestPerformanceBoundaries:
    """性能边界测试"""
    
    @pytest.mark.performance
    @pytest.mark.asyncio
    async def test_response_time_under_load(self, async_client):
        """测试负载下的响应时间"""
        start_time = time.time()
        
        # 并发发送多个请求
        tasks = []
        for i in range(50):
            task = async_client.get("/health")
            tasks.append(task)
        
        responses = await asyncio.gather(*tasks)
        
        end_time = time.time()
        total_time = end_time - start_time
        
        print(f"Completed 50 requests in {total_time:.2f} seconds")
        print(f"Average response time: {total_time/50:.3f} seconds")
        
        # 验证所有请求都成功
        for response in responses:
            assert response.status_code == 200
        
        # 验证总时间在可接受范围内(根据系统性能调整)
        assert total_time < 10.0  # 10秒内完成50个请求
    
    @pytest.mark.performance
    @pytest.mark.asyncio
    async def test_memory_usage(self, async_client):
        """测试内存使用"""
        import psutil
        import os
        
        process = psutil.Process(os.getpid())
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 执行内存密集型操作
        for i in range(100):
            response = await async_client.get("/health")
            assert response.status_code == 200
        
        final_memory = process.memory_info().rss / 1024 / 1024  # MB
        memory_increase = final_memory - initial_memory
        
        print(f"Memory usage increase: {memory_increase:.2f} MB")
        
        # 验证内存增长在可接受范围内
        assert memory_increase < 50.0  # 小于50MB
    
    @pytest.mark.stress
    @pytest.mark.asyncio
    async def test_stress_test(self, async_client):
        """压力测试"""
        import asyncio
        from concurrent.futures import ThreadPoolExecutor
        
        # 配置压力测试参数
        num_requests = 200
        concurrency_level = 20
        
        async def make_request():
            try:
                response = await async_client.get("/health")
                return response.status_code == 200
            except Exception:
                return False
        
        # 创建并发请求
        tasks = [make_request() for _ in range(num_requests)]
        results = await asyncio.gather(*tasks)
        
        success_count = sum(results)
        success_rate = success_count / num_requests * 100
        
        print(f"Stress test results: {success_count}/{num_requests} successful ({success_rate:.1f}%)")
        
        # 验证成功率
        assert success_rate >= 95.0  # 95%成功率
    
    @pytest.mark.performance
    @pytest.mark.asyncio
    async def test_database_connection_pool(self, test_session):
        """测试数据库连接池"""
        import asyncio
        
        # 并发数据库操作测试
        async def db_operation(i):
            from app.crud import get_user_by_email
            # 执行数据库查询
            try:
                await get_user_by_email(test_session, f"test{i}@example.com")
                return True
            except Exception:
                return False
        
        # 并发执行多个数据库操作
        tasks = [db_operation(i) for i in range(10)]
        results = await asyncio.gather(*tasks)
        
        success_count = sum(results)
        assert success_count == 10  # 所有操作都应该成功

测试覆盖率分析

覆盖率配置与执行

# coverage_setup.sh - 覆盖率设置脚本
#!/bin/bash

echo "🚀 Setting up test coverage analysis..."

# 安装覆盖率工具
pip install pytest-cov coverage

# 运行测试并生成覆盖率报告
echo "📊 Running tests with coverage analysis..."
pytest --cov=app --cov-report=html --cov-report=term-missing --cov-fail-under=80

# 生成XML报告用于CI/CD
pytest --cov=app --cov-report=xml

echo "✅ Coverage analysis completed!"
echo "📄 HTML report available at: htmlcov/index.html"
echo "📊 XML report available at: coverage.xml"

详细的覆盖率配置

# .coveragerc - 覆盖率配置文件
[run]
# 源代码路径
source = app/
omit = 
    */venv/*
    */env/*
    */.venv/*
    */tests/*
    */migrations/*
    */config/*
    */__init__.py
    */settings.py
    */alembic/*
    */venv/*/**
    */env/*/**
    */.venv/*/**

# 包含分支覆盖率
branch = True

# 平行模式(用于并行测试)
parallel = True

[report]
# 忽略的行
exclude_lines =
    pragma: no cover
    def __repr__
    raise AssertionError
    raise NotImplementedError
    if __name__ == .__main__.:
    if TYPE_CHECKING:
    if settings.debug:
    if DEBUG:
    @abstractmethod
    @overload
    @runtime_checkable
    if sys.version_info <
    if typing.TYPE_CHECKING:
    if MYPY:
    assert_never\(
    assert False
    assert 0
    if 0:
    if __debug__:
    if not __debug__:
    if typing.TYPE_CHECKING:
    if TYPE_CHECKING:
    if "NO PYTEST COVERAGE" in __doc__:
    if "PYTEST_NO_COVERAGE" in __file__:
    if "COVERAGE DISABLE" in __file__:

# 覆盖率阈值
precision = 2
show_missing = True
skip_covered = False
skip_empty = True

[html]
directory = htmlcov
title = FastAPI Application Test Coverage Report

[xml]
output = coverage.xml

覆盖率分析工具

# tools/coverage_analyzer.py - 覆盖率分析工具
import subprocess
import json
import xml.etree.ElementTree as ET
from pathlib import Path
import sys

class CoverageAnalyzer:
    """覆盖率分析器"""
    
    def __init__(self, source_path="app/", coverage_threshold=80):
        self.source_path = source_path
        self.threshold = coverage_threshold
        self.results = {}
    
    def run_coverage_analysis(self):
        """运行覆盖率分析"""
        print("🔍 Running coverage analysis...")
        
        # 运行pytest with coverage
        cmd = [
            "pytest", 
            f"--cov={self.source_path}",
            "--cov-report=json",
            "--cov-report=term-missing",
            f"--cov-fail-under={self.threshold}"
        ]
        
        result = subprocess.run(cmd, capture_output=True, text=True)
        
        print("STDOUT:", result.stdout)
        if result.stderr:
            print("STDERR:", result.stderr)
        
        # 读取JSON报告
        json_report = Path("coverage.json")
        if json_report.exists():
            with open(json_report, 'r') as f:
                self.results = json.load(f)
        
        return result.returncode == 0
    
    def get_coverage_summary(self):
        """获取覆盖率摘要"""
        if not self.results:
            return None
        
        summary = {
            "total_coverage": self.results.get("totals", {}).get("percent_covered", 0),
            "num_files": len(self.results.get("files", {})),
            "missing_lines": self.results.get("totals", {}).get("missing_lines", 0),
            "excluded_lines": self.results.get("totals", {}).get("excluded", 0)
        }
        
        return summary
    
    def get_low_coverage_files(self, threshold=80):
        """获取低覆盖率文件"""
        low_coverage_files = []
        
        for filepath, file_data in self.results.get("files", {}).items():
            coverage_percent = file_data.get("summary", {}).get("percent_covered", 0)
            if coverage_percent < threshold:
                low_coverage_files.append({
                    "file": filepath,
                    "coverage": coverage_percent,
                    "missing_lines": file_data.get("summary", {}).get("missing_lines", 0)
                })
        
        return sorted(low_coverage_files, key=lambda x: x["coverage"])
    
    def generate_report(self):
        """生成覆盖率报告"""
        summary = self.get_coverage_summary()
        if not summary:
            print("❌ No coverage data available")
            return
        
        print("\n" + "="*60)
        print("📋 COVERAGE ANALYSIS REPORT")
        print("="*60)
        print(f"Total Coverage: {summary['total_coverage']:.2f}%")
        print(f"Number of Files: {summary['num_files']}")
        print(f"Missing Lines: {summary['missing_lines']}")
        print(f"Excluded Lines: {summary['excluded_lines']}")
        print("-"*60)
        
        # 显示低覆盖率文件
        low_coverage = self.get_low_coverage_files()
        if low_coverage:
            print("⚠️  LOW COVERAGE FILES (< 80%):")
            for file_info in low_coverage[:10]:  # 只显示前10个
                print(f"  {file_info['file']}: {file_info['coverage']:.1f}% "
                      f"({file_info['missing_lines']} lines missing)")
        else:
            print("✅ All files have good coverage!")
        
        print("="*60)
        
        # 检查是否达到阈值
        if summary['total_coverage'] >= self.threshold:
            print(f"✅ Coverage meets threshold ({self.threshold}%)")
            return True
        else:
            print(f"❌ Coverage below threshold ({self.threshold}%)")
            return False

def main():
    """主函数"""
    analyzer = CoverageAnalyzer(source_path="app", coverage_threshold=80)
    success = analyzer.run_coverage_analysis()
    analyzer.generate_report()
    
    if not success:
        sys.exit(1)

if __name__ == "__main__":
    main()

覆盖率优化策略

# tests/test_coverage_optimization.py
"""
覆盖率优化策略:

1. 业务逻辑覆盖率:确保核心业务逻辑100%覆盖
2. 错误路径覆盖:测试所有可能的错误情况
3. 边界条件覆盖:测试输入参数的边界值
4. 异常处理覆盖:确保所有异常都被捕获和处理
5. 集成路径覆盖:测试模块间的交互
"""

class TestCoverageOptimization:
    """覆盖率优化测试"""
    
    def test_all_exception_paths(self):
        """测试所有异常路径"""
        # 示例:测试除零异常
        def divide(a, b):
            if b == 0:
                raise ZeroDivisionError("Cannot divide by zero")
            return a / b
        
        # 测试正常情况
        assert divide(10, 2) == 5
        
        # 测试异常情况
        with pytest.raises(ZeroDivisionError):
            divide(10, 0)
    
    def test_all_branches(self):
        """测试所有分支"""
        def process_user(user_type, is_active):
            if user_type == "admin":
                if is_active:
                    return "Admin Active"
                else:
                    return "Admin Inactive"
            elif user_type == "user":
                if is_active:
                    return "User Active"
                else:
                    return "User Inactive"
            else:
                return "Unknown"
        
        # 测试所有分支
        assert process_user("admin", True) == "Admin Active"
        assert process_user("admin", False) == "Admin Inactive"
        assert process_user("user", True) == "User Active"  
        assert process_user("user", False) == "User Inactive"
        assert process_user("guest", True) == "Unknown"
    
    def test_boundary_conditions(self):
        """测试边界条件"""
        def validate_age(age):
            if age < 0:
                raise ValueError("Age cannot be negative")
            elif age > 150:
                raise ValueError("Age seems unrealistic")
            elif age < 18:
                return "Minor"
            elif age >= 18 and age <= 65:
                return "Adult"
            else:
                return "Senior"
        
        # 测试边界值
        assert validate_age(0) == "Minor"      # 边界
        assert validate_age(17) == "Minor"    # 边界
        assert validate_age(18) == "Adult"    # 边界
        assert validate_age(65) == "Adult"    # 边界
        assert validate_age(66) == "Senior"   # 边界
        assert validate_age(150) == "Senior"  # 边界
        
        # 测试异常边界
        with pytest.raises(ValueError, match="negative"):
            validate_age(-1)
        
        with pytest.raises(ValueError, match="unrealistic"):
            validate_age(151)
    
    def test_all_input_combinations(self):
        """测试所有输入组合"""
        def calculate_discount(customer_type, purchase_amount, is_member):
            discount = 0
            
            if customer_type == "premium":
                discount += 15
            elif customer_type == "regular":
                discount += 5
            
            if is_member:
                discount += 10
            
            if purchase_amount > 1000:
                discount += 5
            elif purchase_amount > 500:
                discount += 3
            
            return min(discount, 50)  # 最大折扣50%
        
        # 测试各种组合
        test_cases = [
            # (customer_type, purchase_amount, is_member, expected_discount)
            ("premium", 100, False, 15),
            ("premium", 100, True, 25),
            ("premium", 600, True, 28),
            ("premium", 1100, True, 35),
            ("regular", 100, False, 5),
            ("regular", 100, True, 15),
            ("regular", 600, True, 18),
            ("regular", 1100, True, 25),
            ("guest", 100, False, 0),
            ("guest", 1100, True, 15),
        ]
        
        for customer_type, amount, is_member, expected in test_cases:
            result = calculate_discount(customer_type, amount, is_member)
            assert result == expected, f"Failed for {customer_type}, {amount}, {is_member}"

# 专门的覆盖率提高测试
class TestIncreaseCoverage:
    """提高覆盖率的测试"""
    
    def test_repr_methods(self):
        """测试repr方法"""
        class TestClass:
            def __init__(self, value):
                self.value = value
            
            def __repr__(self):
                return f"TestClass(value={self.value!r})"
        
        obj = TestClass("test")
        repr_str = repr(obj)
        assert "TestClass" in repr_str
        assert "test" in repr_str
    
    def test_eq_methods(self):
        """测试相等性方法"""
        class TestClass:
            def __init__(self, value):
                self.value = value
            
            def __eq__(self, other):
                if not isinstance(other, TestClass):
                    return NotImplemented
                return self.value == other.value
        
        obj1 = TestClass("test")
        obj2 = TestClass("test")
        obj3 = TestClass("different")
        
        assert obj1 == obj2
        assert obj1 != obj3
        assert obj1 != "not_a_TestClass"
    
    def test_property_methods(self):
        """测试属性方法"""
        class TestClass:
            def __init__(self):
                self._value = 0
            
            @property
            def value(self):
                return self._value
            
            @value.setter
            def value(self, val):
                if val < 0:
                    raise ValueError("Value cannot be negative")
                self._value = val
        
        obj = TestClass()
        assert obj.value == 0
        
        obj.value = 10
        assert obj.value == 10
        
        with pytest.raises(ValueError):
            obj.value = -1

TDD开发实践

TDD基础实践

"""
TDD (Test-Driven Development) 开发实践:

1. RED - 编写失败的测试
2. GREEN - 编写最少的代码使测试通过
3. REFACTOR - 重构代码,保持测试通过
"""

# 示例:开发一个用户服务类

# 1. RED: 先写测试
# tests/test_user_service_tdd.py
import pytest
from unittest.mock import AsyncMock, MagicMock

class TestUserServiceTDD:
    """TDD方式开发用户服务"""
    
    @pytest.fixture
    def mock_db_session(self):
        """模拟数据库会话"""
        session = AsyncMock()
        session.commit = AsyncMock()
        session.refresh = AsyncMock()
        return session
    
    @pytest.fixture
    def user_service(self, mock_db_session):
        """用户服务实例"""
        from app.services.user_service import UserService
        return UserService(db_session=mock_db_session)
    
    def test_create_user_with_valid_data(self, user_service, mock_db_session):
        """测试使用有效数据创建用户 - RED阶段"""
        # 这个测试应该失败,因为UserService还没实现
        user_data = {
            "email": "test@example.com",
            "password": "SecurePassword123!",
            "name": "Test User"
        }
        
        # 预期会失败,因为我们还没有实现UserService
        with pytest.raises(Exception):  # 期望抛出异常直到实现完成
            result = user_service.create_user(user_data)
    
    def test_get_user_by_email(self, user_service):
        """测试通过邮箱获取用户"""
        # RED: 编写失败的测试
        with pytest.raises(Exception):
            user_service.get_user_by_email("test@example.com")

# 2. GREEN: 实现最简单的代码使测试通过
"""
# app/services/user_service.py - 实现UserService
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession

class UserService:
    def __init__(self, db_session: AsyncSession):
        self.db_session = db_session
    
    async def create_user(self, user_data: Dict[str, Any]):
        # 简单实现使测试通过
        return {
            "id": 1,
            "email": user_data["email"],
            "name": user_data["name"],
            "created_at": "2024-01-01T00:00:00Z"
        }
    
    async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
        # 简单实现使测试通过
        if email == "test@example.com":
            return {
                "id": 1,
                "email": email,
                "name": "Test User"
            }
        return None
"""

# 3. REFACTOR: 重构代码
"""
# 重构后的UserService实现
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import User
from app.schemas import UserCreate
from passlib.context import CryptContext

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

class UserService:
    def __init__(self, db_session: AsyncSession):
        self.db_session = db_session
    
    async def create_user(self, user_data: UserCreate) -> User:
        hashed_password = pwd_context.hash(user_data.password)
        db_user = User(
            email=user_data.email,
            name=user_data.name,
            hashed_password=hashed_password
        )
        self.db_session.add(db_user)
        await self.db_session.commit()
        await self.db_session.refresh(db_user)
        return db_user
    
    async def get_user_by_email(self, email: str) -> Optional[User]:
        result = await self.db_session.execute(
            select(User).filter(User.email == email)
        )
        return result.scalars().first()
    
    async def authenticate_user(self, email: str, password: str) -> Optional[User]:
        user = await self.get_user_by_email(email)
        if user and pwd_context.verify(password, user.hashed_password):
            return user
        return None
"""

# TDD完整示例:开发购物车功能
class TestShoppingCartTDD:
    """购物车功能的TDD开发示例"""
    
    def test_cart_starts_empty(self):
        """测试购物车初始化时空的 - RED"""
        # 首先编写测试
        cart = ShoppingCart()
        assert len(cart.items) == 0
        assert cart.total_price == 0
    
    def test_add_item_to_cart(self):
        """测试向购物车添加商品 - RED"""
        cart = ShoppingCart()
        item = {"id": 1, "name": "Product", "price": 10.0, "quantity": 1}
        cart.add_item(item)
        
        assert len(cart.items) == 1
        assert cart.items[0]["name"] == "Product"
        assert cart.total_price == 10.0
    
    def test_remove_item_from_cart(self):
        """测试从购物车移除商品 - RED"""
        cart = ShoppingCart()
        item = {"id": 1, "name": "Product", "price": 10.0, "quantity": 1}
        cart.add_item(item)
        
        cart.remove_item(1)
        assert len(cart.items) == 0
        assert cart.total_price == 0
    
    def test_update_item_quantity(self):
        """测试更新商品数量 - RED"""
        cart = ShoppingCart()
        item = {"id": 1, "name": "Product", "price": 10.0, "quantity": 1}
        cart.add_item(item)
        
        cart.update_quantity(1, 3)
        assert cart.items[0]["quantity"] == 3
        assert cart.total_price == 30.0

"""
# GREEN: 实现功能使测试通过
class ShoppingCart:
    def __init__(self):
        self.items = []
    
    @property
    def total_price(self):
        return sum(item["price"] * item["quantity"] for item in self.items)
    
    def add_item(self, item):
        self.items.append(item)
    
    def remove_item(self, item_id):
        self.items = [item for item in self.items if item["id"] != item_id]
    
    def update_quantity(self, item_id, quantity):
        for item in self.items:
            if item["id"] == item_id:
                item["quantity"] = quantity
                break
"""

# REFACTOR: 重构为更健壮的实现
"""
class ShoppingCart:
    def __init__(self):
        self._items = {}  # 使用字典以ID为键,便于查找
    
    def add_item(self, item):
        item_id = item["id"]
        if item_id in self._items:
            # 如果商品已存在,更新数量
            self._items[item_id]["quantity"] += item["quantity"]
        else:
            self._items[item_id] = item.copy()
    
    def remove_item(self, item_id):
        if item_id in self._items:
            del self._items[item_id]
    
    def update_quantity(self, item_id, quantity):
        if item_id in self._items:
            if quantity <= 0:
                self.remove_item(item_id)
            else:
                self._items[item_id]["quantity"] = quantity
    
    def get_item(self, item_id):
        return self._items.get(item_id)
    
    @property
    def items(self):
        return list(self._items.values())
    
    @property
    def total_price(self):
        return sum(item["price"] * item["quantity"] 
                  for item in self._items.values())
    
    @property
    def item_count(self):
        return sum(item["quantity"] for item in self._items.values())
"""

# TDD测试策略
class TestTDDStrategies:
    """TDD测试策略"""
    
    def test_outside_in_tdd(self):
        """外部到内部TDD - 从API开始测试"""
        # 先测试外部接口
        response = client.get("/cart")
        assert response.status_code == 200
        
        # 然后逐步深入到内部实现
        # 这种方式适合API驱动开发
    
    def test_inside_out_tdd(self):
        """内部到外部TDD - 从核心逻辑开始测试"""
        # 先测试核心业务逻辑
        calculator = DiscountCalculator()
        assert calculator.calculate(100, "regular") == 5  # 5%折扣
        
        # 然后测试使用该逻辑的上层组件
        # 这种方式适合算法和业务逻辑驱动开发
    
    def test_mockist_tdd(self):
        """Mockist TDD - 使用模拟对象测试"""
        mock_payment_gateway = Mock()
        mock_payment_gateway.process.return_value = True
        
        processor = PaymentProcessor(mock_payment_gateway)
        result = processor.pay(100, "test_card")
        
        assert result is True
        mock_payment_gateway.process.assert_called_once()
    
    def test_classic_tdd(self):
        """Classic TDD - 使用真实对象测试"""
        # 使用真实的数据库、真实的外部服务
        # 更接近生产环境,但测试速度较慢

# TDD最佳实践
"""
TDD Best Practices:

1. Test One Thing: 每个测试只验证一个行为
2. Fast Tests: 测试应该快速运行
3. Independent Tests: 测试之间不应该相互依赖
4. Repeatable Tests: 测试结果应该是一致的
5. Self-Validating: 测试应该自动判断通过还是失败
6. Timely Tests: 测试应该在代码之前或同时编写
"""

## 测试性能优化 \{#测试性能优化}

### 测试并行化

```python
# pytest_parallel_example.py - 并行测试示例
"""
pytest-xdist 插件可以实现测试并行化:

安装:
pip install pytest-xdist

运行:
pytest -n auto          # 自动检测CPU核心数
pytest -n 4            # 指定4个进程
pytest --dist worksteal # 工作窃取模式
"""

# conftest.py 中的并行化配置
@pytest.fixture(scope="session")
def shared_resource():
    """共享资源,只初始化一次"""
    print("Initializing shared resource")
    return {"connection": "mock_connection"}

# 使用pytest-benchmark进行性能测试
"""
pip install pytest-benchmark
"""

def test_with_benchmark(benchmark):
    """使用benchmark进行性能测试"""
    def expensive_function():
        # 模拟耗时操作
        result = sum(i**2 for i in range(1000))
        return result
    
    result = benchmark(expensive_function)
    assert result == sum(i**2 for i in range(1000))

# 性能回归测试
def test_performance_regression(benchmark):
    """性能回归测试"""
    # 设定性能阈值
    threshold = 0.1  # 100ms
    
    def target_function():
        import time
        time.sleep(0.05)  # 模拟50ms操作
        return "result"
    
    result = benchmark(target_function)
    
    # 检查是否超过阈值
    assert benchmark.stats.stats.max < threshold

测试数据优化

# test_data_optimization.py
from pytest_lazyfixture import lazy_fixture

# 使用工厂函数创建测试数据
@pytest.fixture
def user_factory():
    """用户工厂"""
    def _create_user(email=None, name=None, role="user"):
        return {
            "email": email or f"user{random.randint(1, 1000)}@example.com",
            "name": name or f"User {random.randint(1, 1000)}",
            "role": role
        }
    return _create_user

# 使用Faker生成测试数据
from faker import Faker
fake = Faker()

@pytest.fixture
def realistic_user_data():
    """生成逼真的测试数据"""
    return {
        "email": fake.email(),
        "name": fake.name(),
        "address": fake.address(),
        "phone": fake.phone_number(),
        "company": fake.company()
    }

# 测试数据重用
@pytest.fixture(scope="module")
def test_dataset():
    """模块级别的测试数据集"""
    # 生成大量测试数据,只生成一次
    dataset = []
    for i in range(1000):
        dataset.append({
            "id": i,
            "name": f"Item {i}",
            "value": random.random() * 100
        })
    return dataset

# 数据清理优化
@pytest.fixture(autouse=True)
def cleanup_test_data():
    """自动清理测试数据"""
    yield
    # 测试后清理
    # 清理会话、临时文件等
    pass

测试缓存和跳过

# test_caching_and_skipping.py
import pytest

# 使用缓存避免重复计算
@pytest.fixture(scope="session")
def expensive_computation():
    """昂贵的计算,只执行一次"""
    print("Performing expensive computation...")
    # 模拟昂贵的计算
    result = sum(i**2 for i in range(100000))
    return result

# 条件跳过测试
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
def test_new_feature():
    """只在Python 3.8+上运行的测试"""
    # 使用Python 3.8+的新特性
    pass

# 跳过不稳定的测试
@pytest.mark.skip(reason="Unstable test, needs investigation")
def test_unstable_feature():
    """暂时跳过的不稳定测试"""
    pass

# 预期失败的测试
@pytest.mark.xfail(strict=True)
def test_expected_failure():
    """预期会失败的测试"""
    assert False  # 预期失败

与其他测试框架对比

Pytest vs Unittest

特性PytestUnittest
语法简洁性⭐⭐⭐⭐⭐ 简洁直观⭐⭐⭐ 传统但冗长
Fixtures⭐⭐⭐⭐⭐ 强大的Fixtures系统⭐⭐⭐ setUp/tearDown
参数化测试⭐⭐⭐⭐⭐ 简单易用⭐⭐⭐ 较复杂
插件生态系统⭐⭐⭐⭐⭐ 丰富的插件⭐⭐⭐ 有限
异步测试支持⭐⭐⭐⭐⭐ 原生支持⭐⭐ 需要额外库
断言⭐⭐⭐⭐⭐ 简单的assert⭐⭐⭐ 多种assert方法
学习曲线⭐⭐⭐⭐⭐ 容易上手⭐⭐⭐⭐ 标准库

Pytest vs Nose

特性PytestNose
活跃度⭐⭐⭐⭐⭐ 活跃维护⭐ 不再维护
Python 3支持⭐⭐⭐⭐⭐ 完全支持⭐⭐⭐⭐ 部分支持
插件系统⭐⭐⭐⭐⭐ 优秀⭐⭐⭐ 一般
性能⭐⭐⭐⭐⭐ 快速⭐⭐⭐⭐ 可接受

选择建议

"""
选择测试框架的决策树:

1. 如果使用Python 3.6+ → 优先选择Pytest
2. 如果需要与unittest集成 → Pytest兼容性好
3. 如果是大型企业项目 → Pytest + 插件生态系统
4. 如果是小型项目 → Pytest简单够用
5. 如果有特殊需求 → 查看Pytest插件能否满足

Pytest优势:
- 语法简洁
- 功能强大
- 插件丰富
- 社区活跃
- 与主流工具集成好
"""

总结

Pytest是Python中最强大的测试框架之一,特别适合FastAPI应用的测试:

  1. 简洁语法:使用简单的assert语句
  2. 强大Fixtures:灵活的依赖管理
  3. 参数化测试:轻松测试多种场景
  4. 丰富的插件:扩展功能强大
  5. 异步支持:天然支持async/await
  6. 集成友好:与CI/CD无缝集成

💡 关键要点:良好的测试策略是应用质量的保障。Pytest提供了完整的测试解决方案,从单元测试到集成测试,从简单验证到复杂场景,都能胜任。


SEO优化建议

为了提高这篇Pytest单元测试教程在搜索引擎中的排名,以下是几个关键的SEO优化建议:

标题优化

  • 主标题:使用包含核心关键词的标题,如"FastAPI Pytest单元测试完全指南"
  • 二级标题:每个章节标题都包含相关的长尾关键词
  • H1-H6层次结构:保持正确的标题层级,便于搜索引擎理解内容结构

内容优化

  • 关键词密度:在内容中自然地融入关键词如"Pytest"、"单元测试"、"FastAPI"、"异步测试"等
  • 元描述:在文章开头的元数据中包含吸引人的描述
  • 内部链接:链接到其他相关教程,如FastAPI依赖注入系统
  • 外部权威链接:引用官方文档和权威资源

技术SEO

  • 页面加载速度:优化代码块和图片加载
  • 移动端适配:确保在移动设备上良好显示
  • 结构化数据:使用适当的HTML标签和语义化元素

用户体验优化

  • 内容可读性:使用清晰的段落结构和代码示例
  • 互动元素:提供实际可运行的代码示例
  • 更新频率:定期更新内容以保持时效性

常见问题解答(FAQ)

Q1: Pytest和unittest有什么区别?

A: Pytest相比unittest有更简洁的语法、强大的fixture系统、原生的参数化测试支持,以及丰富的插件生态系统。Pytest可以直接使用简单的assert语句,而unittest需要使用特定的assert方法。

Q2: 如何在FastAPI中测试异步函数?

A: 使用pytest-asyncio插件,添加@pytest.mark.asyncio装饰器,然后使用httpx.AsyncClient进行异步测试。

Q3: 如何测试数据库相关的功能?

A: 使用内存数据库(如SQLite in-memory)进行测试,通过fixture提供测试数据库会话,确保每个测试后数据库状态清洁。

Q4: 什么是TDD,为什么要使用它?

A: TDD(Test-Driven Development)是测试驱动开发,先写测试再写实现代码。它能提高代码质量、减少bug、改善设计并提供文档作用。

Q5: 如何提高测试覆盖率?

A: 使用pytest-cov工具分析覆盖率,重点关注业务逻辑、错误路径、边界条件和异常处理的测试覆盖。


🔗 相关教程推荐

🏷️ 标签云: FastAPI测试 Pytest教程 单元测试 异步测试 TDD开发 API测试 测试覆盖率 Fixtures 参数化测试 数据库测试