转换 PyTorch 模型¶
通过导出为 ONNX 格式可支持 PyTorch 框架。为优化并部署使用此框架训练的模型:
根据训练的网络拓扑、权重和偏差值 转换 ONNX 模型 ,以生成优化的模型 中间表示 。
将 PyTorch 模型导出为 ONNX 格式¶
PyTorch 模型在 Python 中定义。要导出这些模型,请使用 torch.onnx.export() 方法。用于
评估或测试模型的代码通常随附自己的代码,可用于初始化和导出操作。
导出为 ONNX 是该流程中至关重要的一步,但 PyTorch 框架中会涵盖此步骤,因此此处不再详细阐述。
如欲获取更多信息,请参阅 将 PyTorch 模型导出为 ONNX 格式 指南。
要导出 PyTorch 模型,需要将模型作为 torch.nn.Module 类的实例获取,并调用 export 函数。
import torch
# Instantiate your model. This is just a regular PyTorch model that will be exported in the following steps.
model = SomeModel()
# Evaluate the model to switch some operations from training mode to inference.
model.eval()
# Create dummy input for the model. It will be used to run the model inside export function.
dummy_input = torch.randn(1, 3, 224, 224)
# Call the export function
torch.onnx.export(model, (dummy_input, ), 'model.onnx')
已知问题¶
- 自版本 1.8.1 起,并非全部 PyTorch 操作都会导出为默认使用的 ONNX opset 9。
建议在未能导出为默认的 opset 9 的情况下,将模型导出为 opset 11 或更高版本。在这种情况下,请使用
opset_version选项 (torch.onnx.export)。如欲了解有关 ONNX opset 的详细信息,请参阅 算子模式 页面。