推理加速框架:ONNX Runtime、TensorRT

📂 所属阶段:第五阶段 — 工业落地与部署(实战篇)
🔗 相关章节:模型轻量化 · Web 视觉应用


1. ONNX 转换

import torch
import onnx

# PyTorch 模型转 ONNX
model = torch.load("model.pth")
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model, dummy_input, "model.onnx",
    input_names=['input'],
    output_names=['output'],
    opset_version=11
)

# 验证 ONNX 模型
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

2. ONNX Runtime 推理

import onnxruntime as ort
import numpy as np

# 加载 ONNX 模型
sess = ort.InferenceSession("model.onnx")

# 推理
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = sess.run(None, {'input': input_data})

print(output[0].shape)

3. TensorRT 优化

import tensorrt as trt

# 创建 TensorRT 引擎
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network()

# 解析 ONNX
parser = trt.OnnxParser(network, logger)
with open("model.onnx", "rb") as f:
    parser.parse(f.read())

# 构建引擎
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
engine = builder.build_engine(network, config)

# 保存引擎
with open("model.trt", "wb") as f:
    f.write(engine.serialize())

4. 小结

推理加速三步:

1. 转 ONNX:统一格式
2. ONNX Runtime:跨平台推理
3. TensorRT:GPU 优化

性能提升:
- ONNX Runtime:1.5-2 倍
- TensorRT:3-5 倍

💡 记住:ONNX 是推理的标准格式。学会用它,你就能在任何平台部署模型。


🔗 扩展阅读