Post-Training Quantization of PyTorch models with NNCF

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

The goal of this notebook is to demonstrate how to use the Neural Network Compression Framework NNCF 8-bit quantization in post-training mode (without the fine-tuning pipeline) to optimize a PyTorch model for the high-speed inference via OpenVINO Toolkit. The optimization process contains the following steps:

  1. Evaluate the original model

  2. Transform the original model to a quantized one

  3. Export optimized and original models to ONNX and then to OpenVINO IR

  4. Compare perfomance of the obtained FP32 and INT8 models

This tutorial uses a ResNet-50 model pre-trained on Tiny ImageNet, which contains 100000 images of 200 classes (500 for each class) downsized to 64×64 colored images. We will demonstrate that only a tiny part of the dataset is needed for the post-training quantization, not demanding the fine-tuning of the model.

NOTE: This notebook requires that a C++ compiler is accessible on the default binary search path of the OS you are running the notebook.

Preparations

# On Windows, this script adds the directory that contains cl.exe to the PATH to enable PyTorch to find the
# required C++ tools. This code assumes that Visual Studio 2019 is installed in the default
# directory. If you have a different C++ compiler, please add the correct path to os.environ["PATH"]
# directly.

# Adding the path to os.environ["LIB"] is not always required - it depends on the system's configuration

import sys

if sys.platform == "win32":
    import distutils.command.build_ext
    import os
    from pathlib import Path

    VS_INSTALL_DIR = r"C:/Program Files (x86)/Microsoft Visual Studio"
    cl_paths = sorted(list(Path(VS_INSTALL_DIR).glob("**/Hostx86/x64/cl.exe")))
    if len(cl_paths) == 0:
        raise ValueError(
            "Cannot find Visual Studio. This notebook requires C++. If you installed "
            "a C++ compiler, please add the directory that contains cl.exe to "
            "`os.environ['PATH']`"
        )
    else:
        # If multiple versions of MSVC are installed, get the most recent version
        cl_path = cl_paths[-1]
        vs_dir = str(cl_path.parent)
        os.environ["PATH"] += f"{os.pathsep}{vs_dir}"
        # Code for finding the library dirs from
        # https://stackoverflow.com/questions/47423246/get-pythons-lib-path
        d = distutils.core.Distribution()
        b = distutils.command.build_ext.build_ext(d)
        b.finalize_options()
        os.environ["LIB"] = os.pathsep.join(b.library_dirs)
        print(f"Added {vs_dir} to PATH")

Preparing model files

NOTE: All NNCF logging messages below ERROR level (INFO and WARNING) are disabled to simplify the tutorial. For production use, it is recommended to enable logging, by removing set_log_level(logging.ERROR).

Imports

import logging
import os
import sys
import time
import warnings
import zipfile
from pathlib import Path
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from nncf import NNCFConfig  # Important - should be imported directly after torch
from nncf.common.utils.logger import set_log_level

set_log_level(logging.ERROR)  # Disables all NNCF info and warning messages
from nncf.torch import create_compressed_model, register_default_init_args
from openvino.runtime import Core
from torch.jit import TracerWarning

sys.path.append("../utils")
from notebook_utils import download_file
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/torch/__init__.py:23: UserWarning: NNCF provides best results with torch==1.9.1, while current torch version is 1.7.1+cpu - consider switching to torch==1.9.1
  warnings.warn("NNCF provides best results with torch=={bkc}, "
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/torch/dynamic_graph/patch_pytorch.py:163: UserWarning: Not patching unique_dim since it is missing in this version of PyTorch
  warnings.warn("Not patching {} since it is missing in this version of PyTorch".format(op_name))

Settings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
BASE_MODEL_NAME = "resnet50"
IMAGE_SIZE = [64, 64]

OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)

# Paths where PyTorch, ONNX and OpenVINO IR models will be stored
fp32_checkpoint_filename = Path(BASE_MODEL_NAME + "_fp32").with_suffix(".pth")
fp32_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_fp32")).with_suffix(".onnx")
fp32_ir_path = fp32_onnx_path.with_suffix(".xml")
int8_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_int8")).with_suffix(".onnx")
int8_ir_path = int8_onnx_path.with_suffix(".xml")


fp32_pth_url = "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/304_resnet50_fp32.pth"
download_file(fp32_pth_url, directory=MODEL_DIR, filename=fp32_checkpoint_filename)
Using cpu device
model/resnet50_fp32.pth:   0%|          | 0.00/91.5M [00:00<?, ?B/s]
PosixPath('/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/model/resnet50_fp32.pth')

Download and Prepare Tiny ImageNet dataset

  • 100k images of shape 3x64x64

  • 200 different classes: snake, spider, cat, truck, grasshopper, gull, etc.

def download_tiny_imagenet_200(
    output_dir: Path,
    url: str = "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    tarname: str = "tiny-imagenet-200.zip",
):
    archive_path = output_dir / tarname
    download_file(url, directory=output_dir, filename=tarname)
    zip_ref = zipfile.ZipFile(archive_path, "r")
    zip_ref.extractall(path=output_dir)
    zip_ref.close()
    print(f"Successfully downloaded and extracted dataset to: {output_dir}")


def create_validation_dir(dataset_dir: Path):
    VALID_DIR = dataset_dir / "val"
    val_img_dir = VALID_DIR / "images"

    fp = open(VALID_DIR / "val_annotations.txt", "r")
    data = fp.readlines()

    val_img_dict = {}
    for line in data:
        words = line.split("\t")
        val_img_dict[words[0]] = words[1]
    fp.close()

    for img, folder in val_img_dict.items():
        newpath = val_img_dir / folder
        if not newpath.exists():
            os.makedirs(newpath)
        if (val_img_dir / img).exists():
            os.rename(val_img_dir / img, newpath / img)


DATASET_DIR = OUTPUT_DIR / "tiny-imagenet-200"
if not DATASET_DIR.exists():
    download_tiny_imagenet_200(OUTPUT_DIR)
    create_validation_dir(DATASET_DIR)
output/tiny-imagenet-200.zip:   0%|          | 0.00/237M [00:00<?, ?B/s]
Successfully downloaded and extracted dataset to: output

Helpers classes and functions

These will help us to count accuracy and visualize validation process.

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name: str, fmt: str = ":f"):
        self.name = name
        self.fmt = fmt
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val: float, n: int = 1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    """Displays the progress of validation process"""

    def __init__(self, num_batches: int, meters: List[AverageMeter], prefix: str = ""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch: int):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches: int):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


def accuracy(output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

Validation function

def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module):
    """Compute the metrics using data from val_loader for the model"""
    batch_time = AverageMeter("Time", ":3.3f")
    top1 = AverageMeter("Acc@1", ":2.2f")
    top5 = AverageMeter("Acc@5", ":2.2f")
    progress = ProgressMeter(len(val_loader), [batch_time, top1, top5], prefix="Test: ")

    # switch to evaluate mode
    model.eval()
    model.to(device)

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print_frequency = 10
            if i % print_frequency == 0:
                progress.display(i)

        print(
            " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
        )
    return top1.avg

Create and load original uncompressed model

ResNet-50 from the torchivision repository is pre-trained on ImageNet with more prediction classes than Tiny ImageNet, so we adjusted the model by swapping the last FC layer to one with fewer output values.

def create_model(model_path: Path):
    """Creates the ResNet-50 model and loads the pretrained weights"""
    model = models.resnet50()
    # update the last FC layer for Tiny ImageNet number of classes
    NUM_CLASSES = 200
    model.fc = nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
    model.to(device)
    if model_path.exists():
        checkpoint = torch.load(str(model_path), map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"], strict=True)
    else:
        raise RuntimeError("There is no checkpoint to load")
    return model


model = create_model(MODEL_DIR / fp32_checkpoint_filename)

Create train and validation dataloaders

def create_dataloaders(batch_size: int = 128):
    """Creates train dataloader that is used for quantization initialization and validation dataloader for computing the model accruacy"""
    train_dir = DATASET_DIR / "train"
    val_dir = DATASET_DIR / "val" / "images"
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose(
            [
                transforms.Resize(IMAGE_SIZE),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    val_dataset = datasets.ImageFolder(
        val_dir,
        transforms.Compose(
            [transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), normalize]
        ),
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        sampler=None,
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )
    return train_loader, val_loader


train_loader, val_loader = create_dataloaders()

Now that the preparation of the training and validation pipelines and the model files within this notebook is done, we will walk through the major steps of performing actual post-training quantization with NNCF.

I. Evaluate the loaded model

acc1 = validate(val_loader, model)
print(f"Test accuracy of FP32 model: {acc1:.3f}")
Test: [ 0/79]   Time 0.487 (0.487)  Acc@1 81.25 (81.25) Acc@5 92.19 (92.19)
Test: [10/79]   Time 0.290 (0.308)  Acc@1 56.25 (66.97) Acc@5 86.72 (87.50)
Test: [20/79]   Time 0.276 (0.295)  Acc@1 67.97 (64.29) Acc@5 85.16 (87.35)
Test: [30/79]   Time 0.278 (0.290)  Acc@1 53.12 (62.37) Acc@5 77.34 (85.33)
Test: [40/79]   Time 0.270 (0.286)  Acc@1 67.19 (60.86) Acc@5 90.62 (84.51)
Test: [50/79]   Time 0.280 (0.285)  Acc@1 60.16 (60.80) Acc@5 88.28 (84.42)
Test: [60/79]   Time 0.306 (0.284)  Acc@1 66.41 (60.46) Acc@5 86.72 (83.79)
Test: [70/79]   Time 0.267 (0.285)  Acc@1 52.34 (60.21) Acc@5 80.47 (83.33)
 * Acc@1 60.740 Acc@5 83.960
Test accuracy of FP32 model: 60.740

Export the FP32 model to ONNX, which is supported by OpenVINO™ Toolkit, to benchmark it in comparison with the INT8 model.

dummy_input = torch.randn(1, 3, *IMAGE_SIZE).to(device)
torch.onnx.export(model, dummy_input, fp32_onnx_path, opset_version=10)
print(f"FP32 ONNX model was exported to {fp32_onnx_path}.")
FP32 ONNX model was exported to output/resnet50_fp32.onnx.

II. Create and initialize quantization

NNCF enables post-training quantization by adding the quantization layers into the model graph and then using a subset of the training dataset to initialize the parameters of these additional quantization layers. The framework is designed so that modifications to your original training code are minor. Quantization is the simplest scenario and requires only 3 modifications.

Configure NNCF parameters to specify compression

nncf_config_dict = {
    "input_info": {"sample_size": [1, 3, *IMAGE_SIZE]},
    "log_dir": str(OUTPUT_DIR),
    "compression": {
        "algorithm": "quantization",
        "initializer": {
            "range": {"num_init_samples": 15000},
            "batchnorm_adaptation": {"num_bn_adaptation_samples": 4000},
        },
    },
}

nncf_config = NNCFConfig.from_dict(nncf_config_dict)

Provide data loader to initialize the values of quantization ranges and determine which activation should be signed or unsigned from the collected statistics using a given number of samples.

nncf_config = register_default_init_args(nncf_config, train_loader)

Create a quantized model from a pre-trained FP32 model and configuration object.

compression_ctrl, model = create_compressed_model(model, nncf_config)

Evaluate the new model on the validation set after initialization of quantization. The accuracy should be close to the accuracy of the floating-point FP32 model for a simple case like the one we are demonstrating now.

acc1 = validate(val_loader, model)
print(f"Accuracy of initialized INT8 model: {acc1:.3f}")
Test: [ 0/79]   Time 0.924 (0.924)  Acc@1 81.25 (81.25) Acc@5 91.41 (91.41)
Test: [10/79]   Time 0.393 (0.442)  Acc@1 59.38 (68.47) Acc@5 87.50 (87.78)
Test: [20/79]   Time 0.391 (0.421)  Acc@1 67.97 (64.58) Acc@5 85.16 (87.46)
Test: [30/79]   Time 0.393 (0.412)  Acc@1 52.34 (62.32) Acc@5 74.22 (85.11)
Test: [40/79]   Time 0.391 (0.407)  Acc@1 67.97 (60.94) Acc@5 88.28 (84.13)
Test: [50/79]   Time 0.392 (0.405)  Acc@1 60.16 (60.71) Acc@5 85.94 (83.96)
Test: [60/79]   Time 0.392 (0.404)  Acc@1 65.62 (60.34) Acc@5 83.59 (83.31)
Test: [70/79]   Time 0.384 (0.402)  Acc@1 55.47 (60.18) Acc@5 78.91 (82.88)
 * Acc@1 60.750 Acc@5 83.450
Accuracy of initialized INT8 model: 60.750

Export INT8 model to ONNX

warnings.filterwarnings("ignore", category=TracerWarning)  # Ignore export warnings
warnings.filterwarnings("ignore", category=UserWarning)
compression_ctrl.export_model(int8_onnx_path)
print(f"INT8 ONNX model exported to {int8_onnx_path}.")
INT8 ONNX model exported to output/resnet50_int8.onnx.

III. Convert ONNX models to OpenVINO™ Intermediate Representation (IR)

Call the OpenVINO Model Optimizer tool to convert the ONNX model to OpenVINO IR, with FP16 precision. The models are saved to the current directory. We add the mean values to the model and scale the output with the standard deviation by –mean_values and –scale_values arguments. It is not necessary to normalize input data before propagating it through the network with these options.

See the Model Optimizer Developer Guide for more information about Model Optimizer.

Executing this command may take a while. There may be some errors or warnings in the output. Model Optimizer successfully converted the model to IR if the last lines of the output include: [ SUCCESS ] Generated IR version 11 model

input_shape = [1, 3, *IMAGE_SIZE]
if not fp32_ir_path.exists():
    !mo --input_model "$fp32_onnx_path" --input_shape "$input_shape" --mean_values "[123.675, 116.28 , 103.53]" --scale_values "[58.395, 57.12 , 57.375]" --data_type FP16 --output_dir "$OUTPUT_DIR"
    assert fp32_ir_path.exists(), "The IR of FP32 model wasn't created"
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_fp32.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output
    - IR output name:   resnet50_fp32
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1, 3, 64, 64]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  [123.675, 116.28 , 103.53]
    - Scale values:     [58.395, 57.12 , 57.375]
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_fp32.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_fp32.bin
[ SUCCESS ] Total execution time: 0.76 seconds.
[ SUCCESS ] Memory consumed: 266 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai
if not int8_ir_path.exists():
    !mo --input_model "$int8_onnx_path" --input_shape "$input_shape" --mean_values "[123.675, 116.28 , 103.53]" --scale_values "[58.395, 57.12 , 57.375]" --data_type FP16 --output_dir "$OUTPUT_DIR"
    assert int8_ir_path.exists(), "The IR of INT8 model wasn't created"
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_int8.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output
    - IR output name:   resnet50_int8
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1, 3, 64, 64]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  [123.675, 116.28 , 103.53]
    - Scale values:     [58.395, 57.12 , 57.375]
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_int8.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_int8.bin
[ SUCCESS ] Total execution time: 1.61 seconds.
[ SUCCESS ] Memory consumed: 270 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

IV. Compare perfomance of INT8 model and FP32 model in OpenVINO

Finally, we will measure the inference performance of the FP32 and INT8 models. To do this, we use Benchmark Tool - OpenVINO’s inference performance measurement tool. By default, Benchmark Tool runs inference for 60 seconds in asynchronous mode on CPU. It returns inference speed as latency (milliseconds per image) and throughput (frames per second) values.

NOTE: In this notebook we run benchmark_app for 15 seconds to give a quick indication of performance. For more accurate performance, we recommended running benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app –help to see an overview of all command line options.

def parse_benchmark_output(benchmark_output: str):
    """Prints the output from benchmark_app in human-readable format"""
    parsed_output = [line for line in benchmark_output if not (line.startswith(r"[") or line.startswith("  ") or line == "")]
    print(*parsed_output, sep='\n')


print('Benchmark FP32 model (IR)')
benchmark_output = ! benchmark_app -m "$fp32_ir_path" -d CPU -api async -t 15
parse_benchmark_output(benchmark_output)

print('Benchmark INT8 model (IR)')
benchmark_output = ! benchmark_app -m "$int8_ir_path" -d CPU -api async -t 15
parse_benchmark_output(benchmark_output)
Benchmark FP32 model (IR)
Count:          15912 iterations
Duration:       15009.10 ms
Latency:
Throughput: 1060.16 FPS
Benchmark INT8 model (IR)
Count:          47598 iterations
Duration:       15003.13 ms
Latency:
Throughput: 3172.54 FPS

Show CPU Information for reference:

ie = Core()
ie.get_property(device_name="CPU", name="FULL_DEVICE_NAME")
'Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz'