Convert and Optimize YOLOv11 instance segmentation model with OpenVINO™#
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window.
You can also make a local installation. Choose one of the following options:
Instance segmentation goes a step further than object detection and
involves identifying individual objects in an image and segmenting them
from the rest of the image. Instance segmentation as an object detection
are often used as key components in computer vision systems.
Applications that use real-time instance segmentation models include
video analytics, robotics, autonomous vehicles, multi-object tracking
and object counting, medical image analysis, and many others.
This tutorial demonstrates step-by-step instructions on how to run and
optimize PyTorch YOLOv11 with OpenVINO. We consider the steps required
for instance segmentation scenario. You can find more details about
model on model page in
Ultralytics documentation.
The tutorial consists of the following steps: - Prepare the PyTorch
model. - Download and prepare a dataset. - Validate the original model.
- Convert the PyTorch model to OpenVINO IR. - Validate the converted
model. - Prepare and run optimization pipeline. - Compare performance of
the FP32 and quantized models. - Compare accuracy of the FP32 and
quantized models. - Live demo
Generally, PyTorch models represent an instance of the
torch.nn.Module
class, initialized by a state dictionary with model weights. We will use
the YOLOv11 nano model (also known as yolo11n-seg) pre-trained on a
COCO dataset, which is available in this
repo. Similar steps are
also applicable to other YOLOv11 models. Typical steps to obtain a
pre-trained model: 1. Create an instance of a model class. 2. Load a
checkpoint state dict, which contains the pre-trained model weights. 3.
Turn the model to evaluation for switching some operations to inference
mode.
In this case, the creators of the model provide an API that enables
converting the YOLOv11 model to OpenVINO IR. Therefore, we do not need
to do these steps manually.
# Download a test sampleIMAGE_PATH=Path("./data/coco_bike.jpg")download_file(url="https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/coco_bike.jpg",filename=IMAGE_PATH.name,directory=IMAGE_PATH.parent,)
For loading the model, required to specify a path to the model
checkpoint. It can be some local path or name available on models hub
(in this case model checkpoint will be downloaded automatically). You
can select model using widget bellow:
Making prediction, the model accepts a path to input image and returns
list with Results class object. Results contains boxes for object
detection model and boxes and masks for segmentation model. Also it
contains utilities for processing results, for example, plot()
method for drawing.
Ultralytics provides API for convenient model exporting to different
formats including OpenVINO IR. model.export is responsible for model
conversion. We need to specify the format, and additionally, we can
preserve dynamic shapes in the model.
Great! The result is the same, as produced by original models.
Optimize model using NNCF Post-training Quantization API#
NNCF provides a suite of
advanced algorithms for Neural Networks inference optimization in
OpenVINO with minimal accuracy drop. We will use 8-bit quantization in
post-training mode (without the fine-tuning pipeline) to optimize
YOLOv11.
The optimization process contains the following steps:
Create a Dataset for quantization.
Run nncf.quantize for getting an optimized model.
Serialize OpenVINO IR model, using the openvino.runtime.serialize
function.
Please select below whether you would like to run quantization to
improve model inference speed.
Reuse validation dataloader in accuracy testing for quantization. For
that, it should be wrapped into the nncf.Dataset object and define a
transformation function for getting only input tensors.
# %%skip not $to_quantize.valueimportnncffromtypingimportDictfromzipfileimportZipFilefromultralytics.data.utilsimportDATASETS_DIRfromultralytics.utilsimportDEFAULT_CFGfromultralytics.cfgimportget_cfgfromultralytics.data.converterimportcoco80_to_coco91_classfromultralytics.data.utilsimportcheck_det_datasetfromultralytics.utilsimportopsifnotint8_model_seg_path.exists():DATA_URL="http://images.cocodataset.org/zips/val2017.zip"LABELS_URL="https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels-segments.zip"CFG_URL="https://raw.githubusercontent.com/ultralytics/ultralytics/v8.1.0/ultralytics/cfg/datasets/coco.yaml"OUT_DIR=DATASETS_DIRDATA_PATH=OUT_DIR/"val2017.zip"LABELS_PATH=OUT_DIR/"coco2017labels-segments.zip"CFG_PATH=OUT_DIR/"coco.yaml"download_file(DATA_URL,DATA_PATH.name,DATA_PATH.parent)download_file(LABELS_URL,LABELS_PATH.name,LABELS_PATH.parent)download_file(CFG_URL,CFG_PATH.name,CFG_PATH.parent)ifnot(OUT_DIR/"coco/labels").exists():withZipFile(LABELS_PATH,"r")aszip_ref:zip_ref.extractall(OUT_DIR)withZipFile(DATA_PATH,"r")aszip_ref:zip_ref.extractall(OUT_DIR/"coco/images")args=get_cfg(cfg=DEFAULT_CFG)args.data=str(CFG_PATH)seg_validator=seg_model.task_map[seg_model.task]["validator"](args=args)seg_validator.data=check_det_dataset(args.data)seg_validator.stride=32seg_data_loader=seg_validator.get_dataloader(OUT_DIR/"coco/",1)seg_validator.is_coco=Trueseg_validator.class_map=coco80_to_coco91_class()seg_validator.names=label_mapseg_validator.metrics.names=seg_validator.namesseg_validator.nc=80seg_validator.nm=32seg_validator.process=ops.process_maskseg_validator.plot_masks=[]deftransform_fn(data_item:Dict):""" Quantization transform function. Extracts and preprocess input data from dataloader item for quantization. Parameters: data_item: Dict with data item produced by DataLoader during iteration Returns: input_tensor: Input data for quantization """input_tensor=seg_validator.preprocess(data_item)["img"].numpy()returninput_tensorquantization_dataset=nncf.Dataset(seg_data_loader,transform_fn)
The nncf.quantize function provides an interface for model
quantization. It requires an instance of the OpenVINO Model and
quantization dataset. Optionally, some additional parameters for the
configuration quantization process (number of samples for quantization,
preset, ignored scope, etc.) can be provided. Ultralytics models contain
non-ReLU activation functions, which require asymmetric quantization of
activations. To achieve a better result, we will use a mixed
quantization preset. It provides symmetric quantization of weights and
asymmetric quantization of activations. For more accurate results, we
should keep the operation in the postprocessing subgraph in floating
point precision, using the ignored_scope parameter.
Note: Model post-training quantization is time-consuming process.
Be patient, it can take several minutes depending on your hardware.
%%skip not $to_quantize.value
if not int8_model_seg_path.exists():
ignored_scope = nncf.IgnoredScope( # post-processing
subgraphs=[
nncf.Subgraph(inputs=[f"__module.model.{22 if 'v8' in SEG_MODEL_NAME else 23}/aten::cat/Concat",
f"__module.model.{22 if 'v8' in SEG_MODEL_NAME else 23}/aten::cat/Concat_1",
f"__module.model.{22 if 'v8' in SEG_MODEL_NAME else 23}/aten::cat/Concat_2",
f"__module.model.{22 if 'v8' in SEG_MODEL_NAME else 23}/aten::cat/Concat_7"],
outputs=[f"__module.model.{22 if 'v8' in SEG_MODEL_NAME else 23}/aten::cat/Concat_8"])
]
)
# Segmentation model
quantized_seg_model = nncf.quantize(
seg_ov_model,
quantization_dataset,
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=ignored_scope
)
print(f"Quantized segmentation model will be saved to {int8_model_seg_path}")
ov.save_model(quantized_seg_model, str(int8_model_seg_path))
nncf.quantize returns the OpenVINO Model class instance, which is
suitable for loading on a device for making predictions. INT8 model
input data and output result formats have no difference from the
floating point model representation. Therefore, we can reuse the same
detect function defined above for getting the INT8 model result
on the image.
%%skip not $to_quantize.value
device
%%skip not $to_quantize.value
if quantized_seg_model is None:
quantized_seg_model = core.read_model(int8_model_seg_path)
ov_config = {}
if device.value != "CPU":
quantized_seg_model.reshape({0: [1, 3, 640, 640]})
if "GPU" in device.value or ("AUTO" in device.value and "GPU" in core.available_devices):
ov_config = {"GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
quantized_seg_compiled_model = core.compile_model(quantized_seg_model, device.value, ov_config)
Compare performance of the Original and Quantized Models#
Finally, use the OpenVINO
Benchmark
Tool
to measure the inference performance of the FP32 and INT8
models.
Note: For more accurate performance, it is recommended to run
benchmark_app in a terminal/command prompt after closing other
applications. Run
benchmark_app-m<model_path>-dCPU-shape"<input_shape>" to
benchmark async inference on CPU on specific input data shape for one
minute. Change CPU to GPU to benchmark on GPU. Run
benchmark_app--help to see an overview of all command-line
options.
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2024.5.0-16993-9c432a3641a
[ INFO ]
[ INFO ] Device info:
[ INFO ] AUTO
[ INFO ] Build ................................. 2024.5.0-16993-9c432a3641a
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 19.89 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ] x (node: x) : f32 / [...] / [?,3,?,?]
[ INFO ] Model outputs:
[ INFO ] *NO_NAME* (node: __module.model.23/aten::cat/Concat_8) : f32 / [...] / [?,116,21..]
[ INFO ] input.255 (node: __module.model.23.cv4.2.1.act/aten::silu_/Swish_46) : f32 / [...] / [?,32,8..,8..]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'x': [1,3,640,640]
[ INFO ] Reshape model took 8.57 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ] x (node: x) : u8 / [N,C,H,W] / [1,3,640,640]
[ INFO ] Model outputs:
[ INFO ] *NO_NAME* (node: __module.model.23/aten::cat/Concat_8) : f32 / [...] / [1,116,8400]
[ INFO ] input.255 (node: __module.model.23.cv4.2.1.act/aten::silu_/Swish_46) : f32 / [...] / [1,32,160,160]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 368.93 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ] NETWORK_NAME: Model0
[ INFO ] EXECUTION_DEVICES: ['CPU']
[ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ] MULTI_DEVICE_PRIORITIES: CPU
[ INFO ] CPU:
[ INFO ] AFFINITY: Affinity.CORE
[ INFO ] CPU_DENORMALS_OPTIMIZATION: False
[ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[ INFO ] DYNAMIC_QUANTIZATION_GROUP_SIZE: 32
[ INFO ] ENABLE_CPU_PINNING: True
[ INFO ] ENABLE_HYPER_THREADING: True
[ INFO ] EXECUTION_DEVICES: ['CPU']
[ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ] INFERENCE_NUM_THREADS: 24
[ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ] KV_CACHE_PRECISION: <Type: 'float16'>
[ INFO ] LOG_LEVEL: Level.NO
[ INFO ] MODEL_DISTRIBUTION_POLICY: set()
[ INFO ] NETWORK_NAME: Model0
[ INFO ] NUM_STREAMS: 6
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ] PERFORMANCE_HINT: THROUGHPUT
[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ] PERF_COUNT: NO
[ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ] MODEL_PRIORITY: Priority.MEDIUM
[ INFO ] LOADED_FROM_CACHE: False
[ INFO ] PERF_COUNT: False
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'x'!. This input will be filled with random values!
[ INFO ] Fill input 'x' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 40.94 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count: 1758 iterations
[ INFO ] Duration: 15078.04 ms
[ INFO ] Latency:
[ INFO ] Median: 48.01 ms
[ INFO ] Average: 51.29 ms
[ INFO ] Min: 39.82 ms
[ INFO ] Max: 142.69 ms
[ INFO ] Throughput: 116.59 FPS
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2024.5.0-16993-9c432a3641a
[ INFO ]
[ INFO ] Device info:
[ INFO ] AUTO
[ INFO ] Build ................................. 2024.5.0-16993-9c432a3641a
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 29.86 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ] x (node: x) : f32 / [...] / [1,3,640,640]
[ INFO ] Model outputs:
[ INFO ] *NO_NAME* (node: __module.model.23/aten::cat/Concat_8) : f32 / [...] / [1,116,8400]
[ INFO ] input.255 (node: __module.model.23.cv4.2.1.act/aten::silu_/Swish_46) : f32 / [...] / [1,32,160,160]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'x': [1,3,640,640]
[ INFO ] Reshape model took 0.04 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ] x (node: x) : u8 / [N,C,H,W] / [1,3,640,640]
[ INFO ] Model outputs:
[ INFO ] *NO_NAME* (node: __module.model.23/aten::cat/Concat_8) : f32 / [...] / [1,116,8400]
[ INFO ] input.255 (node: __module.model.23.cv4.2.1.act/aten::silu_/Swish_46) : f32 / [...] / [1,32,160,160]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 609.11 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ] NETWORK_NAME: Model0
[ INFO ] EXECUTION_DEVICES: ['CPU']
[ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ] MULTI_DEVICE_PRIORITIES: CPU
[ INFO ] CPU:
[ INFO ] AFFINITY: Affinity.CORE
[ INFO ] CPU_DENORMALS_OPTIMIZATION: False
[ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[ INFO ] DYNAMIC_QUANTIZATION_GROUP_SIZE: 32
[ INFO ] ENABLE_CPU_PINNING: True
[ INFO ] ENABLE_HYPER_THREADING: True
[ INFO ] EXECUTION_DEVICES: ['CPU']
[ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ] INFERENCE_NUM_THREADS: 24
[ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ] KV_CACHE_PRECISION: <Type: 'float16'>
[ INFO ] LOG_LEVEL: Level.NO
[ INFO ] MODEL_DISTRIBUTION_POLICY: set()
[ INFO ] NETWORK_NAME: Model0
[ INFO ] NUM_STREAMS: 6
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ] PERFORMANCE_HINT: THROUGHPUT
[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ] PERF_COUNT: NO
[ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ] MODEL_PRIORITY: Priority.MEDIUM
[ INFO ] LOADED_FROM_CACHE: False
[ INFO ] PERF_COUNT: False
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'x'!. This input will be filled with random values!
[ INFO ] Fill input 'x' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 26.42 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count: 3708 iterations
[ INFO ] Duration: 15029.82 ms
[ INFO ] Latency:
[ INFO ] Median: 24.08 ms
[ INFO ] Average: 24.20 ms
[ INFO ] Min: 18.24 ms
[ INFO ] Max: 40.66 ms
[ INFO ] Throughput: 246.71 FPS
The performance could be also improved by another OpenVINO method such
as async inference pipeline or preprocessing API.
Async Inference pipeline help to utilize the device more optimal. The
key advantage of the Async API is that when a device is busy with
inference, the application can perform other tasks in parallel (for
example, populating inputs or scheduling other requests) rather than
wait for the current inference to complete first. To understand how to
perform async inference using openvino, refer to Async API
tutorial
Preprocessing API enables making preprocessing a part of the model
reducing application code and dependency on additional image processing
libraries. The main advantage of Preprocessing API is that preprocessing
steps will be integrated into the execution graph and will be performed
on a selected device (CPU/GPU etc.) rather than always being executed on
CPU as part of an application. This will also improve selected device
utilization. For more information, refer to the overview of
Preprocessing API
tutorial. To
see, how it could be used with YOLOV8 object detection model, please,
see Convert and Optimize YOLOv8 real-time object detection with
OpenVINO tutorial
The following code runs model inference on a video:
importcollectionsimporttimeimportcv2fromIPythonimportdisplayimportnumpyasnpdefrun_instance_segmentation(source=0,flip=False,use_popup=False,skip_first_frames=0,model=seg_model,device=device.value,video_width:int=None,# if not set the original size is used):player=Noneov_config={}ifdevice!="CPU":model.reshape({0:[1,3,640,640]})if"GPU"indeviceor("AUTO"indeviceand"GPU"incore.available_devices):ov_config={"GPU_DISABLE_WINOGRAD_CONVOLUTION":"YES"}compiled_model=core.compile_model(model,device,ov_config)ifseg_model.predictorisNone:custom={"conf":0.25,"batch":1,"save":False,"mode":"predict"}# method defaultsargs={**seg_model.overrides,**custom}seg_model.predictor=seg_model._smart_load("predictor")(overrides=args,_callbacks=seg_model.callbacks)seg_model.predictor.setup_model(model=seg_model.model)seg_model.predictor.model.ov_compiled_model=compiled_modeltry:# Create a video player to play with target fps.player=VideoPlayer(source=source,flip=flip,fps=30,skip_first_frames=skip_first_frames)# Start capturing.player.start()ifuse_popup:title="Press ESC to Exit"cv2.namedWindow(winname=title,flags=cv2.WINDOW_GUI_NORMAL|cv2.WINDOW_AUTOSIZE)processing_times=collections.deque()whileTrue:# Grab the frame.frame=player.next()ifframeisNone:print("Source ended")breakifvideo_width:# If the frame is larger than video_width, reduce size to improve the performance.# If more, increase size for better demo expirience.scale=video_width/max(frame.shape)frame=cv2.resize(src=frame,dsize=None,fx=scale,fy=scale,interpolation=cv2.INTER_AREA,)# Get the results.input_image=np.array(frame)start_time=time.time()detections=seg_model(input_image,verbose=False)stop_time=time.time()frame=detections[0].plot()processing_times.append(stop_time-start_time)# Use processing times from last 200 frames.iflen(processing_times)>200:processing_times.popleft()_,f_width=frame.shape[:2]# Mean processing time [ms].processing_time=np.mean(processing_times)*1000fps=1000/processing_timecv2.putText(img=frame,text=f"Inference time: {processing_time:.1f}ms ({fps:.1f} FPS)",org=(20,40),fontFace=cv2.FONT_HERSHEY_COMPLEX,fontScale=f_width/1000,color=(0,0,255),thickness=1,lineType=cv2.LINE_AA,)# Use this workaround if there is flickering.ifuse_popup:cv2.imshow(winname=title,mat=frame)key=cv2.waitKey(1)# escape = 27ifkey==27:breakelse:# Encode numpy array to jpg._,encoded_img=cv2.imencode(ext=".jpg",img=frame,params=[cv2.IMWRITE_JPEG_QUALITY,100])# Create an IPython image.i=display.Image(data=encoded_img)# Display the image in this notebook.display.clear_output(wait=True)display.display(i)# ctrl-cexceptKeyboardInterrupt:print("Interrupted")# any different errorexceptRuntimeErrorase:print(e)finally:ifplayerisnotNone:# Stop capturing.player.stop()ifuse_popup:cv2.destroyAllWindows()
Use a webcam as the video input. By default, the primary webcam is set
with source=0. If you have multiple webcams, each one will be
assigned a consecutive number starting at 0. Set flip=True when
using a front-facing camera. Some web browsers, especially Mozilla
Firefox, may cause flickering. If you experience flickering,
set use_popup=True.
NOTE: To use this notebook with a webcam, you need to run the
notebook on a computer with a webcam. If you run the notebook on a
remote server (for example, in Binder or Google Colab service), the
webcam will not work. By default, the lower cell will run model
inference on a video file. If you want to try live inference on your
webcam set WEBCAM_INFERENCE=True