转换 PyTorch 模型

通过导出为 ONNX 格式可支持 PyTorch 框架。为优化并部署使用此框架训练的模型:

  1. 将 PyTorch 模型导出为 ONNX

  2. 根据训练的网络拓扑、权重和偏差值 转换 ONNX 模型 ,以生成优化的模型 中间表示

将 PyTorch 模型导出为 ONNX 格式

PyTorch 模型在 Python 中定义。要导出这些模型,请使用 torch.onnx.export() 方法。用于 评估或测试模型的代码通常随附自己的代码,可用于初始化和导出操作。 导出为 ONNX 是该流程中至关重要的一步,但 PyTorch 框架中会涵盖此步骤,因此此处不再详细阐述。 如欲获取更多信息,请参阅 将 PyTorch 模型导出为 ONNX 格式 指南。

要导出 PyTorch 模型,需要将模型作为 torch.nn.Module 类的实例获取,并调用 export 函数。

import torch

# Instantiate your model. This is just a regular PyTorch model that will be exported in the following steps.
model = SomeModel()
# Evaluate the model to switch some operations from training mode to inference.
model.eval()
# Create dummy input for the model. It will be used to run the model inside export function.
dummy_input = torch.randn(1, 3, 224, 224)
# Call the export function
torch.onnx.export(model, (dummy_input, ), 'model.onnx')

已知问题

  • 自版本 1.8.1 起,并非全部 PyTorch 操作都会导出为默认使用的 ONNX opset 9。

    建议在未能导出为默认的 opset 9 的情况下,将模型导出为 opset 11 或更高版本。在这种情况下,请使用 opset_version 选项 (torch.onnx.export)。如欲了解有关 ONNX opset 的详细信息,请参阅 算子模式 页面。

其他资源

请参阅 模型转换教程 页面获取一系列教程,了解转换特定 PyTorch 模型相关的分步指导。部分示例如下: