Pytest 单元测试:使用 TestClient 编写高覆盖率的接口测试

📂 所属阶段:第五阶段 — 工程化与部署(实战篇)
🔗 相关章节:依赖注入系统 · Pydantic Settings 多环境配置


1. 测试基础

1.1 为什么需要测试?

无测试:  改代码 → 手动测试 → 上线 → 发现 bug → 回滚 → 用户流失 😱
有测试:  改代码 → pytest → 快速反馈 → 自信上线 🚀

1.2 安装测试依赖

pip install pytest pytest-asyncio httpx
# pytest-asyncio → 支持异步测试
# httpx → FastAPI TestClient 的依赖

1.3 pytest 配置文件

# pytest.ini 或 pyproject.toml
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
python_functions = "test_*"
asyncio_mode = "auto"
addopts = "-v --tb=short"
filterwarnings = [
    "ignore::DeprecationWarning",
]

2. TestClient 基础

2.1 第一个测试

# tests/test_main.py
from fastapi.testclient import TestClient
from main import app

client = TestClient(app)

def test_root():
    response = client.get("/")
    assert response.status_code == 200
    assert response.json()["message"] == "Welcome to DaomanAPI"

def test_health():
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json()["status"] == "ok"

2.2 运行测试

pytest                          # 运行所有测试
pytest tests/test_main.py       # 运行指定文件
pytest -v                       # 详细模式
pytest --cov=app                # 覆盖率报告
pytest -k "user"                # 只运行名称包含 "user" 的测试
pytest --tb=short               # 简短错误信息

3. 异步测试

3.1 async 路由测试

# tests/test_users.py
from httpx import AsyncClient, ASGITransport
from main import app
import pytest

@pytest.fixture
async def async_client():
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as client:
        yield client

@pytest.mark.asyncio
async def test_list_users(async_client: AsyncClient):
    response = await async_client.get("/users/")
    assert response.status_code == 200
    data = response.json()
    assert isinstance(data, list)

3.2 使用 override 依赖(模拟认证)

# 测试配置:覆盖数据库依赖
from main import app
from database import get_db, AsyncSessionLocal
from unittest.mock import AsyncMock

@pytest.fixture
def mock_db():
    """用模拟数据库覆盖真实数据库"""
    async def override_get_db():
        yield AsyncMock(spec=AsyncSession)

    app.dependency_overrides[get_db] = override_get_db
    yield
    app.dependency_overrides.clear()  # 测试后清理

def test_protected_route_without_auth(mock_db):
    """无认证访问受保护路由"""
    response = client.get("/users/me")
    assert response.status_code == 401

def test_protected_route_with_auth(mock_db):
    """带认证访问受保护路由"""
    # 先登录获取 token
    login_resp = client.post("/auth/login", json={
        "username": "alice@example.com",
        "password": "password123"
    })
    token = login_resp.json()["access_token"]

    # 访问受保护路由
    response = client.get(
        "/users/me",
        headers={"Authorization": f"Bearer {token}"}
    )
    assert response.status_code == 200

4. 完整测试用例

4.1 用户 CRUD 测试

# tests/test_users_crud.py
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
from database import engine, Base

# 测试数据
TEST_USER = {
    "email": "pytest@example.com",
    "password": "SecurePass123!",
    "name": "Pytest User",
}

@pytest.fixture(scope="module", autouse=True)
async def setup_db():
    """测试开始前创建表,结束后删除"""
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    yield
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)

@pytest.fixture
async def client():
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as c:
        yield c

@pytest.fixture
async def auth_token(client: AsyncClient):
    """注册并登录,返回 access_token"""
    # 注册
    await client.post("/auth/register", json=TEST_USER)
    # 登录
    resp = await client.post("/auth/login", data={
        "username": TEST_USER["email"],
        "password": TEST_USER["password"],
    })
    return resp.json()["access_token"]


class TestUserAuth:
    async def test_register(self, client: AsyncClient):
        response = await client.post("/auth/register", json=TEST_USER)
        assert response.status_code == 201
        data = response.json()
        assert data["email"] == TEST_USER["email"]

    async def test_register_duplicate_email(self, client: AsyncClient):
        response = await client.post("/auth/register", json=TEST_USER)
        assert response.status_code == 409

    async def test_login_success(self, client: AsyncClient):
        response = await client.post("/auth/login", data={
            "username": TEST_USER["email"],
            "password": TEST_USER["password"],
        })
        assert response.status_code == 200
        assert "access_token" in response.json()

    async def test_login_wrong_password(self, client: AsyncClient):
        response = await client.post("/auth/login", data={
            "username": TEST_USER["email"],
            "password": "wrongpassword",
        })
        assert response.status_code == 401


class TestUserProfile:
    async def test_get_me(self, client: AsyncClient, auth_token: str):
        response = await client.get(
            "/users/me",
            headers={"Authorization": f"Bearer {auth_token}"}
        )
        assert response.status_code == 200
        assert response.json()["email"] == TEST_USER["email"]

    async def test_get_me_without_token(self, client: AsyncClient):
        response = await client.get("/users/me")
        assert response.status_code == 401

4.2 参数验证测试

class TestValidation:
    async def test_register_invalid_email(self, client: AsyncClient):
        response = await client.post("/auth/register", json={
            "email": "not-an-email",
            "password": "password123",
            "name": "Test",
        })
        assert response.status_code == 422  # FastAPI 自动返回 422

    async def test_register_short_password(self, client: AsyncClient):
        response = await client.post("/auth/register", json={
            "email": "test@example.com",
            "password": "123",  # 太短
            "name": "Test",
        })
        assert response.status_code == 422
        errors = response.json()["errors"]
        assert any("password" in str(e) for e in errors)

5. 测试覆盖率

5.1 生成覆盖率报告

pip install pytest-cov

# 生成覆盖率报告
pytest --cov=app --cov-report=html
# 打开 htmlcov/index.html 查看详细报告

# 在终端显示
pytest --cov=app --cov-report=term-missing

# 覆盖率目标:核心业务逻辑 80%+,简单函数 100%

5.2 覆盖率输出示例

---------- coverage: platform win32 ----------
Name                      Stmts   Miss  Cover
---------------------------------------------
app/main.py                  45      5    89%
app/routers/users.py         80     10    88%
app/services/user.py         60     12    80%
app/auth/jwt.py              30      0   100%
---------------------------------------------
TOTAL                       215     27    87%

6. 常用 pytest 技巧

# 使用 pytest fixtures 复用设置
@pytest.fixture
def sample_user():
    return {"email": "sample@test.com", "password": "Test1234!"}

# 参数化测试
@pytest.mark.parametrize("email,expected", [
    ("a@b.com", True),
    ("invalid", False),
    ("", False),
])
async def test_email_validation(email, expected):
    ...

# 跳过测试
@pytest.mark.skip(reason="待实现")
async def test_future_feature():
    ...

# 预期异常
def test_divide_by_zero():
    with pytest.raises(ZeroDivisionError):
        1 / 0

7. 小结

# 测试速查

# 同步测试
client = TestClient(app)
response = client.get("/path")
assert response.status_code == 200
assert response.json() == {"expected": "data"}

# 异步测试
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport) as client:
    response = await client.get("/path")

# 模拟认证
app.dependency_overrides[get_current_user] = lambda: mock_user

# 覆盖率
pytest --cov=app --cov-report=html

💡 测试驱动开发(TDD):建议先写测试,再写功能。测试即文档,让代码质量有保障。


🔗 扩展阅读