Quantize Speech Recognition Models using NNCF PTQ API

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This tutorial demonstrates how to apply INT8 quantization to the speech recognition model, known as Wav2Vec2, using the NNCF (Neural Network Compression Framework) 8-bit quantization in post-training mode (without the fine-tuning pipeline). This notebook uses a fine-tuned Wav2Vec2-Base-960h PyTorch model trained on the LibriSpeech ASR corpus. The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:

  • Download and prepare the Wav2Vec2 model and LibriSpeech dataset.

  • Define data loading and accuracy validation functionality.

  • Model quantization.

  • Compare Accuracy of original PyTorch model, OpenVINO FP16 and INT8 models.

  • Compare performance of the original and quantized models.

Table of contents:

!pip install -q "openvino==2023.1.0.dev20230811" "nncf>=2.5.0"
!pip install -q soundfile librosa transformers onnx

Imports

import os
import sys
import re
import numpy as np
import openvino as ov
import tarfile
import torch
from itertools import groupby
import soundfile as sf
import IPython.display as ipd

from transformers import Wav2Vec2ForCTC

sys.path.append("../utils")
from notebook_utils import download_file
2023-09-08 22:38:42.752981: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2023-09-08 22:38:42.787924: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-08 22:38:43.332490: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Settings

from pathlib import Path

# Set the data and model directories, model source URL and model filename.
MODEL_DIR = Path("model")
DATA_DIR = Path("../data/datasets/librispeech")
MODEL_DIR.mkdir(exist_ok=True)
DATA_DIR.mkdir(exist_ok=True)

Prepare the Model

Perform the following: - Download and unpack a pre-trained Wav2Vec2 model. - Convert the model to ONNX. - Run model conversion API to convert the model from the ONNX representation to the OpenVINO Intermediate Representation (OpenVINO IR).

download_file("https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin", directory=Path(MODEL_DIR) / 'pytorch', show_progress=True)
download_file("https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json", directory=Path(MODEL_DIR) / 'pytorch', show_progress=False)
model/pytorch/pytorch_model.bin:   0%|          | 0.00/360M [00:00<?, ?B/s]
PosixPath('/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-499/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/pytorch/config.json')

Import all dependencies to load the original PyTorch model and convert it to the ONNX representation.

BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480


def export_model_to_onnx(model, path):
    with torch.no_grad():
        default_input = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
        inputs = {
            "inputs": default_input
        }
        symbolic_names = {0: "batch_size", 1: "sequence_len"}
        torch.onnx.export(
            model,
            (inputs["inputs"]),
            path,
            opset_version=11,
            input_names=["inputs"],
            output_names=["logits"],
            dynamic_axes={
                "inputs": symbolic_names,
                "logits": symbolic_names,
            },
        )
        print("ONNX model saved to {}".format(path))


torch_model = Wav2Vec2ForCTC.from_pretrained(Path(MODEL_DIR) / 'pytorch')
onnx_model_path = Path(MODEL_DIR) / "wav2vec2_base.onnx"
if not onnx_model_path.exists():
    export_model_to_onnx(torch_model, onnx_model_path)
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at model/pytorch and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-499/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-499/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
ONNX model saved to model/wav2vec2_base.onnx
ov_model = ov.convert_model(onnx_model_path)

ir_model_path = MODEL_DIR / "wav2vec2_base.xml"
ov.save_model(ov_model, str(ir_model_path))

Prepare LibriSpeech Dataset

Use the code below to download and unpack the archives with ‘dev-clean’ and ‘test-clean’ subsets of LibriSpeech Dataset.

download_file("http://openslr.elda.org/resources/12/dev-clean.tar.gz", directory=DATA_DIR, show_progress=True)
download_file("http://openslr.elda.org/resources/12/test-clean.tar.gz", directory=DATA_DIR, show_progress=True)

if not os.path.exists(f'{DATA_DIR}/LibriSpeech/dev-clean'):
    with tarfile.open(f"{DATA_DIR}/dev-clean.tar.gz") as tar:
        tar.extractall(path=DATA_DIR)
if not os.path.exists(f'{DATA_DIR}/LibriSpeech/test-clean'):
    with tarfile.open(f"{DATA_DIR}/test-clean.tar.gz") as tar:
        tar.extractall(path=DATA_DIR)
../data/datasets/librispeech/dev-clean.tar.gz:   0%|          | 0.00/322M [00:00<?, ?B/s]
../data/datasets/librispeech/test-clean.tar.gz:   0%|          | 0.00/331M [00:00<?, ?B/s]

Define DataLoader

Wav2Vec2 model accepts a raw waveform of the speech signal as input and produces vocabulary class estimations as output. Since the dataset contains audio files in FLAC format, use the soundfile package to convert them to waveform.

Note

Consider increasing samples_limit to get more precise results. A suggested value is 300 or more, as it will take longer time to process.

class LibriSpeechDataLoader:

    @staticmethod
    def read_flac(file_name):
        speech, samplerate = sf.read(file_name)
        assert samplerate == 16000, "read_flac: only 16kHz supported!"
        return speech

    # Required methods
    def __init__(self, config, samples_limit=300):
        """Constructor
        :param config: data loader specific config
        """
        self.samples_limit = samples_limit
        self._data_dir = config["data_source"]
        self._ds = []
        self._prepare_dataset()

    def __len__(self):
        """Returns size of the dataset"""
        return len(self._ds)

    def __getitem__(self, index):
        """
        Returns annotation, data and metadata at the specified index.
        Possible formats:
        (index, annotation), data
        (index, annotation), data, metadata
        """
        label = self._ds[index][0]
        inputs = {'inputs': np.expand_dims(self._ds[index][1], axis=0)}
        return label, inputs

    # Methods specific to the current implementation
    def _prepare_dataset(self):
        pattern = re.compile(r'([0-9\-]+)\s+(.+)')
        data_folder = Path(self._data_dir)
        txts = list(data_folder.glob('**/*.txt'))
        counter = 0
        for txt in txts:
            content = txt.open().readlines()
            for line in content:
                res = pattern.search(line)
                if not res:
                    continue
                name = res.group(1)
                transcript = res.group(2)
                fname = txt.parent / name
                fname = fname.with_suffix('.flac')
                identifier = str(fname.relative_to(data_folder))
                self._ds.append(((counter, transcript.upper()), LibriSpeechDataLoader.read_flac(os.path.join(self._data_dir, identifier))))
                counter += 1
                if counter >= self.samples_limit:
                    # Limit exceeded
                    return

Run Quantization

NNCF provides a suite of advanced algorithms for Neural Networks inference optimization in OpenVINO with minimal accuracy drop.

Create a quantized model from the pre-trained FP16 model and the calibration dataset. The optimization process contains the following steps: 1. Create a Dataset for quantization. 2. Run nncf.quantize for getting an optimized model. The nncf.quantize function provides an interface for model quantization. It requires an instance of the OpenVINO Model and quantization dataset. Optionally, some additional parameters for the configuration quantization process (number of samples for quantization, preset, ignored scope, etc.) can be provided. For more accurate results, we should keep the operation in the postprocessing subgraph in floating point precision, using the ignored_scope parameter. advanced_parameters can be used to specify advanced quantization parameters for fine-tuning the quantization algorithm. In this tutorial we pass range estimator parameters for activations. For more information see Tune quantization parameters. 3. Serialize OpenVINO IR model using openvino.runtime.serialize function.

import nncf
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters, RangeEstimatorParameters
from nncf.quantization.range_estimator import StatisticsCollectorParameters, StatisticsType, AggregatorType
from nncf.parameters import ModelType


def transform_fn(data_item):
    """
    Extract the model's input from the data item.
    The data item here is the data item that is returned from the data source per iteration.
    This function should be passed when the data item cannot be used as model's input.
    """
    _, inputs = data_item

    return inputs["inputs"]


dataset_config = {"data_source": os.path.join(DATA_DIR, "LibriSpeech/dev-clean")}
data_loader = LibriSpeechDataLoader(dataset_config, samples_limit=300)
calibration_dataset = nncf.Dataset(data_loader, transform_fn)


quantized_model = nncf.quantize(
    ov_model,
    calibration_dataset,
    model_type=ModelType.TRANSFORMER,  # specify additional transformer patterns in the model
    ignored_scope=nncf.IgnoredScope(
        names=[
            '/wav2vec2/feature_extractor/conv_layers.1/conv/Conv',
            '/wav2vec2/feature_extractor/conv_layers.2/conv/Conv',
            '/wav2vec2/encoder/layers.7/feed_forward/output_dense/MatMul'
        ],
    ),
    advanced_parameters=AdvancedQuantizationParameters(
        activations_range_estimator_params=RangeEstimatorParameters(
            min=StatisticsCollectorParameters(
                statistics_type=StatisticsType.MIN,
                aggregator_type=AggregatorType.MIN
            ),
            max=StatisticsCollectorParameters(
                statistics_type=StatisticsType.QUANTILE,
                aggregator_type=AggregatorType.MEAN,
                quantile_outlier_prob=0.0001
            ),
        )
    )
)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
INFO:nncf:3 ignored nodes was found by name in the NNCFGraph
INFO:nncf:193 ignored nodes was found by types in the NNCFGraph
INFO:nncf:24 ignored nodes was found by name in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 5 MVN_224
INFO:nncf:Not adding activation input quantizer for operation: 7 /wav2vec2/feature_extractor/conv_layers.0/layer_norm/Mul
8 /wav2vec2/feature_extractor/conv_layers.0/layer_norm/Add

INFO:nncf:Not adding activation input quantizer for operation: 10 /wav2vec2/feature_extractor/conv_layers.1/conv/Conv
INFO:nncf:Not adding activation input quantizer for operation: 12 /wav2vec2/feature_extractor/conv_layers.2/conv/Conv
INFO:nncf:Not adding activation input quantizer for operation: 23 /wav2vec2/feature_projection/layer_norm/Div
24 /wav2vec2/feature_projection/layer_norm/Mul
25 /wav2vec2/feature_projection/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 28 /wav2vec2/encoder/Add
INFO:nncf:Not adding activation input quantizer for operation: 30 /wav2vec2/encoder/layer_norm/Div
32 /wav2vec2/encoder/layer_norm/Mul
34 /wav2vec2/encoder/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 36 /wav2vec2/encoder/layers.0/Add
INFO:nncf:Not adding activation input quantizer for operation: 42 /wav2vec2/encoder/layers.0/layer_norm/Div
49 /wav2vec2/encoder/layers.0/layer_norm/Mul
58 /wav2vec2/encoder/layers.0/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 66 /wav2vec2/encoder/layers.0/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 74 /wav2vec2/encoder/layers.0/final_layer_norm/Div
79 /wav2vec2/encoder/layers.0/final_layer_norm/Mul
82 /wav2vec2/encoder/layers.0/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 84 /wav2vec2/encoder/layers.1/Add
INFO:nncf:Not adding activation input quantizer for operation: 90 /wav2vec2/encoder/layers.1/layer_norm/Div
96 /wav2vec2/encoder/layers.1/layer_norm/Mul
105 /wav2vec2/encoder/layers.1/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 113 /wav2vec2/encoder/layers.1/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 121 /wav2vec2/encoder/layers.1/final_layer_norm/Div
126 /wav2vec2/encoder/layers.1/final_layer_norm/Mul
129 /wav2vec2/encoder/layers.1/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 131 /wav2vec2/encoder/layers.2/Add
INFO:nncf:Not adding activation input quantizer for operation: 137 /wav2vec2/encoder/layers.2/layer_norm/Div
143 /wav2vec2/encoder/layers.2/layer_norm/Mul
152 /wav2vec2/encoder/layers.2/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 160 /wav2vec2/encoder/layers.2/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 168 /wav2vec2/encoder/layers.2/final_layer_norm/Div
173 /wav2vec2/encoder/layers.2/final_layer_norm/Mul
176 /wav2vec2/encoder/layers.2/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 178 /wav2vec2/encoder/layers.3/Add
INFO:nncf:Not adding activation input quantizer for operation: 184 /wav2vec2/encoder/layers.3/layer_norm/Div
190 /wav2vec2/encoder/layers.3/layer_norm/Mul
199 /wav2vec2/encoder/layers.3/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 207 /wav2vec2/encoder/layers.3/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 215 /wav2vec2/encoder/layers.3/final_layer_norm/Div
220 /wav2vec2/encoder/layers.3/final_layer_norm/Mul
223 /wav2vec2/encoder/layers.3/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 225 /wav2vec2/encoder/layers.4/Add
INFO:nncf:Not adding activation input quantizer for operation: 231 /wav2vec2/encoder/layers.4/layer_norm/Div
237 /wav2vec2/encoder/layers.4/layer_norm/Mul
246 /wav2vec2/encoder/layers.4/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 254 /wav2vec2/encoder/layers.4/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 262 /wav2vec2/encoder/layers.4/final_layer_norm/Div
267 /wav2vec2/encoder/layers.4/final_layer_norm/Mul
270 /wav2vec2/encoder/layers.4/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 272 /wav2vec2/encoder/layers.5/Add
INFO:nncf:Not adding activation input quantizer for operation: 278 /wav2vec2/encoder/layers.5/layer_norm/Div
284 /wav2vec2/encoder/layers.5/layer_norm/Mul
293 /wav2vec2/encoder/layers.5/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 301 /wav2vec2/encoder/layers.5/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 309 /wav2vec2/encoder/layers.5/final_layer_norm/Div
314 /wav2vec2/encoder/layers.5/final_layer_norm/Mul
317 /wav2vec2/encoder/layers.5/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 319 /wav2vec2/encoder/layers.6/Add
INFO:nncf:Not adding activation input quantizer for operation: 325 /wav2vec2/encoder/layers.6/layer_norm/Div
331 /wav2vec2/encoder/layers.6/layer_norm/Mul
340 /wav2vec2/encoder/layers.6/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 348 /wav2vec2/encoder/layers.6/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 356 /wav2vec2/encoder/layers.6/final_layer_norm/Div
361 /wav2vec2/encoder/layers.6/final_layer_norm/Mul
364 /wav2vec2/encoder/layers.6/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 366 /wav2vec2/encoder/layers.7/Add
INFO:nncf:Not adding activation input quantizer for operation: 372 /wav2vec2/encoder/layers.7/layer_norm/Div
378 /wav2vec2/encoder/layers.7/layer_norm/Mul
387 /wav2vec2/encoder/layers.7/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 412 /wav2vec2/encoder/layers.7/feed_forward/output_dense/MatMul
418 /wav2vec2/encoder/layers.7/feed_forward/output_dense/Add

INFO:nncf:Not adding activation input quantizer for operation: 395 /wav2vec2/encoder/layers.7/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 403 /wav2vec2/encoder/layers.7/final_layer_norm/Div
408 /wav2vec2/encoder/layers.7/final_layer_norm/Mul
411 /wav2vec2/encoder/layers.7/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 413 /wav2vec2/encoder/layers.8/Add
INFO:nncf:Not adding activation input quantizer for operation: 419 /wav2vec2/encoder/layers.8/layer_norm/Div
425 /wav2vec2/encoder/layers.8/layer_norm/Mul
434 /wav2vec2/encoder/layers.8/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 442 /wav2vec2/encoder/layers.8/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 450 /wav2vec2/encoder/layers.8/final_layer_norm/Div
455 /wav2vec2/encoder/layers.8/final_layer_norm/Mul
458 /wav2vec2/encoder/layers.8/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 460 /wav2vec2/encoder/layers.9/Add
INFO:nncf:Not adding activation input quantizer for operation: 466 /wav2vec2/encoder/layers.9/layer_norm/Div
472 /wav2vec2/encoder/layers.9/layer_norm/Mul
481 /wav2vec2/encoder/layers.9/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 489 /wav2vec2/encoder/layers.9/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 497 /wav2vec2/encoder/layers.9/final_layer_norm/Div
502 /wav2vec2/encoder/layers.9/final_layer_norm/Mul
505 /wav2vec2/encoder/layers.9/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 507 /wav2vec2/encoder/layers.10/Add
INFO:nncf:Not adding activation input quantizer for operation: 513 /wav2vec2/encoder/layers.10/layer_norm/Div
519 /wav2vec2/encoder/layers.10/layer_norm/Mul
528 /wav2vec2/encoder/layers.10/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 536 /wav2vec2/encoder/layers.10/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 544 /wav2vec2/encoder/layers.10/final_layer_norm/Div
549 /wav2vec2/encoder/layers.10/final_layer_norm/Mul
552 /wav2vec2/encoder/layers.10/final_layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 554 /wav2vec2/encoder/layers.11/Add
INFO:nncf:Not adding activation input quantizer for operation: 560 /wav2vec2/encoder/layers.11/layer_norm/Div
566 /wav2vec2/encoder/layers.11/layer_norm/Mul
575 /wav2vec2/encoder/layers.11/layer_norm/Add_1

INFO:nncf:Not adding activation input quantizer for operation: 583 /wav2vec2/encoder/layers.11/Add_1
INFO:nncf:Not adding activation input quantizer for operation: 591 /wav2vec2/encoder/layers.11/final_layer_norm/Div
596 /wav2vec2/encoder/layers.11/final_layer_norm/Mul
599 /wav2vec2/encoder/layers.11/final_layer_norm/Add_1
Statistics collection: 100%|██████████| 300/300 [02:51<00:00,  1.75it/s]
Biases correction: 100%|██████████| 74/74 [00:25<00:00,  2.96it/s]
MODEL_NAME = 'quantized_wav2vec2_base'
quantized_model_path = Path(f"{MODEL_NAME}_openvino_model/{MODEL_NAME}_quantized.xml")
ov.save_model(quantized_model, str(quantized_model_path))

Model Usage Example with Inference Pipeline

Both initial (FP16) and quantized (INT8) models are exactly the same in use.

Start with taking one example from the dataset to show inference steps for it.

Next, load the quantized model to the inference pipeline.

audio = LibriSpeechDataLoader.read_flac(f'{DATA_DIR}/LibriSpeech/test-clean/121/127105/121-127105-0017.flac')

ipd.Audio(audio, rate=16000)
core = ov.Core()

compiled_model = core.compile_model(model=quantized_model, device_name='CPU')

input_data = np.expand_dims(audio, axis=0)
output_layer = compiled_model.outputs[0]

Next, make a prediction.

predictions = compiled_model([input_data])[output_layer]

Validate model accuracy on dataset

The code below is used for running model inference on a single sample from the dataset. It contains the following steps:

  • Define MetricWER class to calculate Word Error Rate.

  • Define dataloader for test dataset.

  • Define functions to get inference for PyTorch and OpenVINO models.

  • Define functions to compute Word Error Rate.

class MetricWER:
    alphabet = [
        "<pad>", "<s>", "</s>", "<unk>", "|",
        "e", "t", "a", "o", "n", "i", "h", "s", "r", "d", "l", "u",
        "m", "w", "c", "f", "g", "y", "p", "b", "v", "k", "'", "x", "j", "q", "z"]
    words_delimiter = '|'
    pad_token = '<pad>'

    # Required methods
    def __init__(self):
        self._name = "WER"
        self._sum_score = 0
        self._sum_words = 0
        self._cur_score = 0
        self._decoding_vocab = dict(enumerate(self.alphabet))

    @property
    def value(self):
        """Returns accuracy metric value for the last model output."""
        return {self._name: self._cur_score}

    @property
    def avg_value(self):
        """Returns accuracy metric value for all model outputs."""
        return {self._name: self._sum_score / self._sum_words if self._sum_words != 0 else 0}

    def update(self, output, target):
        """
        Updates prediction matches.

        :param output: model output
        :param target: annotations
        """
        decoded = [decode_logits(i) for i in output]
        target = [i.lower() for i in target]
        assert len(output) == len(target), "sizes of output and target mismatch!"
        for i in range(len(output)):
            self._get_metric_per_sample(decoded[i], target[i])

    def reset(self):
        """
        Resets collected matches
        """
        self._sum_score = 0
        self._sum_words = 0

    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {"direction": "higher-worse", "type": "WER"}}

    # Methods specific to the current implementation
    def _get_metric_per_sample(self, annotation, prediction):
        cur_score = self._editdistance_eval(annotation.split(), prediction.split())
        cur_words = len(annotation.split())

        self._sum_score += cur_score
        self._sum_words += cur_words
        self._cur_score = cur_score / cur_words

        result = cur_score / cur_words if cur_words != 0 else 0
        return result

    def _editdistance_eval(self, source, target):
        n, m = len(source), len(target)

        distance = np.zeros((n + 1, m + 1), dtype=int)
        distance[:, 0] = np.arange(0, n + 1)
        distance[0, :] = np.arange(0, m + 1)

        for i in range(1, n + 1):
            for j in range(1, m + 1):
                cost = 0 if source[i - 1] == target[j - 1] else 1

                distance[i][j] = min(distance[i - 1][j] + 1,
                                     distance[i][j - 1] + 1,
                                     distance[i - 1][j - 1] + cost)
        return distance[n][m]

Now, you just need to decode predicted probabilities to text, using tokenizer decode_logits.

Alternatively, use a built-in Wav2Vec2Processor tokenizer from the transformers package.

def decode_logits(logits):
    decoding_vocab = dict(enumerate(MetricWER.alphabet))
    token_ids = np.squeeze(np.argmax(logits, -1))
    tokens = [decoding_vocab[idx] for idx in token_ids]
    tokens = [token_group[0] for token_group in groupby(tokens)]
    tokens = [t for t in tokens if t != MetricWER.pad_token]
    res_string = ''.join([t if t != MetricWER.words_delimiter else ' ' for t in tokens]).strip()
    res_string = ' '.join(res_string.split(' '))
    res_string = res_string.lower()
    return res_string


predicted_text = decode_logits(predictions)
predicted_text
'it was almost the tone of hope  everybody will stay'
from tqdm.notebook import tqdm

import numpy as np


dataset_config = {"data_source": os.path.join(DATA_DIR, "LibriSpeech/test-clean")}
test_data_loader = LibriSpeechDataLoader(dataset_config, samples_limit=300)


# inference function for pytorch
def torch_infer(model, sample):
    output = model(torch.Tensor(sample[1]['inputs'])).logits
    output = output.detach().cpu().numpy()

    return output


# inference function for openvino
def ov_infer(model, sample):
    output = model.output(0)
    output = model(np.array(sample[1]['inputs']))[output]

    return output


def compute_wer(dataset, model, infer_fn):
    wer = MetricWER()
    for sample in tqdm(dataset):
        # run infer function on sample
        output = infer_fn(model, sample)
        # update metric on sample result
        wer.update(output, [sample[0][1]])

    return wer.avg_value

Now, compute WER for the original PyTorch model, OpenVINO IR model and quantized model.

compiled_fp32_ov_model = core.compile_model(ov_model)

pt_result = compute_wer(test_data_loader, torch_model, torch_infer)
ov_fp32_result = compute_wer(test_data_loader, compiled_fp32_ov_model, ov_infer)
quantized_result = compute_wer(test_data_loader, compiled_model, ov_infer)

print(f'[PyTorch]   Word Error Rate: {pt_result["WER"]:.4f}')
print(f'[OpenVino]  Word Error Rate: {ov_fp32_result["WER"]:.4f}')
print(f'[Quantized OpenVino]  Word Error Rate: {quantized_result["WER"]:.4f}')
0%|          | 0/300 [00:00<?, ?it/s]
0%|          | 0/300 [00:00<?, ?it/s]
0%|          | 0/300 [00:00<?, ?it/s]
[PyTorch]   Word Error Rate: 0.0292
[OpenVino]  Word Error Rate: 0.0292
[Quantized OpenVino]  Word Error Rate: 0.0422

Compare Performance of the Original and Quantized Models

Finally, use Benchmark Tool to measure the inference performance of the FP16 and INT8 models.

Note

For more accurate performance, it is recommended to run benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app --help to see an overview of all command-line options.

# Inference FP16 model (OpenVINO IR)
! benchmark_app -m $ir_model_path -shape [1,30480] -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.1.0-12050-e33de350633
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2023.1.0-12050-e33de350633
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 61.48 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'inputs': [1,30480]
[ INFO ] Reshape model took 28.87 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 644.15 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: torch_jit
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   NUM_STREAMS: 6
[ INFO ]   AFFINITY: Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS: 24
[ INFO ]   PERF_COUNT: False
[ INFO ]   INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]   ENABLE_CPU_PINNING: True
[ INFO ]   SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   ENABLE_HYPER_THREADING: True
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[ INFO ]   CPU_DENORMALS_OPTIMIZATION: False
[ INFO ]   CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 69.35 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            2748 iterations
[ INFO ] Duration:         60151.82 ms
[ INFO ] Latency:
[ INFO ]    Median:        131.23 ms
[ INFO ]    Average:       131.13 ms
[ INFO ]    Min:           67.66 ms
[ INFO ]    Max:           145.43 ms
[ INFO ] Throughput:   45.68 FPS
# Inference INT8 model (OpenVINO IR)
! benchmark_app -m $quantized_model_path -shape [1,30480] -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.1.0-12050-e33de350633
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2023.1.0-12050-e33de350633
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 81.97 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'inputs': [1,30480]
[ INFO ] Reshape model took 35.47 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 920.18 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: torch_jit
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   NUM_STREAMS: 6
[ INFO ]   AFFINITY: Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS: 24
[ INFO ]   PERF_COUNT: False
[ INFO ]   INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]   ENABLE_CPU_PINNING: True
[ INFO ]   SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   ENABLE_HYPER_THREADING: True
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[ INFO ]   CPU_DENORMALS_OPTIMIZATION: False
[ INFO ]   CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 52.31 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            4500 iterations
[ INFO ] Duration:         60105.34 ms
[ INFO ] Latency:
[ INFO ]    Median:        79.88 ms
[ INFO ]    Average:       79.99 ms
[ INFO ]    Min:           47.16 ms
[ INFO ]    Max:           106.32 ms
[ INFO ] Throughput:   74.87 FPS