INT8 Quantization with Post-training Optimization Tool (POT) in Simplified Mode tutorial

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This tutorial shows how to quantize a ResNet20 image classification model, trained on CIFAR10 dataset, using the Simplified Mode of OpenVINO Post-Training Optimization Tool (POT).

Simplified Mode is designed to make the data preparation step easier, before model optimization. The mode is represented by an implementation of the engine interface in the POT API. It enables reading data from an arbitrary folder specified by the user. Currently, Simplified Mode is available only for image data in PNG or JPEG formats, stored in a single folder.

Note: This mode cannot be used with the accuracy-aware method. It is not possible to control accuracy after optimization using this mode. However, Simplified Mode can be useful for estimating performance improvements when optimizing models.

This tutorial includes the following steps:

  • Downloading and saving the CIFAR10 dataset

  • Preparing the model for quantization

  • Compressing the prepared model

  • Measuring and comparing the performance of the original and quantized models

  • Demonstrating the use of the quantized model for image classification

import os
from pathlib import Path
import warnings

import torch
from torchvision import transforms as T
from torchvision.datasets import CIFAR10

import matplotlib.pyplot as plt
import numpy as np

from openvino.runtime import Core, Tensor

warnings.filterwarnings("ignore")

# Set the data and model directories
MODEL_DIR = 'model'
CALIB_DIR = 'calib'
CIFAR_DIR = 'cifar'
CALIB_SET_SIZE = 300
MODEL_NAME = 'resnet20'

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CALIB_DIR, exist_ok=True)

Prepare the calibration dataset

The following steps are required to prepare the calibration dataset: - Download CIFAR10 dataset from Torchvision.datasets repository - Save the selected number of elements from this dataset as .png images in a separate folder

transform = T.Compose([T.ToTensor()])
dataset = CIFAR10(root=CIFAR_DIR, train=False, transform=transform, download=True)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar/cifar-10-python.tar.gz
0it [00:00, ?it/s]
Extracting cifar/cifar-10-python.tar.gz to cifar
pil_converter = T.ToPILImage(mode="RGB")

for idx, info in enumerate(dataset):
    im = info[0]
    if idx >= CALIB_SET_SIZE:
        break
    label = info[1]
    pil_converter(im.squeeze(0)).save(Path(CALIB_DIR) / f'{label}_{idx}.png')

Prepare the Model

Model preparation includes the following steps:, - Download PyTorch model from Torchvision repository, - Convert the model to ONNX format, - Run OpenVINO Model Optimizer tool to convert ONNX to OpenVINO Intermediate Representation (IR)

model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
dummy_input = torch.randn(1, 3, 32, 32)

onnx_model_path = Path(MODEL_DIR) / '{}.onnx'.format(MODEL_NAME)
ir_model_xml = onnx_model_path.with_suffix('.xml')
ir_model_bin = onnx_model_path.with_suffix('.bin')

torch.onnx.export(model, dummy_input, onnx_model_path)
Using cache found in /opt/home/k8sworker/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master

Now we convert this model into the OpenVINO IR using the Model Optimizer:

!mo --framework=onnx --data_type=FP32 --input_shape=[1,3,32,32] -m $onnx_model_path  --output_dir $MODEL_DIR
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/114-quantization-simplified-mode/model/resnet20.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/114-quantization-simplified-mode/model
    - IR output name:   resnet20
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,3,32,32]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP32
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/114-quantization-simplified-mode/model/resnet20.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/114-quantization-simplified-mode/model/resnet20.bin
[ SUCCESS ] Total execution time: 0.37 seconds.
[ SUCCESS ] Memory consumed: 73 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Compression stage

Compress the model with the following command:

pot -q default -m <path_to_xml> -w <path_to_bin> --engine simplified --data-source <path_to_data>

!pot -q default -m $ir_model_xml -w $ir_model_bin --engine simplified --data-source $CALIB_DIR --output-dir compressed --direct-dump --name $MODEL_NAME
INFO:openvino.tools.pot.app.run:Output log dir: compressed
INFO:openvino.tools.pot.app.run:Creating pipeline:
 Algorithm: DefaultQuantization
 Parameters:
    preset                     : performance
    stat_subset_size           : 300
    target_device              : ANY
    model_type                 : None
    dump_intermediate_model    : False
    inplace_statistics         : True
    exec_log_dir               : compressed
 ===========================================================================
INFO:openvino.tools.pot.data_loaders.image_loader:Layout value is set [N,C,H,W]
INFO:openvino.tools.pot.pipeline.pipeline:Inference Engine version:                2022.1.0-7019-cdb9bec7210-releases/2022/1
INFO:openvino.tools.pot.pipeline.pipeline:Model Optimizer version:                 2022.1.0-7019-cdb9bec7210-releases/2022/1
INFO:openvino.tools.pot.pipeline.pipeline:Post-Training Optimization Tool version: 2022.1.0-7019-cdb9bec7210-releases/2022/1
INFO:openvino.tools.pot.statistics.collector:Start computing statistics for algorithms : DefaultQuantization
INFO:openvino.tools.pot.statistics.collector:Computing statistics finished
INFO:openvino.tools.pot.pipeline.pipeline:Start algorithm: DefaultQuantization
INFO:openvino.tools.pot.algorithms.quantization.default.algorithm:Start computing statistics for algorithm : ActivationChannelAlignment
INFO:openvino.tools.pot.algorithms.quantization.default.algorithm:Computing statistics finished
INFO:openvino.tools.pot.algorithms.quantization.default.algorithm:Start computing statistics for algorithms : MinMaxQuantization,FastBiasCorrection
INFO:openvino.tools.pot.algorithms.quantization.default.algorithm:Computing statistics finished
INFO:openvino.tools.pot.pipeline.pipeline:Finished: DefaultQuantization
 ===========================================================================

Compare Performance of the Original and Quantized Models

Finally, we will measure the inference performance of the FP32 and INT8 models. To do this, we use Benchmark Tool - an inference performance measurement tool for OpenVINO.

NOTE: For more accurate performance, we recommended running benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app –help to see an overview of all command line options.

optimized_model_path = Path('compressed/optimized')
optimized_model_xml = optimized_model_path / '{}.xml'.format(MODEL_NAME)
optimized_model_bin = optimized_model_path / '{}.bin'.format(MODEL_NAME)
# Inference FP32 model (IR)
!benchmark_app -m $ir_model_xml -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading OpenVINO
[ WARNING ] PerformanceMode was not explicitly specified in command line. Device CPU performance hint will be set to THROUGHPUT.
[ INFO ] OpenVINO:
         API version............. 2022.1.0-7019-cdb9bec7210-releases/2022/1
[ INFO ] Device info
         CPU
         openvino_intel_cpu_plugin version 2022.1
         Build................... 2022.1.0-7019-cdb9bec7210-releases/2022/1

[Step 3/11] Setting device configuration
[ WARNING ] -nstreams default value is determined automatically for CPU device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 4/11] Reading network files
[ INFO ] Read model took 9.73 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model input 'input.1' precision u8, dimensions ([N,C,H,W]): 1 3 32 32
[ INFO ] Model output '208' precision f32, dimensions ([...]): 1 10
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 59.86 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] DEVICE: CPU
[ INFO ]   AVAILABLE_DEVICES  , ['']
[ INFO ]   RANGE_FOR_ASYNC_INFER_REQUESTS  , (1, 1, 1)
[ INFO ]   RANGE_FOR_STREAMS  , (1, 24)
[ INFO ]   FULL_DEVICE_NAME  , Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
[ INFO ]   OPTIMIZATION_CAPABILITIES  , ['WINOGRAD', 'FP32', 'FP16', 'INT8', 'BIN', 'EXPORT_IMPORT']
[ INFO ]   CACHE_DIR  ,
[ INFO ]   NUM_STREAMS  , 6
[ INFO ]   INFERENCE_NUM_THREADS  , 0
[ INFO ]   PERF_COUNT  , False
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS  , 0
[Step 9/11] Creating infer requests and preparing input data
[ INFO ] Create 6 infer requests took 1.81 ms
[ WARNING ] No input files were given for input 'input.1'!. This input will be filled with random values!
[ INFO ] Fill input 'input.1' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests using 6 streams for CPU, inference only: True, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 3.99 ms
[Step 11/11] Dumping statistics report
Count:          889998 iterations
Duration:       60000.59 ms
Latency:
    Median:     0.38 ms
    AVG:        0.39 ms
    MIN:        0.25 ms
    MAX:        6.49 ms
Throughput: 14833.15 FPS
# Inference INT8 model (IR)
!benchmark_app -m $optimized_model_xml -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading OpenVINO
[ WARNING ] PerformanceMode was not explicitly specified in command line. Device CPU performance hint will be set to THROUGHPUT.
[ INFO ] OpenVINO:
         API version............. 2022.1.0-7019-cdb9bec7210-releases/2022/1
[ INFO ] Device info
         CPU
         openvino_intel_cpu_plugin version 2022.1
         Build................... 2022.1.0-7019-cdb9bec7210-releases/2022/1

[Step 3/11] Setting device configuration
[ WARNING ] -nstreams default value is determined automatically for CPU device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 4/11] Reading network files
[ INFO ] Read model took 7.85 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model input 'input.1' precision u8, dimensions ([N,C,H,W]): 1 3 32 32
[ INFO ] Model output '208' precision f32, dimensions ([...]): 1 10
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 72.71 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] DEVICE: CPU
[ INFO ]   AVAILABLE_DEVICES  , ['']
[ INFO ]   RANGE_FOR_ASYNC_INFER_REQUESTS  , (1, 1, 1)
[ INFO ]   RANGE_FOR_STREAMS  , (1, 24)
[ INFO ]   FULL_DEVICE_NAME  , Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
[ INFO ]   OPTIMIZATION_CAPABILITIES  , ['WINOGRAD', 'FP32', 'FP16', 'INT8', 'BIN', 'EXPORT_IMPORT']
[ INFO ]   CACHE_DIR  ,
[ INFO ]   NUM_STREAMS  , 6
[ INFO ]   INFERENCE_NUM_THREADS  , 0
[ INFO ]   PERF_COUNT  , False
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS  , 0
[Step 9/11] Creating infer requests and preparing input data
[ INFO ] Create 6 infer requests took 1.70 ms
[ WARNING ] No input files were given for input 'input.1'!. This input will be filled with random values!
[ INFO ] Fill input 'input.1' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests using 6 streams for CPU, inference only: True, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 0.61 ms
[Step 11/11] Dumping statistics report
Count:          1353888 iterations
Duration:       60000.33 ms
Latency:
    Median:     0.24 ms
    AVG:        0.25 ms
    MIN:        0.17 ms
    MAX:        3.49 ms
Throughput: 22564.68 FPS

Demonstration of the results

This section demonstrates how to use the compressed model by running the optimized model on a subset of images from the CIFAR10 dataset and shows predictions using the model.

The first step is to load the model:

ie = Core()

compiled_model = ie.compile_model(str(optimized_model_xml), "AUTO")
# define all possible labels from CIFAR10
labels_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
all_images = []
all_labels = []

# get all images and their labels
for batch in dataset:
    all_images.append(torch.unsqueeze(batch[0], 0))
    all_labels.append(batch[1])

This section defines the function that shows the images and their labels using the indexes and two lists created in the previous step:

def plot_pictures(indexes: list, images=all_images, labels=all_labels):
    """Plot images with the specified indexes.
    :param indexes: a list of indexes of images to be displayed.
    :param images: a list of images from the dataset.
    :param labels: a list of labels for each image.
    """
    num_pics = len(indexes)
    _, axarr = plt.subplots(1, num_pics)
    for idx, im_idx in enumerate(indexes):
        assert idx < 10000, 'Cannot get such index, there are only 10000'
        pic = np.rollaxis(images[im_idx].squeeze().numpy(), 0, 3)
        axarr[idx].imshow(pic)
        axarr[idx].set_title(labels_names[labels[im_idx]])

In this section we define a function that uses the optimized model to obtain predictions for the selected images:

def infer_on_images(net, indexes: list, images=all_images):
    """ Inference model on a set of images.
    :param net: model on which do inference
    :param indexes: a list of indexes of images to infer on.
    :param images: a list of images from the dataset.
    """
    predicted_labels = []
    infer_request = net.create_infer_request()
    for idx in indexes:
        assert idx < 10000, 'Cannot get such index, there are only 10000'
        input_tensor = Tensor(array=images[idx].detach().numpy(), shared_memory=True)
        infer_request.set_input_tensor(input_tensor)
        infer_request.start_async()
        infer_request.wait()
        output = infer_request.get_output_tensor()
        result = list(output.data)
        result = labels_names[np.argmax(result[0])]
        predicted_labels.append(result)
    return predicted_labels
indexes_to_infer = [0, 1, 2]  # to plot specify indexes

plot_pictures(indexes_to_infer)

results_quanized = infer_on_images(compiled_model, indexes_to_infer)

print(f"Image labels using the quantized model : {results_quanized}.")
Image labels using the quantized model : ['cat', 'ship', 'ship'].
../_images/114-quantization-simplified-mode-with-output_22_1.png