Convert a PyTorch Model to ONNX and OpenVINO IR

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This tutorial demonstrates step-by-step instructions to perform inference on a PyTorch semantic segmentation model using OpenVINO’s Inference Engine.

First, the PyTorch model is converted to ONNX and OpenVINO Intermediate Representation (IR) formats. Then the ONNX and IR models are loaded in OpenVINO Inference Engine to show model predictions. The model is pre-trained on the CityScapes dataset. The source of the model is FastSeg.

Preparation

Imports

import sys
import time
from pathlib import Path

import cv2
import numpy as np
import torch
from IPython.display import Markdown, display
from fastseg import MobileV3Large
from openvino.runtime import Core

sys.path.append("../utils")
from notebook_utils import CityScapesSegmentation, segmentation_map_to_image, viz_result_image

Settings

Set the name for the model, and the image width and height that will be used for the network. CityScapes is pretrained on images of 2048x1024. Using smaller dimensions will impact model accuracy, but will improve inference speed.

IMAGE_WIDTH = 1024  # Suggested values: 2048, 1024 or 512. The minimum width is 512.
# Set IMAGE_HEIGHT manually for custom input sizes. Minimum height is 512
IMAGE_HEIGHT = 1024 if IMAGE_WIDTH == 2048 else 512
DIRECTORY_NAME = "model"
BASE_MODEL_NAME = DIRECTORY_NAME + f"/fastseg{IMAGE_WIDTH}"

# Paths where PyTorch, ONNX and OpenVINO IR models will be stored
model_path = Path(BASE_MODEL_NAME).with_suffix(".pth")
onnx_path = model_path.with_suffix(".onnx")
ir_path = model_path.with_suffix(".xml")

Download the Fastseg Model

Download, load and save the model with pretrained weights. This may take some time if you have not downloaded the model before.

print("Downloading the Fastseg model (if it has not been downloaded before)....")
model = MobileV3Large.from_pretrained().cpu().eval()
print("Loaded PyTorch Fastseg model")

# Save the model
model_path.parent.mkdir(exist_ok=True)
torch.save(model.state_dict(), str(model_path))
print(f"Model saved at {model_path}")
Downloading the Fastseg model (if it has not been downloaded before)....
Loading pretrained model mobilev3large-lraspp with F=128...
Loaded PyTorch Fastseg model
Model saved at model/fastseg1024.pth

ONNX Model Conversion

Convert PyTorch model to ONNX

The output for this cell will show some warnings. These are most likely harmless. Conversion succeeded if the last line of the output says ONNX model exported to fastseg1024.onnx.

if not onnx_path.exists():
    dummy_input = torch.randn(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH)

    # For the Fastseg model, setting do_constant_folding to False is required
    # for PyTorch>1.5.1
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        opset_version=11,
        do_constant_folding=False,
    )
    print(f"ONNX model exported to {onnx_path}.")
else:
    print(f"ONNX model {onnx_path} already exists.")
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/geffnet/conv2d_layers.py:39: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/geffnet/conv2d_layers.py:39: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/geffnet/conv2d_layers.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_h > 0 or pad_w > 0:
ONNX model exported to model/fastseg1024.onnx.

Convert ONNX Model to OpenVINO IR Format

Call the OpenVINO Model Optimizer tool to convert the ONNX model to OpenVINO IR with FP16 precision. The models are saved to the current directory. We add the mean values to the model and scale the output with the standard deviation with --scale_values. With these options, it is not necessary to normalize input data before propagating it through the network.

See the Model Optimizer Developer Guide for more information about Model Optimizer.

Executing this command may take a while. There may be some errors or warnings in the output. Model Optimization was successful if the last lines of the output include [ SUCCESS ] Generated IR version 11 model.

# Construct the command for Model Optimizer
mo_command = f"""mo
                 --input_model "{onnx_path}"
                 --input_shape "[1,3, {IMAGE_HEIGHT}, {IMAGE_WIDTH}]"
                 --mean_values="[123.675, 116.28 , 103.53]"
                 --scale_values="[58.395, 57.12 , 57.375]"
                 --data_type FP16
                 --output_dir "{model_path.parent}"
                 """
mo_command = " ".join(mo_command.split())
print("Model Optimizer command to convert the ONNX model to OpenVINO:")
display(Markdown(f"`{mo_command}`"))
Model Optimizer command to convert the ONNX model to OpenVINO:

mo --input_model "model/fastseg1024.onnx" --input_shape "[1,3, 512, 1024]" --mean_values="[123.675, 116.28 , 103.53]" --scale_values="[58.395, 57.12 , 57.375]" --data_type FP16 --output_dir "model"

if not ir_path.exists():
    print("Exporting ONNX model to IR... This may take a few minutes.")
    mo_result = %sx $mo_command
    print("\n".join(mo_result))
else:
    print(f"IR model {ir_path} already exists.")
Exporting ONNX model to IR... This may take a few minutes.
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/102-pytorch-onnx-to-openvino/model/fastseg1024.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/102-pytorch-onnx-to-openvino/model
    - IR output name:   fastseg1024
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,3, 512, 1024]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  [123.675, 116.28 , 103.53]
    - Scale values:     [58.395, 57.12 , 57.375]
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/102-pytorch-onnx-to-openvino/model/fastseg1024.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/102-pytorch-onnx-to-openvino/model/fastseg1024.bin
[ SUCCESS ] Total execution time: 0.80 seconds.
[ SUCCESS ] Memory consumed: 109 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Show Results

Confirm that the segmentation results look as expected, by comparing model predictions on the ONNX, IR and PyTorch model

Load and Preprocess an Input Image

For the OpenVINO model, normalization is moved to the model. For the ONNX and PyTorch models, images need to be normalized before propagating through the network. A sample image from the Mapillary Vistas dataset is provided for inference.

def normalize(image: np.ndarray) -> np.ndarray:
    """
    Normalize the image to the given mean and standard deviation
    for CityScapes models.
    """
    image = image.astype(np.float32)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    image /= 255.0
    image -= mean
    image /= std
    return image
image_filename = "data/street.jpg"
image = cv2.cvtColor(cv2.imread(image_filename), cv2.COLOR_BGR2RGB)

resized_image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
normalized_image = normalize(resized_image)

# Convert the resized images to network input shape
input_image = np.expand_dims(np.transpose(resized_image, (2, 0, 1)), 0)
normalized_input_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)

Load the OpenVINO IR Network and Run Inference on the ONNX model

Inference Engine can load ONNX models directly. We first load the ONNX model, do inference and show the results. After that we load the model that was converted to Intermediate Representation (IR) with Model Optimizer and do inference on that model and show the results on an image from Mapillary Vistas.

1. ONNX Model in Inference Engine

# Load network to Inference Engine
ie = Core()
model_onnx = ie.read_model(model=onnx_path)
compiled_model_onnx = ie.compile_model(model=model_onnx, device_name="CPU")

output_layer_onnx = compiled_model_onnx.output(0)

# Run inference on the input image
res_onnx = compiled_model_onnx([normalized_input_image])[output_layer_onnx]
# Convert network result to segmentation map and display the result
result_mask_onnx = np.squeeze(np.argmax(res_onnx, axis=1)).astype(np.uint8)
viz_result_image(
    image,
    segmentation_map_to_image(result_mask_onnx, CityScapesSegmentation.get_colormap()),
    resize=True,
)
../_images/102-pytorch-onnx-to-openvino-with-output_19_0.png

2. IR Model in Inference Engine

# Load the network in Inference Engine
ie = Core()
model_ir = ie.read_model(model=ir_path)
compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")

# Get input and output layers
output_layer_ir = compiled_model_ir.output(0)

# Run inference on the input image
res_ir = compiled_model_ir([input_image])[output_layer_ir]
result_mask_ir = np.squeeze(np.argmax(res_ir, axis=1)).astype(np.uint8)
viz_result_image(
    image,
    segmentation_map_to_image(result=result_mask_ir, colormap=CityScapesSegmentation.get_colormap()),
    resize=True,
)
../_images/102-pytorch-onnx-to-openvino-with-output_22_0.png

PyTorch Comparison

Do inference on the PyTorch model to verify that the output visually looks the same as the output on the ONNX/IR models.

with torch.no_grad():
    result_torch = model(torch.as_tensor(normalized_input_image).float())

result_mask_torch = torch.argmax(result_torch, dim=1).squeeze(0).numpy().astype(np.uint8)
viz_result_image(
    image,
    segmentation_map_to_image(result=result_mask_torch, colormap=CityScapesSegmentation.get_colormap()),
    resize=True,
)
../_images/102-pytorch-onnx-to-openvino-with-output_24_0.png

Performance Comparison

Measure the time it takes to do inference on twenty images. This gives an indication of performance. For more accurate benchmarking, use the OpenVINO Benchmark Tool. Note that many optimizations are possible to improve the performance.

num_images = 20

start = time.perf_counter()
for _ in range(num_images):
    compiled_model_onnx([normalized_input_image])
end = time.perf_counter()
time_onnx = end - start
print(
    f"ONNX model in Inference Engine/CPU: {time_onnx/num_images:.3f} "
    f"seconds per image, FPS: {num_images/time_onnx:.2f}"
)

start = time.perf_counter()
for _ in range(num_images):
    compiled_model_ir([input_image])
end = time.perf_counter()
time_ir = end - start
print(
    f"IR model in Inference Engine/CPU: {time_ir/num_images:.3f} "
    f"seconds per image, FPS: {num_images/time_ir:.2f}"
)

with torch.no_grad():
    start = time.perf_counter()
    for _ in range(num_images):
        model(torch.as_tensor(input_image).float())
    end = time.perf_counter()
    time_torch = end - start
print(
    f"PyTorch model on CPU: {time_torch/num_images:.3f} seconds per image, "
    f"FPS: {num_images/time_torch:.2f}"
)

if "GPU" in ie.available_devices:
    compiled_model_onnx_gpu = ie.compile_model(model=model_onnx, device_name="GPU")
    start = time.perf_counter()
    for _ in range(num_images):
        compiled_model_onnx_gpu([input_image])
    end = time.perf_counter()
    time_onnx_gpu = end - start
    print(
        f"ONNX model in Inference Engine/GPU: {time_onnx_gpu/num_images:.3f} "
        f"seconds per image, FPS: {num_images/time_onnx_gpu:.2f}"
    )

    compiled_model_ir_gpu = ie.compile_model(model=model_ir, device_name="GPU")
    start = time.perf_counter()
    for _ in range(num_images):
        compiled_model_ir_gpu([input_image])
    end = time.perf_counter()
    time_ir_gpu = end - start
    print(
        f"IR model in Inference Engine/GPU: {time_ir_gpu/num_images:.3f} "
        f"seconds per image, FPS: {num_images/time_ir_gpu:.2f}"
    )
ONNX model in Inference Engine/CPU: 0.069 seconds per image, FPS: 14.40
IR model in Inference Engine/CPU: 0.069 seconds per image, FPS: 14.40
PyTorch model on CPU: 0.348 seconds per image, FPS: 2.87

Show Device Information

devices = ie.available_devices
for device in devices:
    device_name = ie.get_property(device_name=device, name="FULL_DEVICE_NAME")
    print(f"{device}: {device_name}")
CPU: Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz