Convert a PyTorch Model to ONNX and OpenVINO™ IR¶
This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.
This tutorial demonstrates step-by-step instructions on how to do inference on a PyTorch semantic segmentation model, using OpenVINO Runtime.
First, the PyTorch model is exported in ONNX format and then converted to OpenVINO IR. Then the respective ONNX and OpenVINO IR models are loaded into OpenVINO Runtime to show model predictions. In this tutorial, we will use LR-ASPP model with MobileNetV3 backbone.
According to the paper, Searching for MobileNetV3, LR-ASPP or Lite Reduced Atrous Spatial Pyramid Pooling has a lightweight and efficient segmentation decoder architecture. The diagram below illustrates the model architecture:
The model is pre-trained on the MS COCO dataset. Instead of training on all 80 classes, the segmentation model has been trained on 20 classes from the PASCAL VOC dataset: background, aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, dining table, dog, horse, motorbike, person, potted plant, sheep, sofa, train, tvmonitor
More information about the model is available in the torchvision documentation
Preparation¶
Imports¶
import sys
import time
import warnings
from pathlib import Path
import cv2
import numpy as np
import torch
from IPython.display import Markdown, display
from torchvision.models.segmentation import lraspp_mobilenet_v3_large, LRASPP_MobileNet_V3_Large_Weights
from openvino.runtime import Core
sys.path.append("../utils")
from notebook_utils import segmentation_map_to_image, viz_result_image, SegmentationMap, Label, download_file
Settings¶
Set a name for the model, then define width and height of the image that will be used by the network during inference. According to the input transforms function, the model is pre-trained on images with a height of 520 and width of 780.
IMAGE_WIDTH = 780
IMAGE_HEIGHT = 520
DIRECTORY_NAME = "model"
BASE_MODEL_NAME = DIRECTORY_NAME + "/lraspp_mobilenet_v3_large"
weights_path = Path(BASE_MODEL_NAME + ".pt")
# Paths where ONNX and OpenVINO IR models will be stored.
onnx_path = weights_path.with_suffix('.onnx')
if not onnx_path.parent.exists():
onnx_path.parent.mkdir()
ir_path = onnx_path.with_suffix(".xml")
Load Model¶
Generally, PyTorch models represent an instance of torch.nn.Module
class, initialized by a state dictionary with model weights. Typical
steps for getting a pre-trained model: 1. Create instance of model class
2. Load checkpoint state dict, which contains pre-trained model weights
3. Turn model to evaluation for switching some operations to inference
mode
The torchvision
module provides a ready to use set of functions for
model class initialization. We will use
torchvision.models.segmentation.lraspp_mobilenet_v3_large
. You can
directly pass pre-trained model weights to the model initialization
function using weights enum
LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1
. However,
for demonstration purposes, we will create it separately. Download the
pre-trained weights and load the model. This may take some time if you
have not downloaded the model before.
print("Downloading the LRASPP MobileNetV3 model (if it has not been downloaded already)...")
download_file(LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.url, filename=weights_path.name, directory=weights_path.parent)
# create model object
model = lraspp_mobilenet_v3_large()
# read state dict, use map_location argument to avoid a situation where weights are saved in cuda (which may not be unavailable on the system)
state_dict = torch.load(weights_path, map_location='cpu')
# load state dict to model
model.load_state_dict(state_dict)
# switch model from training to inference mode
model.eval()
print("Loaded PyTorch LRASPP MobileNetV3 model")
Downloading the LRASPP MobileNetV3 model (if it has not been downloaded already)...
model/lraspp_mobilenet_v3_large.pt: 0%| | 0.00/12.5M [00:00<?, ?B/s]
Loaded PyTorch LRASPP MobileNetV3 model
ONNX Model Conversion¶
Convert PyTorch model to ONNX¶
OpenVINO supports PyTorch models that are exported in ONNX format. We
will use the torch.onnx.export
function to obtain the ONNX model,
you can learn more about this feature in the PyTorch
documentation. We need to
provide a model object, example input for model tracing and path where
the model will be saved. When providing example input, it is not
necessary to use real data, dummy input data with specified shape is
sufficient. Optionally, we can provide a target onnx opset for
conversion and/or other parameters specified in documentation
(e.g. input and output names or dynamic shapes).
Sometimes a warning will be shown, but in most cases it is harmless, so
let us just filter it out. When the conversion is successful, the last
line of the output will read:
ONNX model exported to model/lraspp_mobilenet_v3_large.onnx.
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
if not onnx_path.exists():
dummy_input = torch.randn(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH)
torch.onnx.export(
model,
dummy_input,
onnx_path,
)
print(f"ONNX model exported to {onnx_path}.")
else:
print(f"ONNX model {onnx_path} already exists.")
ONNX model exported to model/lraspp_mobilenet_v3_large.onnx.
Convert ONNX Model to OpenVINO IR Format¶
Use Model Optimizer to convert the ONNX model to OpenVINO IR with
FP16
precision. The models are saved inside the current directory.
For more information about Model Optimizer, see the Model Optimizer
Developer
Guide.
Executing this command may take a while. There may be some errors or
warnings in the output. When Model Optimization is successful, the last
lines of the output will include:
[ SUCCESS ] Generated IR version 11 model.
# Construct the command for Model Optimizer.
mo_command = f"""mo
--input_model "{onnx_path}"
--compress_to_fp16
--output_dir "{ir_path.parent}"
"""
mo_command = " ".join(mo_command.split())
print("Model Optimizer command to convert the ONNX model to OpenVINO:")
display(Markdown(f"`{mo_command}`"))
Model Optimizer command to convert the ONNX model to OpenVINO:
mo --input_model "model/lraspp_mobilenet_v3_large.onnx" --compress_to_fp16 --output_dir "model"
if not ir_path.exists():
print("Exporting ONNX model to IR... This may take a few minutes.")
mo_result = %sx $mo_command
print("\n".join(mo_result))
else:
print(f"IR model {ir_path} already exists.")
Exporting ONNX model to IR... This may take a few minutes.
Check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2023_bu_IOTG_OpenVINO-2022-3&content=upg_all&medium=organic or on https://github.com/openvinotoolkit/openvino
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai/latest/openvino_2_0_transition_guide.html
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/notebooks/102-pytorch-onnx-to-openvino/model/lraspp_mobilenet_v3_large.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/notebooks/102-pytorch-onnx-to-openvino/model/lraspp_mobilenet_v3_large.bin
Show Results¶
Confirm that the segmentation results look as expected by comparing model predictions on the ONNX, OpenVINO IR and PyTorch models.
Load and Preprocess an Input Image¶
Images need to be normalized before propagating through the network.
def normalize(image: np.ndarray) -> np.ndarray:
"""
Normalize the image to the given mean and standard deviation
for CityScapes models.
"""
image = image.astype(np.float32)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
image /= 255.0
image -= mean
image /= std
return image
image_filename = "../data/image/coco.jpg"
image = cv2.cvtColor(cv2.imread(image_filename), cv2.COLOR_BGR2RGB)
resized_image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
normalized_image = normalize(resized_image)
# Convert the resized images to network input shape.
input_image = np.expand_dims(np.transpose(resized_image, (2, 0, 1)), 0)
normalized_input_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)
Load the OpenVINO IR Network and Run Inference on the ONNX model¶
OpenVINO Runtime can load ONNX models directly. First, load the ONNX model, do inference and show the results. Then, load the model that was converted to OpenVINO Intermediate Representation (OpenVINO IR) with Model Optimizer and do inference on that model, and show the results on an image.
1. ONNX Model in OpenVINO Runtime¶
# Load the network to OpenVINO Runtime.
ie = Core()
model_onnx = ie.read_model(model=onnx_path)
compiled_model_onnx = ie.compile_model(model=model_onnx, device_name="CPU")
output_layer_onnx = compiled_model_onnx.output(0)
# Run inference on the input image.
res_onnx = compiled_model_onnx([normalized_input_image])[output_layer_onnx]
Model predicts probabilities for how well each pixel corresponds to a specific label. To get the label with highest probability for each pixel, operation argmax should be applied. After that, color coding can be applied to each label for more convenient visualization.
voc_labels = [
Label(index=0, color=(0, 0, 0), name="background"),
Label(index=1, color=(128, 0, 0), name="aeroplane"),
Label(index=2, color=(0, 128, 0), name="bicycle"),
Label(index=3, color=(128, 128, 0), name="bird"),
Label(index=4, color=(0, 0, 128), name="boat"),
Label(index=5, color=(128, 0, 128), name="bottle"),
Label(index=6, color=(0, 128, 128), name="bus"),
Label(index=7, color=(128, 128, 128), name="car"),
Label(index=8, color=(64, 0, 0), name="cat"),
Label(index=9, color=(192, 0, 0), name="chair"),
Label(index=10, color=(64, 128, 0), name="cow"),
Label(index=11, color=(192, 128, 0), name="dining table"),
Label(index=12, color=(64, 0, 128), name="dog"),
Label(index=13, color=(192, 0, 128), name="horse"),
Label(index=14, color=(64, 128, 128), name="motorbike"),
Label(index=15, color=(192, 128, 128), name="person"),
Label(index=16, color=(0, 64, 0), name="potted plant"),
Label(index=17, color=(128, 64, 0), name="sheep"),
Label(index=18, color=(0, 192, 0), name="sofa"),
Label(index=19, color=(128, 192, 0), name="train"),
Label(index=20, color=(0, 64, 128), name="tv monitor")
]
VOCLabels = SegmentationMap(voc_labels)
# Convert the network result to a segmentation map and display the result.
result_mask_onnx = np.squeeze(np.argmax(res_onnx, axis=1)).astype(np.uint8)
viz_result_image(
image,
segmentation_map_to_image(result_mask_onnx, VOCLabels.get_colormap()),
resize=True,
)
2. OpenVINO IR Model in OpenVINO Runtime¶
# Load the network in OpenVINO Runtime.
ie = Core()
model_ir = ie.read_model(model=ir_path)
compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")
# Get input and output layers.
output_layer_ir = compiled_model_ir.output(0)
# Run inference on the input image.
res_ir = compiled_model_ir([normalized_input_image])[output_layer_ir]
result_mask_ir = np.squeeze(np.argmax(res_ir, axis=1)).astype(np.uint8)
viz_result_image(
image,
segmentation_map_to_image(result=result_mask_ir, colormap=VOCLabels.get_colormap()),
resize=True,
)
PyTorch Comparison¶
Do inference on the PyTorch model to verify that the output visually looks the same as the output on the ONNX/OpenVINO IR models.
model.eval()
with torch.no_grad():
result_torch = model(torch.as_tensor(normalized_input_image).float())
result_mask_torch = torch.argmax(result_torch['out'], dim=1).squeeze(0).numpy().astype(np.uint8)
viz_result_image(
image,
segmentation_map_to_image(result=result_mask_torch, colormap=VOCLabels.get_colormap()),
resize=True,
)
Performance Comparison¶
Measure the time it takes to do inference on twenty images. This gives an indication of performance. For more accurate benchmarking, use the Benchmark Tool. Keep in mind that many optimizations are possible to improve the performance.
num_images = 100
with torch.no_grad():
start = time.perf_counter()
for _ in range(num_images):
model(torch.as_tensor(input_image).float())
end = time.perf_counter()
time_torch = end - start
print(
f"PyTorch model on CPU: {time_torch/num_images:.3f} seconds per image, "
f"FPS: {num_images/time_torch:.2f}"
)
start = time.perf_counter()
for _ in range(num_images):
compiled_model_onnx([normalized_input_image])
end = time.perf_counter()
time_onnx = end - start
print(
f"ONNX model in OpenVINO Runtime/CPU: {time_onnx/num_images:.3f} "
f"seconds per image, FPS: {num_images/time_onnx:.2f}"
)
start = time.perf_counter()
for _ in range(num_images):
compiled_model_ir([input_image])
end = time.perf_counter()
time_ir = end - start
print(
f"OpenVINO IR model in OpenVINO Runtime/CPU: {time_ir/num_images:.3f} "
f"seconds per image, FPS: {num_images/time_ir:.2f}"
)
if "GPU" in ie.available_devices:
compiled_model_onnx_gpu = ie.compile_model(model=model_onnx, device_name="GPU")
start = time.perf_counter()
for _ in range(num_images):
compiled_model_onnx_gpu([input_image])
end = time.perf_counter()
time_onnx_gpu = end - start
print(
f"ONNX model in OpenVINO/GPU: {time_onnx_gpu/num_images:.3f} "
f"seconds per image, FPS: {num_images/time_onnx_gpu:.2f}"
)
compiled_model_ir_gpu = ie.compile_model(model=model_ir, device_name="GPU")
start = time.perf_counter()
for _ in range(num_images):
compiled_model_ir_gpu([input_image])
end = time.perf_counter()
time_ir_gpu = end - start
print(
f"IR model in OpenVINO/GPU: {time_ir_gpu/num_images:.3f} "
f"seconds per image, FPS: {num_images/time_ir_gpu:.2f}"
)
PyTorch model on CPU: 0.038 seconds per image, FPS: 26.34
ONNX model in OpenVINO Runtime/CPU: 0.029 seconds per image, FPS: 34.02
OpenVINO IR model in OpenVINO Runtime/CPU: 0.029 seconds per image, FPS: 34.69
Show Device Information
devices = ie.available_devices
for device in devices:
device_name = ie.get_property(device, "FULL_DEVICE_NAME")
print(f"{device}: {device_name}")
CPU: Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz