RAG 后端系统集成:FastAPI 与向量数据库(Milvus/Chroma)的联动

📂 所属阶段:第六阶段 — 2026 特色专题(AI 集成篇)
🔗 相关章节:流式响应 StreamingResponse · 异步任务队列 Celery


1. 什么是 RAG?

1.1 RAG 核心流程

用户问题 → 检索(向量相似度)→ 相关文档片段 → 注入 Prompt → LLM 生成答案

          文档先被分块 → 向量化 → 存入向量数据库

RAG = Retrieval(检索)+ Augmented(增强)+ Generation(生成)

1.2 为什么需要 RAG?

方案优点缺点
纯 Prompt简单模型知识有限、容易幻觉
Fine-tuning深度定制成本高、更新慢
RAG实时性、可溯源、成本低依赖检索质量

2. 向量数据库选型

2.1 Chroma(轻量,适合中小规模)

pip install chromadb sentence-transformers
# vector_store/chroma_store.py
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer

class ChromaStore:
    def __init__(self, collection_name: str = "documents"):
        self.client = chromadb.PersistentClient(
            path="./chroma_data",
            settings=Settings(anonymized_telemetry=False)
        )
        self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}  # 余弦相似度
        )

    def add_document(
        self,
        doc_id: str,
        text: str,
        metadata: dict | None = None
    ):
        """添加文档到向量库"""
        embedding = self.embedding_model.encode(text).tolist()
        self.collection.add(
            ids=[doc_id],
            documents=[text],
            embeddings=[embedding],
            metadatas=[metadata or {}]
        )

    def search(self, query: str, top_k: int = 5) -> list[dict]:
        """语义检索"""
        query_embedding = self.embedding_model.encode(query).tolist()
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k,
            include=["documents", "metadatas", "distances"]
        )
        return [
            {
                "id": results["ids"][0][i],
                "text": results["documents"][0][i],
                "metadata": results["metadatas"][0][i],
                "score": 1 - results["distances"][0][i],  # 转为相似度
            }
            for i in range(len(results["ids"][0]))
        ]

    def delete(self, doc_id: str):
        self.collection.delete(ids=[doc_id])

# 全局实例
vector_store = ChromaStore()

2.2 Milvus(大规模生产环境)

pip install pymilvus
# vector_store/milvus_store.py
from pymilvus import MilvusClient, DataType

class MilvusStore:
    def __init__(self, uri: str = "http://localhost:19530"):
        self.client = MilvusClient(uri=uri)
        self.collection_name = "documents"

    def create_collection(self):
        schema = MilvusClient.create_schema(
            auto_id=True,
            enable_dynamic_field=True,
        )
        schema.add_field("id", DataType.INT64, is_primary=True)
        schema.add_field("vector", DataType.FLOAT_VECTOR, dim=384)
        schema.add_field("text", DataType.VARCHAR, max_length=4096)
        schema.add_field("source", DataType.VARCHAR, max_length=255)

        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            index_params={"vector": {"type": "IP", "metric_type": "COSINE"}}
        )

    def insert(self, documents: list[dict], embeddings: list[list[float]]):
        data = [
            {"vector": emb, "text": doc["text"], "source": doc.get("source", "")}
            for doc, emb in zip(documents, embeddings)
        ]
        self.client.insert(collection_name=self.collection_name, data=data)

    def search(self, query_embedding: list[float], top_k: int = 5):
        return self.client.search(
            collection_name=self.collection_name,
            data=[query_embedding],
            limit=top_k,
            output_fields=["text", "source"]
        )

3. 文档处理管道

3.1 文档分块策略

# processing/chunker.py
from typing import List
import re

class TextChunker:
    def __init__(self, chunk_size: int = 500, overlap: int = 50):
        self.chunk_size = chunk_size
        self.overlap = overlap

    def chunk_text(self, text: str, source: str = "") -> List[dict]:
        """按段落分块,保留上下文"""
        # 按段落分割
        paragraphs = [p.strip() for p in re.split(r'\n+', text) if p.strip()]
        chunks = []
        current_chunk = ""
        chunk_id = 0

        for para in paragraphs:
            if len(current_chunk) + len(para) < self.chunk_size:
                current_chunk += para + "\n"
            else:
                if current_chunk.strip():
                    chunks.append({
                        "id": f"{source}_{chunk_id}",
                        "text": current_chunk.strip(),
                        "metadata": {"source": source, "chunk_id": chunk_id}
                    })
                    chunk_id += 1
                # 带重叠滑动窗口
                current_chunk = para[-self.overlap:] + "\n" + para + "\n"

        if current_chunk.strip():
            chunks.append({
                "id": f"{source}_{chunk_id}",
                "text": current_chunk.strip(),
                "metadata": {"source": source, "chunk_id": chunk_id}
            })

        return chunks

    def chunk_from_file(self, file_path: str, source: str = "") -> List[dict]:
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()
        return self.chunker_text(text, source=source or file_path)

3.2 文档入库流程

# processing/ingest.py
from vector_store.chroma_store import vector_store
from processing.chunker import TextChunker
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
chunker = TextChunker(chunk_size=500)

async def ingest_document(file_path: str, source_name: str):
    """将文档切片并入库"""
    # 1. 分块
    chunks = chunker.chunk_from_file(file_path, source=source_name)

    # 2. 向量化
    texts = [c["text"] for c in chunks]
    embeddings = embedding_model.encode(texts, show_progress_bar=True).tolist()

    # 3. 存入向量库
    for chunk, embedding in zip(chunks, embeddings):
        vector_store.collection.add(
            ids=[chunk["id"]],
            documents=[chunk["text"]],
            embeddings=[embedding],
            metadatas=[chunk["metadata"]]
        )

    return {"chunks": len(chunks), "source": source_name}

4. RAG 检索与生成

4.1 RAG 检索

# services/rag_service.py
from vector_store.chroma_store import vector_store
from sentence_transformers import SentenceTransformer
from openai import AsyncOpenAI
import httpx

class RAGService:
    def __init__(self):
        self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
        self.openai = AsyncOpenAI(api_key="your-openai-api-key")

    async def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.5):
        """从向量库检索相关文档"""
        # 1. 向量化查询
        query_embedding = self.embedding_model.encode(query).tolist()

        # 2. 相似度检索
        results = vector_store.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k,
            include=["documents", "metadatas", "distances"]
        )

        # 3. 过滤低分结果
        docs = []
        for i in range(len(results["ids"][0])):
            score = 1 - results["distances"][0][i]
            if score >= min_score:
                docs.append({
                    "text": results["documents"][0][i],
                    "source": results["metadatas"][0][i].get("source", ""),
                    "score": round(score, 4),
                })
        return docs

    async def generate_answer(self, query: str, context_docs: list[dict]) -> str:
        """用检索结果增强生成答案"""
        # 构建增强 Prompt
        context = "\n\n".join([
            f"[来源 {i+1}: {doc['source']}]\n{doc['text']}"
            for i, doc in enumerate(context_docs)
        ])

        prompt = f"""你是一个知识库问答助手。请根据以下参考资料回答用户问题。

参考资料:
{context}

用户问题:{query}

请基于参考资料给出准确、简洁的回答。如果资料中没有相关信息,请如实说明。"""

        response = await self.openai.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3,
            max_tokens=1000,
        )
        return response.choices[0].message.content

4.2 FastAPI RAG 路由

# routers/rag.py
from fastapi import APIRouter, UploadFile, File, HTTPException
from pydantic import BaseModel
from services.rag_service import RAGService

router = APIRouter(prefix="/rag", tags=["RAG 检索增强"])
rag_service = RAGService()

class QueryRequest(BaseModel):
    query: str
    top_k: int = 5
    min_score: float = 0.5
    stream: bool = False

# ── 检索问答 ──────────────────────────────────────
@router.post("/query")
async def rag_query(request: QueryRequest):
    """检索并生成回答"""
    # 1. 检索相关文档
    docs = await rag_service.retrieve(
        query=request.query,
        top_k=request.top_k,
        min_score=request.min_score,
    )

    if not docs:
        return {"answer": "抱歉,未找到相关内容。", "sources": []}

    # 2. 生成回答
    answer = await rag_service.generate_answer(request.query, docs)

    return {
        "answer": answer,
        "sources": [
            {
                "text": doc["text"][:200] + "...",
                "source": doc["source"],
                "score": doc["score"],
            }
            for doc in docs
        ]
    }

# ── 流式 RAG 问答 ──────────────────────────────────
@router.post("/query/stream")
async def rag_query_stream(request: QueryRequest):
    """流式返回 RAG 答案(打字机效果)"""
    from fastapi.responses import StreamingResponse
    import asyncio
    import json

    docs = await rag_service.retrieve(request.query, request.top_k, request.min_score)

    if not docs:
        return {"answer": "抱歉,未找到相关文档。", "sources": []}

    answer = await rag_service.generate_answer(request.query, docs)

    async def stream_response():
        for char in answer:
            yield f"data: {json.dumps({'type': 'token', 'content': char})}\n\n"
            await asyncio.sleep(0.02)
        yield f"data: {json.dumps({'type': 'done', 'sources': docs})}\n\n"

    return StreamingResponse(
        stream_response(),
        media_type="text/event-stream",
        headers={"X-Accel-Buffering": "no"}
    )

# ── 文档上传入库 ──────────────────────────────────
@router.post("/ingest")
async def ingest_document(file: UploadFile = File(...)):
    """上传文档并入库"""
    if not file.filename.endswith(".txt"):
        raise HTTPException(400, "目前只支持 .txt 文件")

    content = await file.read()
    text = content.decode("utf-8")

    # 分块入库
    chunks = TextChunker().chunk_text(text, source=file.filename)

    async def chunk_generator():
        for chunk in chunks:
            vector_store.collection.add(
                ids=[chunk["id"]],
                documents=[chunk["text"]],
                embeddings=[embedding_model.encode(chunk["text"]).tolist()],
                metadatas=[chunk["metadata"]]
            )
            yield chunk

    return StreamingResponse(
        chunk_generator(),
        media_type="application/json"
    )

5. 小结

# RAG 完整流程

# 1. 文档入库
文本 → 分块 → 向量化 → 存入向量数据库

# 2. 查询检索
用户问题 → 向量化 → 相似度搜索 → Top-K 相关文档

# 3. 增强生成
相关文档 + 用户问题 → LLM → 答案(带来源引用)

💡 RAG 优化技巧:分块大小影响检索质量(太大 → 噪声多,太小 → 上下文缺失)。建议 300-500 tokens。检索后对 Top-K 结果做重排序(rerank)可显著提升准确率。


🔗 扩展阅读