Quantization Aware Training with NNCF, using TensorFlow Framework

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

The goal of this notebook to demonstrate how to use the Neural Network Compression Framework NNCF 8-bit quantization to optimize a TensorFlow model for inference with OpenVINO Toolkit. The optimization process contains the following steps: * Transform the original FP32 model to INT8 * Use fine-tuning to restore the accuracy * Export optimized and original models to Frozen Graph and then to OpenVINO * Measure and compare the performance of models

For more advanced usage, please refer to these examples.

We selected the ResNet-18 model with Imagenette dataset. Imagenette is a subset of 10 easily classified classes from the Imagenet dataset. Using the smaller model and dataset will speed up training and download time.

Imports and Settings

Import NNCF and all auxiliary packages from your Python code. Set a name for the model, input image size, used batch size, and the learning rate. Also define paths where Frozen Graph and OpenVINO IR versions of the models will be stored.

NOTE: All NNCF logging messages below ERROR level (INFO and WARNING) are disabled to simplify the tutorial. For production use, it is recommended to enable logging, by removing set_log_level(logging.ERROR).

from pathlib import Path
import logging

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.keras import layers
from tensorflow.python.keras import models

from nncf import NNCFConfig
from nncf.tensorflow.helpers.model_creation import create_compressed_model
from nncf.tensorflow.initialization import register_default_init_args
from nncf.common.utils.logger import set_log_level

set_log_level(logging.ERROR)  # Disables all NNCF info and warning messages

MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
MODEL_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

BASE_MODEL_NAME = "ResNet-18"

fp32_h5_path = Path(MODEL_DIR / (BASE_MODEL_NAME + "_fp32")).with_suffix(".h5")
fp32_sm_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_fp32"))
fp32_ir_path = Path(OUTPUT_DIR / "saved_model").with_suffix(".xml")
int8_pb_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_int8")).with_suffix(".pb")
int8_pb_name = Path(BASE_MODEL_NAME + "_int8").with_suffix(".pb")
int8_ir_path = int8_pb_path.with_suffix(".xml")

BATCH_SIZE = 128
IMG_SIZE = (64, 64)  # Default Imagenet image size
NUM_CLASSES = 10  # For Imagenette dataset

LR = 1e-5

MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)  # From Imagenet dataset
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)  # From Imagenet dataset

fp32_pth_url = "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/305_resnet18_imagenette_fp32.h5"
_ = tf.keras.utils.get_file(fp32_h5_path.resolve(), fp32_pth_url)
print(f'Absolute path where the model weights are saved:\n {fp32_h5_path.resolve()}')
Downloading data from https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/305_resnet18_imagenette_fp32.h5
134602752/134602216 [==============================] - 31s 0us/step
Absolute path where the model weights are saved:
 /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/model/ResNet-18_fp32.h5

Dataset Preprocessing

Download and prepare Imagenette 160px dataset. - Number of classes: 10 - Download size: 94.18 MiB | Split | Examples | |————–|———-| | ‘train’ | 12,894 | | ‘validation’ | 500 |

datasets, datasets_info = tfds.load('imagenette/160px', shuffle_files=True, as_supervised=True, with_info=True)
train_dataset, validation_dataset = datasets['train'], datasets['validation']
fig = tfds.show_examples(train_dataset, datasets_info)
../_images/305-tensorflow-quantization-aware-training-with-output_4_0.png
def preprocessing(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = image - MEAN_RGB
    image = image / STDDEV_RGB
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label


train_dataset = (train_dataset.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                              .batch(BATCH_SIZE)
                              .prefetch(tf.data.experimental.AUTOTUNE))

validation_dataset = (validation_dataset.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                                        .batch(BATCH_SIZE)
                                        .prefetch(tf.data.experimental.AUTOTUNE))

Define a Floating-Point Model

def residual_conv_block(filters, stage, block, strides=(1, 1), cut='pre'):
    def layer(input_tensor):
        x = layers.BatchNormalization(epsilon=2e-5)(input_tensor)
        x = layers.Activation('relu')(x)

        # defining shortcut connection
        if cut == 'pre':
            shortcut = input_tensor
        elif cut == 'post':
            shortcut = layers.Conv2D(filters, (1, 1), strides=strides, kernel_initializer='he_uniform',
                                     use_bias=False)(x)

        # continue with convolution layers
        x = layers.ZeroPadding2D(padding=(1, 1))(x)
        x = layers.Conv2D(filters, (3, 3), strides=strides, kernel_initializer='he_uniform', use_bias=False)(x)

        x = layers.BatchNormalization(epsilon=2e-5)(x)
        x = layers.Activation('relu')(x)
        x = layers.ZeroPadding2D(padding=(1, 1))(x)
        x = layers.Conv2D(filters, (3, 3), kernel_initializer='he_uniform', use_bias=False)(x)

        # add residual connection
        x = layers.Add()([x, shortcut])
        return x

    return layer


def ResNet18(input_shape=None):
    """Instantiates the ResNet18 architecture."""
    img_input = layers.Input(shape=input_shape, name='data')

    # ResNet18 bottom
    x = layers.BatchNormalization(epsilon=2e-5, scale=False)(img_input)
    x = layers.ZeroPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(64, (7, 7), strides=(2, 2), kernel_initializer='he_uniform', use_bias=False)(x)
    x = layers.BatchNormalization(epsilon=2e-5)(x)
    x = layers.Activation('relu')(x)
    x = layers.ZeroPadding2D(padding=(1, 1))(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='valid')(x)

    # ResNet18 body
    repetitions = (2, 2, 2, 2)
    for stage, rep in enumerate(repetitions):
        for block in range(rep):
            filters = 64 * (2 ** stage)
            if block == 0 and stage == 0:
                x = residual_conv_block(filters, stage, block, strides=(1, 1), cut='post')(x)
            elif block == 0:
                x = residual_conv_block(filters, stage, block, strides=(2, 2), cut='post')(x)
            else:
                x = residual_conv_block(filters, stage, block, strides=(1, 1), cut='pre')(x)
    x = layers.BatchNormalization(epsilon=2e-5)(x)
    x = layers.Activation('relu')(x)

    # ResNet18 top
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(NUM_CLASSES)(x)
    x = layers.Activation('softmax')(x)

    # Create model
    model = models.Model(img_input, x)

    return model
IMG_SHAPE = IMG_SIZE + (3,)
model = ResNet18(input_shape=IMG_SHAPE)

Pre-train Floating-Point Model

Using NNCF for model compression assumes that the user has a pre-trained model and a training pipeline.

NOTE For the sake of simplicity of the tutorial, we propose to skip FP32 model training and load the weights that are provided.

# Load the floating-point weights
model.load_weights(fp32_h5_path)

# Compile the floating-point model
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
              metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc@1')])

# Validate the floating-point model
test_loss, test_acc = model.evaluate(validation_dataset,
                                     callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1']))
print(f"\nAccuracy of FP32 model: {test_acc:.3f}")
4/4 [==============================] - 1s 145ms/sample - loss: 1.0358 - acc@1: 0.7820

Accuracy of FP32 model: 0.782

Save the floating-point model to the saved model, which will be later used for conversion to OpenVINO IR and further performance measurement.

model.save(fp32_sm_path)
print(f'Absolute path where the model is saved:\n {fp32_sm_path.resolve()}')
INFO:tensorflow:Assets written to: output/ResNet-18_fp32/assets
INFO:tensorflow:Assets written to: output/ResNet-18_fp32/assets
Absolute path where the model is saved:
 /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/ResNet-18_fp32

Create and Initialize Quantization

NNCF enables compression-aware training by integrating into regular training pipelines. The framework is designed so that modifications to your original training code are minor. Quantization is the simplest scenario and requires only 3 modifications.

  1. Configure NNCF parameters to specify compression

nncf_config_dict = {
    "input_info": {"sample_size": [1, 3] + list(IMG_SIZE)},
    "log_dir": str(OUTPUT_DIR),  # log directory for NNCF-specific logging outputs
    "compression": {
        "algorithm": "quantization",  # specify the algorithm here
    },
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
  1. Provide data loader to initialize the values of quantization ranges and determine which activation should be signed or unsigned from the collected statistics using a given number of samples.

nncf_config = register_default_init_args(nncf_config=nncf_config,
                                         data_loader=train_dataset,
                                         batch_size=BATCH_SIZE)
  1. Create a wrapped model ready for compression fine-tuning from a pre-trained FP32 model and configuration object.

compression_ctrl, model = create_compressed_model(model, nncf_config)

Evaluate the new model on the validation set after initialization of quantization. The accuracy should be not far from the accuracy of the floating-point FP32 model for a simple case like the one we are demonstrating now.

# Compile the int8 model
model.compile(optimizer=tf.keras.optimizers.Adam(lr=LR),
              loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
              metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc@1')])

# Validate the int8 model
test_loss, test_acc = model.evaluate(validation_dataset,
                                     callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1']))
print(f"\nAccuracy of INT8 model after initialization: {test_acc:.3f}")
4/4 [==============================] - 1s 207ms/sample - loss: 1.0291 - acc@1: 0.7920

Accuracy of INT8 model after initialization: 0.792

Fine-tune the Compressed Model

At this step, a regular fine-tuning process is applied to restore accuracy drop. Normally, several epochs of tuning are required with a small learning rate, the same that is usually used at the end of the training of the original model. No other changes in the training pipeline are required. Here is a simple example.

# Train the int8 model
model.fit(train_dataset,
          epochs=1)

# Validate the int8 model
test_loss, test_acc = model.evaluate(validation_dataset,
                                     callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1']))
print(f"\nAccuracy of INT8 model after fine-tuning: {test_acc:.3f}")
101/101 [==============================] - 102s 983ms/step - loss: 0.7918 - acc@1: 0.8838
4/4 [==============================] - 0s 74ms/sample - loss: 0.9964 - acc@1: 0.8040

Accuracy of INT8 model after fine-tuning: 0.804

Save the INT8 model to the frozen graph (saved model does not work with quantized model for now). Frozen graph will be later used for conversion to OpenVINO IR and further performance measurement.

compression_ctrl.export_model(int8_pb_path, 'frozen_graph')
print(f'Absolute path where the int8 model is saved:\n {int8_pb_path.resolve()}')
Absolute path where the int8 model is saved:
 /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/ResNet-18_int8.pb

Export Frozen Graph Models to OpenVINO™ Intermediate Representation (IR)

Call the OpenVINO Model Optimizer tool to convert the Saved Model and Frozen Graph models to OpenVINO IR. The models are saved to the current directory.

See the Model Optimizer Developer Guide for more information about Model Optimizer.

Executing this command may take a while. There may be some errors or warnings in the output. Model Optimization successfully export to IR if the last lines of the output include: [ SUCCESS ] Generated IR version 10 model

!mo --framework=tf --input_shape=[1,64,64,3] --input=data --saved_model_dir=$fp32_sm_path --output_dir=$OUTPUT_DIR
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  None
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output
    - IR output name:   saved_model
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     data
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,64,64,3]
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP32
    - Enable fusing:    True
    - Enable grouped convolutions fusing:   True
    - Move mean values to preprocess section:   None
    - Reverse input channels:   False
TensorFlow specific parameters:
    - Input model in text protobuf format:  False
    - Path to model dump for TensorBoard:   None
    - List of shared libraries with TensorFlow custom layers implementation:    None
    - Update the configuration file with input/output node names:   None
    - Use configuration file used to generate the model with Object Detection API:  None
    - Use the config file:  None
    - Inference Engine found in:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
Inference Engine version:   2021.4.2-3976-0943ed67223-refs/pull/539/head
Model Optimizer version:    2021.4.2-3976-0943ed67223-refs/pull/539/head
2022-02-18 23:08:02.643719: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py:22: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
2022-02-18 23:08:03.668911: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-02-18 23:08:03.669572: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2022-02-18 23:08:03.699827: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties:
pciBusID: 0000:17:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.69GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-18 23:08:03.701292: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties:
pciBusID: 0000:65:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-18 23:08:03.701324: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2022-02-18 23:08:03.704323: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2022-02-18 23:08:03.704403: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2022-02-18 23:08:03.706308: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2022-02-18 23:08:03.706652: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2022-02-18 23:08:03.708779: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2022-02-18 23:08:03.709555: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2022-02-18 23:08:03.709779: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH:
2022-02-18 23:08:03.709794: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2022-02-18 23:08:03.710051: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-02-18 23:08:03.711254: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-02-18 23:08:03.711278: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2022-02-18 23:08:03.711284: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]
2022-02-18 23:08:05.064346: I tensorflow/core/grappler/devices.cc:69] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 2
2022-02-18 23:08:05.064456: I tensorflow/core/grappler/clusters/single_machine.cc:356] Starting new session
2022-02-18 23:08:05.064790: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-02-18 23:08:05.157431: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties:
pciBusID: 0000:17:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.69GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-18 23:08:05.158033: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties:
pciBusID: 0000:65:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-18 23:08:05.158064: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2022-02-18 23:08:05.158088: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2022-02-18 23:08:05.158097: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2022-02-18 23:08:05.158104: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2022-02-18 23:08:05.158112: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2022-02-18 23:08:05.158119: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2022-02-18 23:08:05.158126: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2022-02-18 23:08:05.158180: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH:
2022-02-18 23:08:05.158187: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1757] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2022-02-18 23:08:05.158211: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2022-02-18 23:08:05.158218: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]      0 1
2022-02-18 23:08:05.158224: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0:   N N
2022-02-18 23:08:05.158227: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 1:   N N
2022-02-18 23:08:05.178571: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 3499910000 Hz
2022-02-18 23:08:05.205937: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:928] Optimization results for grappler item: graph_to_optimize
  function_optimizer: Graph size after: 410 nodes (308), 613 edges (511), time = 14.47ms.
  function_optimizer: function_optimizer did nothing. time = 0.239ms.

[ SUCCESS ] Generated IR version 10 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/saved_model.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/saved_model.bin
[ SUCCESS ] Total execution time: 9.54 seconds.
[ SUCCESS ] Memory consumed: 1061 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
!mo --framework=tf --input_shape=[1,64,64,3] --input=Placeholder --input_model=$int8_pb_path --output_dir=$OUTPUT_DIR
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/ResNet-18_int8.pb
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output
    - IR output name:   ResNet-18_int8
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     Placeholder
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,64,64,3]
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP32
    - Enable fusing:    True
    - Enable grouped convolutions fusing:   True
    - Move mean values to preprocess section:   None
    - Reverse input channels:   False
TensorFlow specific parameters:
    - Input model in text protobuf format:  False
    - Path to model dump for TensorBoard:   None
    - List of shared libraries with TensorFlow custom layers implementation:    None
    - Update the configuration file with input/output node names:   None
    - Use configuration file used to generate the model with Object Detection API:  None
    - Use the config file:  None
    - Inference Engine found in:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
Inference Engine version:   2021.4.2-3976-0943ed67223-refs/pull/539/head
Model Optimizer version:    2021.4.2-3976-0943ed67223-refs/pull/539/head
2022-02-18 23:08:13.221914: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py:22: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
[ SUCCESS ] Generated IR version 10 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/ResNet-18_int8.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/305-tensorflow-quantization-aware-training/output/ResNet-18_int8.bin
[ SUCCESS ] Total execution time: 17.61 seconds.
[ SUCCESS ] Memory consumed: 657 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*

Benchmark Model Performance by Computing Inference Time

Finally, we will measure the inference performance of the FP32 and INT8 models. To do this, we use Benchmark Tool - OpenVINO’s inference performance measurement tool. By default, Benchmark Tool runs inference for 60 seconds in asynchronous mode on CPU. It returns inference speed as latency (milliseconds per image) and throughput (frames per second) values.

NOTE: In this notebook we run benchmark_app for 15 seconds to give a quick indication of performance. For more accurate performance, we recommended running benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app –help to see an overview of all command line options.

def parse_benchmark_output(benchmark_output):
    parsed_output = [line for line in benchmark_output if not (line.startswith(r"[") or line.startswith("  ") or line == "")]
    print(*parsed_output, sep='\n')


print('Benchmark FP32 model (IR)')
benchmark_output = ! benchmark_app -m $fp32_ir_path -d CPU -api async -t 15
parse_benchmark_output(benchmark_output)

print('\nBenchmark INT8 model (IR)')
benchmark_output = ! benchmark_app -m $int8_ir_path -d CPU -api async -t 15
parse_benchmark_output(benchmark_output)
Benchmark FP32 model (IR)
Count:      36834 iterations
Duration:   15004.54 ms
Latency:    2.36 ms
Throughput: 2454.86 FPS

Benchmark INT8 model (IR)
Count:      128436 iterations
Duration:   15001.13 ms
Latency:    0.67 ms
Throughput: 8561.75 FPS

Show CPU Information for reference

from openvino.inference_engine import IECore

ie = IECore()
ie.get_metric(device_name='CPU', metric_name="FULL_DEVICE_NAME")
'Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz'