📂 所属阶段:第五阶段 — 工业落地与部署(实战篇) 🔗 相关章节:模型轻量化 · Web 视觉应用
import torch import onnx # PyTorch 模型转 ONNX model = torch.load("model.pth") dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export( model, dummy_input, "model.onnx", input_names=['input'], output_names=['output'], opset_version=11 ) # 验证 ONNX 模型 onnx_model = onnx.load("model.onnx") onnx.checker.check_model(onnx_model)
import onnxruntime as ort import numpy as np # 加载 ONNX 模型 sess = ort.InferenceSession("model.onnx") # 推理 input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) output = sess.run(None, {'input': input_data}) print(output[0].shape)
import tensorrt as trt # 创建 TensorRT 引擎 logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network() # 解析 ONNX parser = trt.OnnxParser(network, logger) with open("model.onnx", "rb") as f: parser.parse(f.read()) # 构建引擎 config = builder.create_builder_config() config.max_workspace_size = 1 << 30 engine = builder.build_engine(network, config) # 保存引擎 with open("model.trt", "wb") as f: f.write(engine.serialize())
推理加速三步: 1. 转 ONNX:统一格式 2. ONNX Runtime:跨平台推理 3. TensorRT:GPU 优化 性能提升: - ONNX Runtime:1.5-2 倍 - TensorRT:3-5 倍
💡 记住:ONNX 是推理的标准格式。学会用它,你就能在任何平台部署模型。
🔗 扩展阅读