Quantize NLP models with Post-Training Optimization Tool ​in OpenVINO™

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This tutorial demonstrates how to apply INT8 quantization to the Natural Language Processing model known as BERT, using the Post-Training Optimization Tool API (part of the OpenVINO Toolkit). A fine-tuned HuggingFace BERT PyTorch model, trained on the Microsoft Research Paraphrase Corpus (MRPC), will be used. The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:

  • Download and prepare the BERT model and MRPC dataset.

  • Define data loading and accuracy validation functionality.

  • Prepare the model for quantization.

  • Run optimization pipeline.

  • Load and test quantized model.

  • Compare the performance of the original, converted and quantized models.

Imports

import os
import sys
import time
import warnings
from pathlib import Path
from zipfile import ZipFile

import numpy as np
import torch
from addict import Dict
from compression.api import DataLoader as POTDataLoader
from compression.api import Metric
from compression.engines.ie_engine import IEEngine
from compression.graph import load_model, save_model
from compression.graph.model_utils import compress_model_weights
from compression.pipeline.initializer import create_pipeline
from openvino import runtime as ov
from torch.utils.data import TensorDataset
from transformers import BertForSequenceClassification, BertTokenizer
from transformers import (
    glue_convert_examples_to_features as convert_examples_to_features,
)
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors

sys.path.append("../utils")
from notebook_utils import download_file

Settings

# Set the data and model directories, source URL and the filename of the model.
DATA_DIR = "data"
MODEL_DIR = "model"
MODEL_LINK = "https://download.pytorch.org/tutorial/MRPC.zip"
FILE_NAME = MODEL_LINK.split("/")[-1]

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

Prepare the Model

Perform the following: - Download and unpack pre-trained BERT model for MRPC by PyTorch. - Convert the model to the ONNX. - Run Model Optimizer to convert the model from the ONNX representation to the OpenVINO Intermediate Representation (OpenVINO IR)

download_file(MODEL_LINK, directory=MODEL_DIR, show_progress=True)
with ZipFile(f"{MODEL_DIR}/{FILE_NAME}", "r") as zip_ref:
    zip_ref.extractall(MODEL_DIR)
model/MRPC.zip:   0%|          | 0.00/387M [00:00<?, ?B/s]

Import all dependencies to load the original PyTorch model and convert it to the ONNX representation.

BATCH_SIZE = 1
MAX_SEQ_LENGTH = 128


def export_model_to_onnx(model, path):
    with torch.no_grad():
        default_input = torch.ones(1, MAX_SEQ_LENGTH, dtype=torch.int64)
        inputs = {
            "input_ids": default_input,
            "attention_mask": default_input,
            "token_type_ids": default_input,
        }
        symbolic_names = {0: "batch_size", 1: "max_seq_len"}
        torch.onnx.export(
            model,
            (inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]),
            path,
            opset_version=11,
            do_constant_folding=True,
            input_names=["input_ids", "input_mask", "segment_ids"],
            output_names=["output"],
            dynamic_axes={
                "input_ids": symbolic_names,
                "input_mask": symbolic_names,
                "segment_ids": symbolic_names,
            },
        )
        print("ONNX model saved to {}".format(path))


torch_model = BertForSequenceClassification.from_pretrained(os.path.join(MODEL_DIR, "MRPC"))
onnx_model_path = Path(MODEL_DIR) / "bert_mrpc.onnx"
if not onnx_model_path.exists():
    export_model_to_onnx(torch_model, onnx_model_path)
ONNX model saved to model/bert_mrpc.onnx

Convert the ONNX Model to OpenVINO IR

ir_model_xml = onnx_model_path.with_suffix(".xml")
ir_model_bin = onnx_model_path.with_suffix(".bin")

# Convert the ONNX model to OpenVINO IR FP32.
if not ir_model_xml.exists():
    !mo --input_model $onnx_model_path --output_dir $MODEL_DIR --model_name $ir_model_xml.stem --input input_ids,input_mask,segment_ids --input_shape [1,128],[1,128],[1,128] --output output --data_type FP32
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model/bert_mrpc.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model
    - IR output name:   bert_mrpc
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     input_ids,input_mask,segment_ids
    - Output layers:    output
    - Input shapes:     [1,128],[1,128],[1,128]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP32
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.2.0-7713-af16ea1d79a-releases/2022/2
Model Optimizer version:    2022.2.0-7713-af16ea1d79a-releases/2022/2
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model/bert_mrpc.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model/bert_mrpc.bin
[ SUCCESS ] Total execution time: 1.93 seconds.
[ SUCCESS ] Memory consumed: 908 MB.
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Prepare MRPC Task Dataset

To run this tutorial, you will need to download the General Language Understanding Evaluation (GLUE) data for the MRPC task from HuggingFace. Use the code below to download a script that fetches the MRPC dataset.

download_file(
    "https://raw.githubusercontent.com/huggingface/transformers/f98ef14d161d7bcdc9808b5ec399981481411cc1/utils/download_glue_data.py",
    show_progress=False,
)
PosixPath('/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/download_glue_data.py')
from download_glue_data import format_mrpc

format_mrpc(DATA_DIR, "")
Processing MRPC...
Local MRPC data not specified, downloading data from https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt
    Completed!

Define DataLoader for POT

In this step, you define DataLoader based on POT API. It will be used to collect statistics for quantization and run model evaluation. Use helper functions from the HuggingFace Transformers to do the data preprocessing. It takes raw text data and encodes sentences and words, producing three model inputs. For more details about the data preprocessing and tokenization, refer to this description.

class MRPCDataLoader(POTDataLoader):
    # Required methods
    def __init__(self, config):
        """Constructor
        :param config: data loader specific config
        """
        if not isinstance(config, Dict):
            config = Dict(config)
        super().__init__(config)
        self._task = config["task"].lower()
        self._model_dir = config["model_dir"]
        self._data_dir = config["data_source"]
        self._batch_size = config["batch_size"]
        self._max_length = config["max_length"]
        self.examples = []
        self._prepare_dataset()

    def __len__(self):
        """Returns size of the dataset"""
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Returns annotation, data and metadata at the specified index.
        Possible formats:
        (index, annotation), data
        (index, annotation), data, metadata
        """
        if index >= len(self):
            raise IndexError

        batch = self.dataset[index]
        batch = tuple(t.detach().cpu().numpy() for t in batch)
        inputs = {"input_ids": batch[0], "input_mask": batch[1], "segment_ids": batch[2]}
        labels = batch[3]
        return (index, labels), inputs

    # Methods specific to the current implementation
    def _prepare_dataset(self):
        """Prepare dataset"""
        tokenizer = BertTokenizer.from_pretrained(self._model_dir, do_lower_case=True)
        processor = processors[self._task]()
        output_mode = output_modes[self._task]
        label_list = processor.get_labels()
        examples = processor.get_dev_examples(self._data_dir)
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=self._max_length,
            output_mode=output_mode,
        )
        all_input_ids = torch.unsqueeze(torch.tensor([f.input_ids for f in features], dtype=torch.long), 1)
        all_attention_mask = torch.unsqueeze(torch.tensor([f.attention_mask for f in features], dtype=torch.long), 1)
        all_token_type_ids = torch.unsqueeze(torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 1)
        all_labels = torch.unsqueeze(torch.tensor([f.label for f in features], dtype=torch.long), 1)
        self.dataset = TensorDataset(
            all_input_ids, all_attention_mask, all_token_type_ids, all_labels
        )
        self.examples = examples

Define Accuracy Metric Calculation

In this step the Metric interface for MRPC task metrics is implemented. It is used for validating the accuracy of the models.

class Accuracy(Metric):

    # Required methods
    def __init__(self):
        super().__init__()
        self._name = "Accuracy"
        self._matches = []

    @property
    def value(self):
        """Returns accuracy metric value for the last model output."""
        return {self._name: self._matches[-1]}

    @property
    def avg_value(self):
        """Returns accuracy metric value for all model outputs."""
        return {self._name: np.ravel(self._matches).mean()}

    def update(self, output, target):
        """
        Updates prediction matches.

        :param output: model output
        :param target: annotations
        """
        if len(output) > 1:
            raise Exception(
                "The accuracy metric cannot be calculated " "for a model with multiple outputs"
            )
        output = np.argmax(output)
        match = output == target[0]
        self._matches.append(match)

    def reset(self):
        """
        Resets collected matches
        """
        self._matches = []

    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {"direction": "higher-better", "type": "accuracy"}}

Run Quantization Pipeline

Define a configuration for the quantization pipeline and run it. Keep in mind that built-in IEEngine implementation of Engine interface from the POT API for model inference is used here.

warnings.filterwarnings("ignore")  # Suppress accuracychecker warnings.

model_config = Dict({"model_name": "bert_mrpc", "model": ir_model_xml, "weights": ir_model_bin})
engine_config = Dict({"device": "CPU"})
dataset_config = {
    "task": "mrpc",
    "data_source": os.path.join(DATA_DIR, "MRPC"),
    "model_dir": os.path.join(MODEL_DIR, "MRPC"),
    "batch_size": BATCH_SIZE,
    "max_length": MAX_SEQ_LENGTH,
}
algorithms = [
    {
        "name": "DefaultQuantization",
        "params": {
            "target_device": "ANY",
            "model_type": "transformer",
            "preset": "performance",
            "stat_subset_size": 250,
        },
    }
]


# Step 1: Load the model.
model = load_model(model_config=model_config)

# Step 2: Initialize the data loader.
data_loader = MRPCDataLoader(config=dataset_config)

# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric.
metric = Accuracy()

# Step 4: Initialize the engine for metric calculation and statistics collection.
engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)

# Step 5: Create a pipeline of compression algorithms.
pipeline = create_pipeline(algo_config=algorithms, engine=engine)

# Step 6 (Optional): Evaluate the original model. Print the results.
fp_results = pipeline.evaluate(model=model)
if fp_results:
    print("FP32 model results:")
    for name, value in fp_results.items():
        print(f"{name}: {value:.5f}")
FP32 model results:
Accuracy: 0.86029
# Step 7: Execute the pipeline.
warnings.filterwarnings("ignore")  # Suppress accuracychecker warnings.
print(
    f"Quantizing model with {algorithms[0]['params']['preset']} preset and {algorithms[0]['name']}"
)
start_time = time.perf_counter()
compressed_model = pipeline.run(model=model)
end_time = time.perf_counter()
print(f"Quantization finished in {end_time - start_time:.2f} seconds")

# Step 8 (Optional): Compress model weights to quantized precision
#                    in order to reduce the size of the final .bin file.
compress_model_weights(model=compressed_model)

# Step 9: Save the compressed model to the desired path.
compressed_model_paths = save_model(model=compressed_model, save_path=MODEL_DIR, model_name="quantized_bert_mrpc"
)
compressed_model_xml = compressed_model_paths[0]["model"]
Quantizing model with performance preset and DefaultQuantization
Quantization finished in 56.86 seconds
# Step 10 (Optional): Evaluate the compressed model and print the results.
int_results = pipeline.evaluate(model=compressed_model)

if int_results:
    print("INT8 model results:")
    for name, value in int_results.items():
        print(f"{name}: {value:.5f}")
INT8 model results:
Accuracy: 0.85049

Load and Test OpenVINO Model

To load and test converted model, perform the following: * Load the model and compile it for CPU. * Prepare the input. * Run the inference. * Get the answer from the model output.

core = ov.Core()

# Read the model from files.
model = core.read_model(model=compressed_model_xml)

# Assign dynamic shapes to every input layer.
for input_layer in model.inputs:
    input_shape = input_layer.partial_shape
    input_shape[1] = -1
    model.reshape({input_layer: input_shape})

# Compile the model for a specific device.
compiled_model_int8 = core.compile_model(model=model, device_name="CPU")

output_layer = compiled_model_int8.outputs[0]

The Data Loader returns a pair of sentences (indicated by sample_idx) and the inference compares these sentences and outputs whether their meaning is the same. You can test other sentences by changing sample_idx to another value (from 0 to 407).

sample_idx = 5

sample = data_loader.examples[sample_idx]
inputs = data_loader[sample_idx][1]

result = compiled_model_int8(inputs)[output_layer]
result = np.argmax(result)

print(f"Text 1: {sample.text_a}")
print(f"Text 2: {sample.text_b}")
print(f"The same meaning: {'yes' if result == 1 else 'no'}")
Text 1: Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed .
Text 2: It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .
The same meaning: yes

Compare Performance of the Original, Converted and Quantized Models

Compare the original PyTorch model with OpenVINO converted and quantized models (FP32, INT8) to see the difference in performance. It is expressed in Sentences Per Second (SPS) measure, which is the same as Frames Per Second (FPS) for images.

model = core.read_model(model=ir_model_xml)

# Assign dynamic shapes to every input layer.
for input_layer in model.inputs:
    input_shape = input_layer.partial_shape
    input_shape[1] = -1
    model.reshape({input_layer: input_shape})

# Compile the model for a specific device.
compiled_model_fp32 = core.compile_model(model=model, device_name="CPU")
num_samples = 50
inputs = data_loader[0][1]

with torch.no_grad():
    start = time.perf_counter()
    for _ in range(num_samples):
        torch_model(torch.as_tensor(list(inputs.values())).squeeze())
    end = time.perf_counter()
    time_torch = end - start
print(
    f"PyTorch model on CPU: {time_torch / num_samples:.3f} seconds per sentence, "
    f"SPS: {num_samples / time_torch:.2f}"
)

start = time.perf_counter()
for _ in range(num_samples):
    compiled_model_fp32(inputs)
end = time.perf_counter()
time_ir = end - start
print(
    f"IR FP32 model in OpenVINO Runtime/CPU: {time_ir / num_samples:.3f} "
    f"seconds per sentence, SPS: {num_samples / time_ir:.2f}"
)

start = time.perf_counter()
for _ in range(num_samples):
    compiled_model_int8(inputs)
end = time.perf_counter()
time_ir = end - start
print(
    f"OpenVINO IR INT8 model in OpenVINO Runtime/CPU: {time_ir / num_samples:.3f} "
    f"seconds per sentence, SPS: {num_samples / time_ir:.2f}"
)
PyTorch model on CPU: 0.171 seconds per sentence, SPS: 5.85
IR FP32 model in OpenVINO Runtime/CPU: 0.022 seconds per sentence, SPS: 46.18
OpenVINO IR INT8 model in OpenVINO Runtime/CPU: 0.011 seconds per sentence, SPS: 88.69

Finally, measure the inference performance of OpenVINO FP32 and INT8 models. For this purpose, use Benchmark Tool in OpenVINO.

Note: The benchmark_app tool is able to measure the performance of the OpenVINO Intermediate Representation (OpenVINO IR) models only. For more accurate performance, run benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app --help to see an overview of all command-line options.

# Inference FP32 model (OpenVINO IR)
! benchmark_app -m $ir_model_xml -d CPU -api sync
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading OpenVINO
[ WARNING ] PerformanceMode was not explicitly specified in command line. Device CPU performance hint will be set to LATENCY.
[ INFO ] OpenVINO:
         API version............. 2022.2.0-7713-af16ea1d79a-releases/2022/2
[ INFO ] Device info
         CPU
         openvino_intel_cpu_plugin version 2022.2
         Build................... 2022.2.0-7713-af16ea1d79a-releases/2022/2

[Step 3/11] Setting device configuration
[Step 4/11] Reading network files
[ INFO ] Read model took 165.16 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model input 'input_ids' precision i64, dimensions ([...]): 1 128
[ INFO ] Model input 'input_mask' precision i64, dimensions ([...]): 1 128
[ INFO ] Model input 'segment_ids' precision i64, dimensions ([...]): 1 128
[ INFO ] Model output 'output' precision f32, dimensions ([...]): 1 2
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 164.27 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] DEVICE: CPU
[ INFO ]   AVAILABLE_DEVICES  , ['']
[ INFO ]   RANGE_FOR_ASYNC_INFER_REQUESTS  , (1, 1, 1)
[ INFO ]   RANGE_FOR_STREAMS  , (1, 24)
[ INFO ]   FULL_DEVICE_NAME  , Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
[ INFO ]   OPTIMIZATION_CAPABILITIES  , ['WINOGRAD', 'FP32', 'FP16', 'INT8', 'BIN', 'EXPORT_IMPORT']
[ INFO ]   CACHE_DIR  ,
[ INFO ]   NUM_STREAMS  , 1
[ INFO ]   AFFINITY  , Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS  , 0
[ INFO ]   PERF_COUNT  , False
[ INFO ]   INFERENCE_PRECISION_HINT  , <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT  , PerformanceMode.LATENCY
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS  , 0
[Step 9/11] Creating infer requests and preparing input data
[ INFO ] Create 1 infer requests took 0.11 ms
[ WARNING ] No input files were given for input 'input_ids'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'input_mask'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'segment_ids'!. This input will be filled with random values!
[ INFO ] Fill input 'input_ids' with random values
[ INFO ] Fill input 'input_mask' with random values
[ INFO ] Fill input 'segment_ids' with random values
[Step 10/11] Measuring performance (Start inference synchronously, inference only: True, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 21.00 ms
[Step 11/11] Dumping statistics report
Count:          2727 iterations
Duration:       60007.67 ms
Latency:
    Median:     22.28 ms
    AVG:        21.92 ms
    MIN:        20.24 ms
    MAX:        23.17 ms
Throughput: 44.87 FPS
# Inference INT8 model (OpenVINO IR)
! benchmark_app -m $compressed_model_xml -d CPU -api sync
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading OpenVINO
[ WARNING ] PerformanceMode was not explicitly specified in command line. Device CPU performance hint will be set to LATENCY.
[ INFO ] OpenVINO:
         API version............. 2022.2.0-7713-af16ea1d79a-releases/2022/2
[ INFO ] Device info
         CPU
         openvino_intel_cpu_plugin version 2022.2
         Build................... 2022.2.0-7713-af16ea1d79a-releases/2022/2

[Step 3/11] Setting device configuration
[Step 4/11] Reading network files
[ INFO ] Read model took 118.18 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model input 'input_ids' precision i64, dimensions ([...]): 1 128
[ INFO ] Model input 'input_mask' precision i64, dimensions ([...]): 1 128
[ INFO ] Model input 'segment_ids' precision i64, dimensions ([...]): 1 128
[ INFO ] Model output 'output' precision f32, dimensions ([...]): 1 2
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 223.67 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] DEVICE: CPU
[ INFO ]   AVAILABLE_DEVICES  , ['']
[ INFO ]   RANGE_FOR_ASYNC_INFER_REQUESTS  , (1, 1, 1)
[ INFO ]   RANGE_FOR_STREAMS  , (1, 24)
[ INFO ]   FULL_DEVICE_NAME  , Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
[ INFO ]   OPTIMIZATION_CAPABILITIES  , ['WINOGRAD', 'FP32', 'FP16', 'INT8', 'BIN', 'EXPORT_IMPORT']
[ INFO ]   CACHE_DIR  ,
[ INFO ]   NUM_STREAMS  , 1
[ INFO ]   AFFINITY  , Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS  , 0
[ INFO ]   PERF_COUNT  , False
[ INFO ]   INFERENCE_PRECISION_HINT  , <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT  , PerformanceMode.LATENCY
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS  , 0
[Step 9/11] Creating infer requests and preparing input data
[ INFO ] Create 1 infer requests took 0.11 ms
[ WARNING ] No input files were given for input 'input_ids'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'input_mask'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'segment_ids'!. This input will be filled with random values!
[ INFO ] Fill input 'input_ids' with random values
[ INFO ] Fill input 'input_mask' with random values
[ INFO ] Fill input 'segment_ids' with random values
[Step 10/11] Measuring performance (Start inference synchronously, inference only: True, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 10.38 ms
[Step 11/11] Dumping statistics report
Count:          5764 iterations
Duration:       60001.73 ms
Latency:
    Median:     10.36 ms
    AVG:        10.32 ms
    MIN:        9.57 ms
    MAX:        13.52 ms
Throughput: 96.48 FPS