Post-Training Quantization with TensorFlow Classification Model

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This example demonstrates how to quantize the OpenVINO model that was created in 301-tensorflow-training-openvino.ipynb, to improve inference speed. Quantization is performed with Post-training Quantization with NNCF. A custom dataloader and metric will be defined, and accuracy and performance will be computed for the original IR model and the quantized model.

Table of contents:

Preparation

The notebook requires that the training notebook has been run and that the Intermediate Representation (IR) models are created. If the IR models do not exist, running the next cell will run the training notebook. This will take a while.

from pathlib import Path

import tensorflow as tf

model_xml = Path("model/flower/flower_ir.xml")
dataset_url = (
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
)
data_dir = Path(tf.keras.utils.get_file("flower_photos", origin=dataset_url, untar=True))

if not model_xml.exists():
    print("Executing training notebook. This will take a while...")
    %run 301-tensorflow-training-openvino.ipynb
2023-07-05 23:54:28.962752: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2023-07-05 23:54:28.997784: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-05 23:54:29.609276: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Executing training notebook. This will take a while...
3670
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
2023-07-05 23:54:31.178171: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
2023-07-05 23:54:31.493885: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
2023-07-05 23:54:31.494167: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
../_images/301-tensorflow-training-openvino-nncf-with-output_2_5.png
2023-07-05 23:54:31.947372: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
2023-07-05 23:54:31.947613: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
2023-07-05 23:54:32.077841: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
2023-07-05 23:54:32.078164: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
(32, 180, 180, 3)
(32,)
0.0 1.0
2023-07-05 23:54:32.897047: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [2936]
     [[{{node Placeholder/_0}}]]
2023-07-05 23:54:32.897375: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
../_images/301-tensorflow-training-openvino-nncf-with-output_2_9.png
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 sequential_1 (Sequential)   (None, 180, 180, 3)       0

 rescaling_2 (Rescaling)     (None, 180, 180, 3)       0

 conv2d_3 (Conv2D)           (None, 180, 180, 16)      448

 max_pooling2d_3 (MaxPooling  (None, 90, 90, 16)       0
 2D)

 conv2d_4 (Conv2D)           (None, 90, 90, 32)        4640

 max_pooling2d_4 (MaxPooling  (None, 45, 45, 32)       0
 2D)

 conv2d_5 (Conv2D)           (None, 45, 45, 64)        18496

 max_pooling2d_5 (MaxPooling  (None, 22, 22, 64)       0
 2D)

 dropout (Dropout)           (None, 22, 22, 64)        0

 flatten_1 (Flatten)         (None, 30976)             0

 dense_2 (Dense)             (None, 128)               3965056

 outputs (Dense)             (None, 5)                 645

=================================================================
Total params: 3,989,285
Trainable params: 3,989,285
Non-trainable params: 0
_________________________________________________________________
Epoch 1/15
2023-07-05 23:54:33.773069: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [2936]
     [[{{node Placeholder/_0}}]]
2023-07-05 23:54:33.773519: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2936]
     [[{{node Placeholder/_4}}]]
92/92 [==============================] - ETA: 0s - loss: 1.2943 - accuracy: 0.4486
2023-07-05 23:54:40.025734: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [734]
     [[{{node Placeholder/_0}}]]
2023-07-05 23:54:40.026032: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [734]
     [[{{node Placeholder/_0}}]]
92/92 [==============================] - 7s 66ms/step - loss: 1.2943 - accuracy: 0.4486 - val_loss: 1.0944 - val_accuracy: 0.5354
Epoch 2/15
92/92 [==============================] - 6s 63ms/step - loss: 1.0396 - accuracy: 0.5787 - val_loss: 0.9602 - val_accuracy: 0.6322
Epoch 3/15
92/92 [==============================] - 6s 64ms/step - loss: 0.9646 - accuracy: 0.6213 - val_loss: 0.9223 - val_accuracy: 0.6417
Epoch 4/15
92/92 [==============================] - 6s 64ms/step - loss: 0.8775 - accuracy: 0.6533 - val_loss: 0.8511 - val_accuracy: 0.6594
Epoch 5/15
92/92 [==============================] - 6s 64ms/step - loss: 0.8354 - accuracy: 0.6884 - val_loss: 0.8471 - val_accuracy: 0.6689
Epoch 6/15
92/92 [==============================] - 6s 64ms/step - loss: 0.7722 - accuracy: 0.7033 - val_loss: 0.8405 - val_accuracy: 0.6935
Epoch 7/15
92/92 [==============================] - 6s 64ms/step - loss: 0.7347 - accuracy: 0.7207 - val_loss: 0.8848 - val_accuracy: 0.6730
Epoch 8/15
92/92 [==============================] - 6s 63ms/step - loss: 0.6980 - accuracy: 0.7469 - val_loss: 0.7724 - val_accuracy: 0.6948
Epoch 9/15
92/92 [==============================] - 6s 64ms/step - loss: 0.6629 - accuracy: 0.7476 - val_loss: 0.7512 - val_accuracy: 0.7071
Epoch 10/15
92/92 [==============================] - 6s 63ms/step - loss: 0.6429 - accuracy: 0.7643 - val_loss: 0.7196 - val_accuracy: 0.7125
Epoch 11/15
92/92 [==============================] - 6s 64ms/step - loss: 0.5967 - accuracy: 0.7755 - val_loss: 0.7228 - val_accuracy: 0.7084
Epoch 12/15
92/92 [==============================] - 6s 63ms/step - loss: 0.5860 - accuracy: 0.7769 - val_loss: 0.7501 - val_accuracy: 0.7153
Epoch 13/15
92/92 [==============================] - 6s 64ms/step - loss: 0.5695 - accuracy: 0.7793 - val_loss: 0.7366 - val_accuracy: 0.7153
Epoch 14/15
92/92 [==============================] - 6s 63ms/step - loss: 0.5392 - accuracy: 0.7970 - val_loss: 0.7375 - val_accuracy: 0.7275
Epoch 15/15
92/92 [==============================] - 6s 64ms/step - loss: 0.5098 - accuracy: 0.8048 - val_loss: 0.6984 - val_accuracy: 0.7330
../_images/301-tensorflow-training-openvino-nncf-with-output_2_15.png
1/1 [==============================] - 0s 76ms/step
This image most likely belongs to sunflowers with a 99.23 percent confidence.
2023-07-05 23:56:03.289411: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'random_flip_input' with dtype float and shape [?,180,180,3]
     [[{{node random_flip_input}}]]
2023-07-05 23:56:03.376040: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.385907: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'random_flip_input' with dtype float and shape [?,180,180,3]
     [[{{node random_flip_input}}]]
2023-07-05 23:56:03.396762: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.403700: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.410703: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.421394: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.461681: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'sequential_1_input' with dtype float and shape [?,180,180,3]
     [[{{node sequential_1_input}}]]
2023-07-05 23:56:03.529355: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.549619: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'sequential_1_input' with dtype float and shape [?,180,180,3]
     [[{{node sequential_1_input}}]]
2023-07-05 23:56:03.588567: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,22,22,64]
     [[{{node inputs}}]]
2023-07-05 23:56:03.611996: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.685894: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.828047: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:03.965814: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,22,22,64]
     [[{{node inputs}}]]
2023-07-05 23:56:03.999799: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:04.028229: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
2023-07-05 23:56:04.074705: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,180,180,3]
     [[{{node inputs}}]]
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _update_step_xla while saving (showing 4 of 4). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: model/flower/saved_model/assets
INFO:tensorflow:Assets written to: model/flower/saved_model/assets
output/A_Close_Up_Photo_of_a_Dandelion.jpg:   0%|          | 0.00/21.7k [00:00<?, ?B/s]
(1, 180, 180, 3)
[1,180,180,3]
This image most likely belongs to dandelion with a 99.81 percent confidence.
../_images/301-tensorflow-training-openvino-nncf-with-output_2_22.png

Imports

The Post Training Quantization API is implemented in the nncf library.

import sys

import matplotlib.pyplot as plt
import numpy as np
import nncf
from openvino.runtime import Core
from openvino.runtime import serialize
from PIL import Image
from sklearn.metrics import accuracy_score

sys.path.append("../utils")
from notebook_utils import download_file
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino

Post-training Quantization with NNCF

NNCF provides a suite of advanced algorithms for Neural Networks inference optimization in OpenVINO with minimal accuracy drop.

Create a quantized model from the pre-trained FP32 model and the calibration dataset. The optimization process contains the following steps:

  1. Create a Dataset for quantization.

  2. Run nncf.quantize for getting an optimized model.

The validation dataset already defined in the training notebook.

img_height = 180
img_width = 180
val_dataset = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=1
)

for a, b in val_dataset:
    print(type(a), type(b))
    break
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
<class 'tensorflow.python.framework.ops.EagerTensor'> <class 'tensorflow.python.framework.ops.EagerTensor'>
2023-07-05 23:56:07.075279: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [734]
     [[{{node Placeholder/_4}}]]
2023-07-05 23:56:07.075533: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [734]
     [[{{node Placeholder/_4}}]]

The validation dataset can be reused in quantization process. But it returns a tuple (images, labels), whereas calibration_dataset should only return images. The transformation function helps to transform a user validation dataset to the calibration dataset.

def transform_fn(data_item):
    """
    The transformation function transforms a data item into model input data.
    This function should be passed when the data item cannot be used as model's input.
    """
    images, _ = data_item
    return images.numpy()


calibration_dataset = nncf.Dataset(val_dataset, transform_fn)

Download Intermediate Representation (IR) model.

ie = Core()
ir_model = ie.read_model(model_xml)

Use Basic Quantization Flow. To use the most advanced quantization flow that allows to apply 8-bit quantization to the model with accuracy control see Quantizing with accuracy control.

quantized_model = nncf.quantize(
    ir_model,
    calibration_dataset,
    subset_size=1000
)
Statistics collection:  73%|███████▎  | 734/1000 [00:04<00:01, 166.65it/s]
Biases correction: 100%|██████████| 5/5 [00:01<00:00,  3.99it/s]

Save quantized model to benchmark.

compressed_model_dir = Path("model/optimized")
compressed_model_dir.mkdir(parents=True, exist_ok=True)
compressed_model_xml = compressed_model_dir / "flower_ir.xml"
serialize(quantized_model, str(compressed_model_xml))

Compare Metrics

Define a metric to determine the performance of the model.

For this demo we define validate function to compute accuracy metrics.

def validate(model, validation_loader):
    """
    Evaluate model and compute accuracy metrics.

    :param model: Model to validate
    :param validation_loader: Validation dataset
    :returns: Accuracy scores
    """
    predictions = []
    references = []

    output = model.outputs[0]

    for images, target in validation_loader:
        pred = model(images.numpy())[output]

        predictions.append(np.argmax(pred, axis=1))
        references.append(target)

    predictions = np.concatenate(predictions, axis=0)
    references = np.concatenate(references, axis=0)

    scores = accuracy_score(references, predictions)

    return scores

Calculate accuracy for the original model and the quantized model.

original_compiled_model = ie.compile_model(model=ir_model, device_name="CPU")
quantized_compiled_model = ie.compile_model(model=quantized_model, device_name="CPU")

original_accuracy = validate(original_compiled_model, val_dataset)
quantized_accuracy = validate(quantized_compiled_model, val_dataset)

print(f"Accuracy of the original model: {original_accuracy:.3f}")
print(f"Accuracy of the quantized model: {quantized_accuracy:.3f}")
Accuracy of the original model: 0.733
Accuracy of the quantized model: 0.737

Compare file size of the models.

original_model_size = model_xml.with_suffix(".bin").stat().st_size / 1024
quantized_model_size = compressed_model_xml.with_suffix(".bin").stat().st_size / 1024

print(f"Original model size: {original_model_size:.2f} KB")
print(f"Quantized model size: {quantized_model_size:.2f} KB")
Original model size: 7791.65 KB
Quantized model size: 3897.08 KB

So, we can see that the original and quantized models have similar accuracy with a much smaller size of the quantized model.

Run Inference on Quantized Model

Copy the preprocess function from the training notebook and run inference on the quantized model with Inference Engine. See the OpenVINO API tutorial for more information about running inference with Inference Engine Python API.

def pre_process_image(imagePath, img_height=180):
    # Model input format
    n, c, h, w = [1, 3, img_height, img_height]
    image = Image.open(imagePath)
    image = image.resize((h, w), resample=Image.BILINEAR)

    # Convert to array and change data layout from HWC to CHW
    image = np.array(image)

    input_image = image.reshape((n, h, w, c))

    return input_image
# Get the names of the input and output layer
# model_pot = ie.read_model(model="model/optimized/flower_ir.xml")
input_layer = quantized_compiled_model.input(0)
output_layer = quantized_compiled_model.output(0)

# Get the class names: a list of directory names in alphabetical order
class_names = sorted([item.name for item in Path(data_dir).iterdir() if item.is_dir()])

# Run inference on an input image...
inp_img_url = (
    "https://upload.wikimedia.org/wikipedia/commons/4/48/A_Close_Up_Photo_of_a_Dandelion.jpg"
)
directory = "output"
inp_file_name = "A_Close_Up_Photo_of_a_Dandelion.jpg"
file_path = Path(directory)/Path(inp_file_name)
# Download the image if it does not exist yet
if not Path(inp_file_name).exists():
    download_file(inp_img_url, inp_file_name, directory=directory)

# Pre-process the image and get it ready for inference.
input_image = pre_process_image(imagePath=file_path)
print(f'input image shape: {input_image.shape}')
print(f'input layer shape: {input_layer.shape}')

res = quantized_compiled_model([input_image])[output_layer]

score = tf.nn.softmax(res[0])

# Show the results
image = Image.open(file_path)
plt.imshow(image)
print(
    "This image most likely belongs to {} with a {:.2f} percent confidence.".format(
        class_names[np.argmax(score)], 100 * np.max(score)
    )
)
'output/A_Close_Up_Photo_of_a_Dandelion.jpg' already exists.
input image shape: (1, 180, 180, 3)
input layer shape: [1,180,180,3]
This image most likely belongs to dandelion with a 99.82 percent confidence.
../_images/301-tensorflow-training-openvino-nncf-with-output_24_1.png

Compare Inference Speed

Measure inference speed with the OpenVINO Benchmark App.

Benchmark App is a command line tool that measures raw inference performance for a specified OpenVINO IR model. Run benchmark_app --help to see a list of available parameters. By default, Benchmark App tests the performance of the model specified with the -m parameter with asynchronous inference on CPU, for one minute. Use the -d parameter to test performance on a different device, for example an Intel integrated Graphics (iGPU), and -t to set the number of seconds to run inference. See the documentation for more information.

This tutorial uses a wrapper function from Notebook Utils. It prints the benchmark_app command with the chosen parameters.

In the next cells, inference speed will be measured for the original and quantized model on CPU. If an iGPU is available, inference speed will be measured for CPU+GPU as well. The number of seconds is set to 15.

Note

For the most accurate performance estimation, it is recommended to run benchmark_app in a terminal/command prompt after closing other applications.

# print the available devices on this system
print("Device information:")
print(ie.get_property("CPU", "FULL_DEVICE_NAME"))
if "GPU" in ie.available_devices:
    print(ie.get_property("GPU", "FULL_DEVICE_NAME"))
Device information:
Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
# Original model - CPU
! benchmark_app -m $model_xml -d CPU -t 15 -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.0.0-10926-b4452d56304-releases/2023/0
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2023.0.0-10926-b4452d56304-releases/2023/0
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 12.02 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     sequential_1_input (node: sequential_1_input) : f32 / [...] / [1,180,180,3]
[ INFO ] Model outputs:
[ INFO ]     outputs (node: sequential_2/outputs/BiasAdd) : f32 / [...] / [1,5]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     sequential_1_input (node: sequential_1_input) : u8 / [N,H,W,C] / [1,180,180,3]
[ INFO ] Model outputs:
[ INFO ]     outputs (node: sequential_2/outputs/BiasAdd) : f32 / [...] / [1,5]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 76.79 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: TensorFlow_Frontend_IR
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12
[ INFO ]   NUM_STREAMS: 12
[ INFO ]   AFFINITY: Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS: 24
[ INFO ]   PERF_COUNT: False
[ INFO ]   INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]   ENABLE_CPU_PINNING: True
[ INFO ]   SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   ENABLE_HYPER_THREADING: True
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'sequential_1_input'!. This input will be filled with random values!
[ INFO ] Fill input 'sequential_1_input' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 12 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 7.22 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            57276 iterations
[ INFO ] Duration:         15002.57 ms
[ INFO ] Latency:
[ INFO ]    Median:        2.90 ms
[ INFO ]    Average:       2.95 ms
[ INFO ]    Min:           1.67 ms
[ INFO ]    Max:           234.29 ms
[ INFO ] Throughput:   3817.75 FPS
# Quantized model - CPU
! benchmark_app -m $compressed_model_xml -d CPU -t 15 -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.0.0-10926-b4452d56304-releases/2023/0
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2023.0.0-10926-b4452d56304-releases/2023/0
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 12.35 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     sequential_1_input (node: sequential_1_input) : f32 / [...] / [1,180,180,3]
[ INFO ] Model outputs:
[ INFO ]     outputs (node: sequential_2/outputs/BiasAdd) : f32 / [...] / [1,5]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     sequential_1_input (node: sequential_1_input) : u8 / [N,H,W,C] / [1,180,180,3]
[ INFO ] Model outputs:
[ INFO ]     outputs (node: sequential_2/outputs/BiasAdd) : f32 / [...] / [1,5]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 54.95 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: TensorFlow_Frontend_IR
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12
[ INFO ]   NUM_STREAMS: 12
[ INFO ]   AFFINITY: Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS: 24
[ INFO ]   PERF_COUNT: False
[ INFO ]   INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]   ENABLE_CPU_PINNING: True
[ INFO ]   SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   ENABLE_HYPER_THREADING: True
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'sequential_1_input'!. This input will be filled with random values!
[ INFO ] Fill input 'sequential_1_input' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 12 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 2.06 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            178752 iterations
[ INFO ] Duration:         15001.22 ms
[ INFO ] Latency:
[ INFO ]    Median:        0.92 ms
[ INFO ]    Average:       0.92 ms
[ INFO ]    Min:           0.54 ms
[ INFO ]    Max:           4.90 ms
[ INFO ] Throughput:   11915.83 FPS

Benchmark on MULTI:CPU,GPU

With a recent Intel CPU, the best performance can often be achieved by doing inference on both the CPU and the iGPU, with OpenVINO’s Multi Device Plugin. It takes a bit longer to load a model on GPU than on CPU, so this benchmark will take a bit longer to complete than the CPU benchmark, when run for the first time. Benchmark App supports caching, by specifying the --cdir parameter. In the cells below, the model will cached to the model_cache directory.

# Original model - MULTI:CPU,GPU
if "GPU" in ie.available_devices:
    ! benchmark_app -m $model_xml -d MULTI:CPU,GPU -t 15 -api async
else:
    print("A supported integrated GPU is not available on this system.")
A supported integrated GPU is not available on this system.
# Quantized model - MULTI:CPU,GPU
if "GPU" in ie.available_devices:
    ! benchmark_app -m $compressed_model_xml -d MULTI:CPU,GPU -t 15 -api async
else:
    print("A supported integrated GPU is not available on this system.")
A supported integrated GPU is not available on this system.
# print the available devices on this system
print("Device information:")
print(ie.get_property("CPU", "FULL_DEVICE_NAME"))
if "GPU" in ie.available_devices:
    print(ie.get_property("GPU", "FULL_DEVICE_NAME"))
Device information:
Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz

Original IR model - CPU

benchmark_output = %sx benchmark_app -m $model_xml -t 15 -api async
# Remove logging info from benchmark_app output and show only the results
benchmark_result = benchmark_output[-8:]
print("\n".join(benchmark_result))
[ INFO ] Count:            58332 iterations
[ INFO ] Duration:         15005.08 ms
[ INFO ] Latency:
[ INFO ]    Median:        2.88 ms
[ INFO ]    Average:       2.89 ms
[ INFO ]    Min:           2.02 ms
[ INFO ]    Max:           8.94 ms
[ INFO ] Throughput:   3887.48 FPS

Quantized IR model - CPU

benchmark_output = %sx benchmark_app -m $compressed_model_xml -t 15 -api async
# Remove logging info from benchmark_app output and show only the results
benchmark_result = benchmark_output[-8:]
print("\n".join(benchmark_result))
[ INFO ] Count:            179124 iterations
[ INFO ] Duration:         15001.17 ms
[ INFO ] Latency:
[ INFO ]    Median:        0.92 ms
[ INFO ]    Average:       0.92 ms
[ INFO ]    Min:           0.56 ms
[ INFO ]    Max:           4.33 ms
[ INFO ] Throughput:   11940.67 FPS

Original IR model - MULTI:CPU,GPU

With a recent Intel CPU, the best performance can often be achieved by doing inference on both the CPU and the iGPU, with OpenVINO’s Multi Device Plugin. It takes a bit longer to load a model on GPU than on CPU, so this benchmark will take a bit longer to complete than the CPU benchmark.

if "GPU" in ie.available_devices:
    benchmark_output = %sx benchmark_app -m $model_xml -d MULTI:CPU,GPU -t 15 -api async
    # Remove logging info from benchmark_app output and show only the results
    benchmark_result = benchmark_output[-8:]
    print("\n".join(benchmark_result))
else:
    print("An GPU is not available on this system.")
An GPU is not available on this system.

Quantized IR model - MULTI:CPU,GPU

if "GPU" in ie.available_devices:
    benchmark_output = %sx benchmark_app -m $compressed_model_xml -d MULTI:CPU,GPU -t 15 -api async
    # Remove logging info from benchmark_app output and show only the results
    benchmark_result = benchmark_output[-8:]
    print("\n".join(benchmark_result))
else:
    print("An GPU is not available on this system.")
An GPU is not available on this system.