Quantize Wav2Vec Speech Recognition Model using NNCF PTQ API

This Jupyter notebook can be launched after a local installation only.

Github

This tutorial demonstrates how to apply INT8 quantization to the speech recognition model, known as Wav2Vec2, using the NNCF (Neural Network Compression Framework) 8-bit quantization in post-training mode (without the fine-tuning pipeline). This notebook uses a fine-tuned Wav2Vec2-Base-960h PyTorch model trained on the LibriSpeech ASR corpus. The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:

  • Download and prepare the Wav2Vec2 model and LibriSpeech dataset.

  • Define data loading and accuracy validation functionality.

  • Model quantization.

  • Compare Accuracy of original PyTorch model, OpenVINO FP16 and INT8 models.

  • Compare performance of the original and quantized models.

Table of contents:

%pip install -q "openvino>=2023.3.0" "nncf>=2.7"
%pip install datasets "torchmetrics>=0.11.0" "torch>=2.1.0" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q soundfile librosa "transformers>=4.36.2" --extra-index-url https://download.pytorch.org/whl/cpu
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu
Requirement already satisfied: datasets in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (2.18.0)
Collecting torchmetrics>=0.11.0
  Using cached torchmetrics-1.3.1-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: torch>=2.1.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (2.1.0+cpu)
Requirement already satisfied: filelock in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (3.13.1)
Requirement already satisfied: numpy>=1.17 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (1.23.5)
Requirement already satisfied: pyarrow>=12.0.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (15.0.1)
Requirement already satisfied: pyarrow-hotfix in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.6)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (2.0.3)
Requirement already satisfied: requests>=2.19.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (2.31.0)
Requirement already satisfied: tqdm>=4.62.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (4.66.2)
Requirement already satisfied: xxhash in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (3.4.1)
Requirement already satisfied: multiprocess in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)
Requirement already satisfied: aiohttp in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (3.9.3)
Requirement already satisfied: huggingface-hub>=0.19.4 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.21.4)
Requirement already satisfied: packaging in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (6.0.1)
Collecting lightning-utilities>=0.8.0 (from torchmetrics>=0.11.0)
  Using cached lightning_utilities-0.10.1-py3-none-any.whl.metadata (4.8 kB)
Requirement already satisfied: typing-extensions in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torchmetrics>=0.11.0) (4.10.0)
Requirement already satisfied: sympy in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torch>=2.1.0) (1.12)
Requirement already satisfied: networkx in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torch>=2.1.0) (3.1)
Requirement already satisfied: jinja2 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torch>=2.1.0) (3.1.3)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)
Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)
Requirement already satisfied: setuptools in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from lightning-utilities>=0.8.0->torchmetrics>=0.11.0) (69.1.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2024.2.2)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from jinja2->torch>=2.1.0) (2.1.5)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: mpmath>=0.19 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from sympy->torch>=2.1.0) (1.3.0)
Requirement already satisfied: six>=1.5 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Using cached torchmetrics-1.3.1-py3-none-any.whl (840 kB)
Using cached lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.1 torchmetrics-1.3.1
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Imports

import numpy as np
import openvino as ov
import torch
import IPython.display as ipd

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

Settings

from pathlib import Path

# Set the data and model directories, model source URL and model filename.
MODEL_DIR = Path("model")
MODEL_DIR.mkdir(exist_ok=True)

Prepare the Model

Perform the following: - Download and unpack a pre-trained Wav2Vec2 model. - Run model conversion API to convert the model from the PyTorch representation to the OpenVINO Intermediate Representation (OpenVINO IR).

torch_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", ctc_loss_reduction="mean")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480
ov_model = ov.convert_model(torch_model, example_input=torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float))

ir_model_path = MODEL_DIR / "wav2vec2_base.xml"
ov.save_model(ov_model, ir_model_path)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py:4193: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead
  warnings.warn(
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:594: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:633: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):

Prepare LibriSpeech Dataset

For demonstration purposes, we will use short dummy version of LibriSpeech dataset - patrickvonplaten/librispeech_asr_dummy to speed up model evaluation. Model accuracy can be different from reported in the paper. For reproducing original accuracy, use librispeech_asr dataset.

from datasets import load_dataset


dataset = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
test_sample = dataset[0]["audio"]


# define preprocessing function for converting audio to input values for model
def map_to_input(batch):
    preprocessed_signal = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=batch['audio']['sampling_rate'])
    input_values = preprocessed_signal.input_values
    batch['input_values'] = input_values
    return batch


# apply preprocessing function to dataset and remove audio column, to save memory as we do not need it anymore
dataset = dataset.map(map_to_input, batched=False, remove_columns=["audio"])
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/datasets/load.py:1461: FutureWarning: The repository for patrickvonplaten/librispeech_asr_dummy contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/patrickvonplaten/librispeech_asr_dummy
You can avoid this message in future by passing the argument trust_remote_code=True.
Passing trust_remote_code=True will be mandatory to load this dataset from the next major release of datasets.
  warnings.warn(

Run Quantization

NNCF provides a suite of advanced algorithms for Neural Networks inference optimization in OpenVINO with minimal accuracy drop.

Create a quantized model from the pre-trained FP16 model and the calibration dataset. The optimization process contains the following steps:

  1. Create a Dataset for quantization.

  2. Run nncf.quantize for getting an optimized model. The nncf.quantize function provides an interface for model quantization. It requires an instance of the OpenVINO Model and quantization dataset. Optionally, some additional parameters for the configuration quantization process (number of samples for quantization, preset, ignored scope, etc.) can be provided. For more accurate results, we should keep the operation in the postprocessing subgraph in floating point precision, using the ignored_scope parameter. For more information see Tune quantization parameters. For this model, ignored scope was selected experimentally, based on result of quantization with accuracy control. For understanding how it works please check following notebook

  3. Serialize OpenVINO IR model using ov.save_model function.

import nncf
from nncf.parameters import ModelType

def transform_fn(data_item):
    """
    Extract the model's input from the data item.
    The data item here is the data item that is returned from the data source per iteration.
    This function should be passed when the data item cannot be used as model's input.
    """
    return np.array(data_item["input_values"])


calibration_dataset = nncf.Dataset(dataset, transform_fn)

quantized_model = nncf.quantize(
    ov_model,
    calibration_dataset,
    model_type=ModelType.TRANSFORMER,  # specify additional transformer patterns in the model
    ignored_scope=nncf.IgnoredScope(
        names=[
            "__module.wav2vec2.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution",
            "__module.wav2vec2.feature_extractor.conv_layers.2.conv/aten::_convolution/Convolution",
            "__module.wav2vec2.feature_extractor.conv_layers.3.conv/aten::_convolution/Convolution",
            "__module.wav2vec2.feature_extractor.conv_layers.0.conv/aten::_convolution/Convolution",
        ],
    ),
)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
2024-03-12 22:26:40.835104: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-03-12 22:26:40.866034: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-12 22:26:41.461559: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Output()
Output()
INFO:nncf:4 ignored nodes were found by name in the NNCFGraph
INFO:nncf:36 ignored nodes were found by name in the NNCFGraph
INFO:nncf:50 ignored nodes were found by name in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 3 __module.wav2vec2.feature_extractor.conv_layers.0.conv/aten::_convolution/Convolution
INFO:nncf:Not adding activation input quantizer for operation: 11 __module.wav2vec2.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution
INFO:nncf:Not adding activation input quantizer for operation: 13 __module.wav2vec2.feature_extractor.conv_layers.2.conv/aten::_convolution/Convolution
INFO:nncf:Not adding activation input quantizer for operation: 15 __module.wav2vec2.feature_extractor.conv_layers.3.conv/aten::_convolution/Convolution
Output()
Output()
MODEL_NAME = 'quantized_wav2vec2_base'
quantized_model_path = Path(f"{MODEL_NAME}_openvino_model/{MODEL_NAME}_quantized.xml")
ov.save_model(quantized_model, quantized_model_path)

Model Usage Example with Inference Pipeline

Both initial (FP16) and quantized (INT8) models are exactly the same in use.

Start with taking one example from the dataset to show inference steps for it.

ipd.Audio(test_sample["array"], rate=16000)

For running model with OpenVINO, we should select inference device first. Please select one from available devices from dropdown list:

import ipywidgets as widgets

core = ov.Core()
device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value='AUTO',
    description='Device:',
    disabled=False,
)

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

Next, load the quantized model to the inference pipeline.

compiled_model = core.compile_model(model=quantized_model, device_name=device.value)

input_data = np.expand_dims(test_sample["array"], axis=0)

Next, make a prediction.

predictions = compiled_model(input_data)[0]
predicted_ids = np.argmax(predictions, axis=-1)
transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
print(transcription)
['A MAN SAID TO THE UNIVERSE SIR I EXIST']

Validate model accuracy on dataset

For model accuracy evaluation, Word Error Rate metric can be used. Word Error Rate or WER is the ratio of errors in a transcript to the total words spoken. A lower WER in speech-to-text means better accuracy in recognizing speech.

For WER calculation, we will use torchmetrics library.

from torchmetrics import WordErrorRate
from tqdm.notebook import tqdm

# inference function for pytorch
def torch_infer(model, sample):
    logits = model(torch.Tensor(sample['input_values'])).logits
    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    return transcription


# inference function for openvino
def ov_infer(model, sample):
    logits = model(np.array(sample['input_values']))[0]
    predicted_ids = np.argmax(logits, axis=-1)
    transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
    return transcription


def compute_wer(dataset, model, infer_fn):
    wer = WordErrorRate()
    for sample in tqdm(dataset):
        # run infer function on sample
        transcription = infer_fn(model, sample)
        # update metric on sample result
        wer.update(transcription, [sample['text']])
    # finalize metric calculation
    result = wer.compute()
    return result

Now, you just need to decode predicted probabilities to text, using tokenizer decode_logits.

Alternatively, use a built-in Wav2Vec2Processor tokenizer from the transformers package.

Now, compute WER for the original PyTorch model, OpenVINO IR model and quantized model.

compiled_fp32_ov_model = core.compile_model(ov_model, device.value)

pt_result = compute_wer(dataset, torch_model, torch_infer)
ov_result = compute_wer(dataset, compiled_fp32_ov_model, ov_infer)
int8_ov_result = compute_wer(dataset, compiled_model, ov_infer)
print(f'[PyTorch]       Word Error Rate: {pt_result:.4f}')
print(f'[OpenVino FP16] Word Error Rate: {ov_result:.4}')
print(f'[OpenVino INT8] Word Error Rate: {int8_ov_result:.4f}')
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-632/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:62: FutureWarning: Importing WordErrorRate from torchmetrics was deprecated and will be removed in 2.0. Import WordErrorRate from torchmetrics.text instead.
  _future_warning(
0%|          | 0/73 [00:00<?, ?it/s]
0%|          | 0/73 [00:00<?, ?it/s]
0%|          | 0/73 [00:00<?, ?it/s]
[PyTorch]       Word Error Rate: 0.0530
[OpenVino FP16] Word Error Rate: 0.05304
[OpenVino INT8] Word Error Rate: 0.0548

Compare Performance of the Original and Quantized Models

Finally, use Benchmark Tool to measure the inference performance of the FP16 and INT8 models.

NOTE: For more accurate performance, it is recommended to run benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app --help to see an overview of all command-line options.

# Inference FP16 model (OpenVINO IR)
! benchmark_app -m $ir_model_path -shape [1,30480] -d $device.value -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ WARNING ] Default duration 120 seconds is used for unknown device AUTO
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2024.0.0-14509-34caeefd078-releases/2024/0
[ INFO ]
[ INFO ] Device info:
[ INFO ] AUTO
[ INFO ] Build ................................. 2024.0.0-14509-34caeefd078-releases/2024/0
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 27.55 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     input_values (node: input_values) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ]     logits (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'input_values': [1,30480]
[ INFO ] Reshape model took 29.21 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     input_values (node: input_values) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ]     logits (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 531.22 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: Model0
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   MULTI_DEVICE_PRIORITIES: CPU
[ INFO ]   CPU:
[ INFO ]     AFFINITY: Affinity.CORE
[ INFO ]     CPU_DENORMALS_OPTIMIZATION: False
[ INFO ]     CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[ INFO ]     DYNAMIC_QUANTIZATION_GROUP_SIZE: 0
[ INFO ]     ENABLE_CPU_PINNING: True
[ INFO ]     ENABLE_HYPER_THREADING: True
[ INFO ]     EXECUTION_DEVICES: ['CPU']
[ INFO ]     EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]     INFERENCE_NUM_THREADS: 24
[ INFO ]     INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]     KV_CACHE_PRECISION: <Type: 'float16'>
[ INFO ]     LOG_LEVEL: Level.NO
[ INFO ]     NETWORK_NAME: Model0
[ INFO ]     NUM_STREAMS: 6
[ INFO ]     OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]     PERFORMANCE_HINT: THROUGHPUT
[ INFO ]     PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]     PERF_COUNT: NO
[ INFO ]     SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   MODEL_PRIORITY: Priority.MEDIUM
[ INFO ]   LOADED_FROM_CACHE: False
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'input_values'!. This input will be filled with random values!
[ INFO ] Fill input 'input_values' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 120000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 61.24 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            5484 iterations
[ INFO ] Duration:         120219.99 ms
[ INFO ] Latency:
[ INFO ]    Median:        131.07 ms
[ INFO ]    Average:       131.36 ms
[ INFO ]    Min:           107.53 ms
[ INFO ]    Max:           328.44 ms
[ INFO ] Throughput:   45.62 FPS
# Inference INT8 model (OpenVINO IR)
! benchmark_app -m $quantized_model_path -shape [1,30480] -d $device.value -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ WARNING ] Default duration 120 seconds is used for unknown device AUTO
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2024.0.0-14509-34caeefd078-releases/2024/0
[ INFO ]
[ INFO ] Device info:
[ INFO ] AUTO
[ INFO ] Build ................................. 2024.0.0-14509-34caeefd078-releases/2024/0
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 40.21 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     input_values (node: input_values) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ]     logits (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'input_values': [1,30480]
[ INFO ] Reshape model took 35.76 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     input_values (node: input_values) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ]     logits (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 1294.11 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: Model0
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   MULTI_DEVICE_PRIORITIES: CPU
[ INFO ]   CPU:
[ INFO ]     AFFINITY: Affinity.CORE
[ INFO ]     CPU_DENORMALS_OPTIMIZATION: False
[ INFO ]     CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[ INFO ]     DYNAMIC_QUANTIZATION_GROUP_SIZE: 0
[ INFO ]     ENABLE_CPU_PINNING: True
[ INFO ]     ENABLE_HYPER_THREADING: True
[ INFO ]     EXECUTION_DEVICES: ['CPU']
[ INFO ]     EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]     INFERENCE_NUM_THREADS: 24
[ INFO ]     INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]     KV_CACHE_PRECISION: <Type: 'float16'>
[ INFO ]     LOG_LEVEL: Level.NO
[ INFO ]     NETWORK_NAME: Model0
[ INFO ]     NUM_STREAMS: 6
[ INFO ]     OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]     PERFORMANCE_HINT: THROUGHPUT
[ INFO ]     PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]     PERF_COUNT: NO
[ INFO ]     SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   MODEL_PRIORITY: Priority.MEDIUM
[ INFO ]   LOADED_FROM_CACHE: False
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'input_values'!. This input will be filled with random values!
[ INFO ] Fill input 'input_values' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 120000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 58.50 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            8268 iterations
[ INFO ] Duration:         120137.02 ms
[ INFO ] Latency:
[ INFO ]    Median:        87.16 ms
[ INFO ]    Average:       87.05 ms
[ INFO ]    Min:           73.56 ms
[ INFO ]    Max:           112.73 ms
[ INFO ] Throughput:   68.82 FPS