Quantize Speech Recognition Models with OpenVINO Post-Training Optimization Tool ​

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This tutorial demonstrates how to apply INT8 quantization to the speech recognition model known as Wav2Vec2, using the Post-Training Optimization Tool API (POT API) (part of the OpenVINO Toolkit). We will use a fine-tuned Wav2Vec2-Base-960h PyTorch model trained on the LibriSpeech ASR corpus. The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:

  • Download and prepare the Wav2Vec2 model and LibriSpeech dataset

  • Define data loading and accuracy validation functionality

  • Prepare the model for quantization

  • Run optimization pipeline

  • Compare performance of the original and quantized models

Imports

import os
import sys
import time
import re
import numpy as np
import torch
import tarfile
from pathlib import Path
from itertools import groupby
import soundfile as sf
import IPython.display as ipd

from transformers import Wav2Vec2ForCTC
from openvino.runtime import Core
from openvino.tools.pot import Metric, DataLoader, IEEngine, \
    load_model, save_model, compress_model_weights, create_pipeline

sys.path.append("../utils")
from notebook_utils import download_file

Settings

# Set the data and model directories, model source URL and model filename
DATA_DIR = "data"
MODEL_DIR = "model"

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

Prepare the Model

Next steps include: - Download and unpack pre-trained Wav2Vec2 model - Convert model to ONNX - Run OpenVINO Model Optimizer tool to convert the model from the ONNX representation to the OpenVINO Intermediate Representation (IR)

download_file("https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin", directory=Path(MODEL_DIR) / 'pytorch', show_progress=True)
download_file("https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json", directory=Path(MODEL_DIR) / 'pytorch', show_progress=False)
model/pytorch/pytorch_model.bin:   0%|          | 0.00/360M [00:00<?, ?B/s]
PosixPath('/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/pytorch/config.json')

Import all dependencies to load the original PyTorch model and convert it to the ONNX representation.

BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480


def export_model_to_onnx(model, path):
    with torch.no_grad():
        default_input = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
        inputs = {
            "inputs": default_input
        }
        symbolic_names = {0: "batch_size", 1: "sequence_len"}
        torch.onnx.export(
            model,
            (inputs["inputs"]),
            path,
            opset_version=11,
            input_names=["inputs"],
            output_names=["logits"],
            dynamic_axes={
                "inputs": symbolic_names,
                "logits": symbolic_names,
            },
        )
        print("ONNX model saved to {}".format(path))


torch_model = Wav2Vec2ForCTC.from_pretrained(Path(MODEL_DIR) / 'pytorch')
onnx_model_path = Path(MODEL_DIR) / "wav2vec2_base.onnx"
if not onnx_model_path.exists():
    export_model_to_onnx(torch_model, onnx_model_path)
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at model/pytorch and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
ONNX model saved to model/wav2vec2_base.onnx

Convert the ONNX Model to OpenVINO IR

ir_model_xml = onnx_model_path.with_suffix(".xml")
ir_model_bin = onnx_model_path.with_suffix(".bin")

if not ir_model_xml.exists():
    !mo --input_model $onnx_model_path --output_dir $MODEL_DIR --input_shape [1,-1] --data_type FP16
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/wav2vec2_base.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model
    - IR output name:   wav2vec2_base
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,-1]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/wav2vec2_base.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/wav2vec2_base.bin
[ SUCCESS ] Total execution time: 2.69 seconds.
[ SUCCESS ] Memory consumed: 818 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Prepare LibriSpeech Dataset

The code below will download and unpack the archive with ‘test-clean’ subset of LibriSpeech Dataset

download_file("http://openslr.elda.org/resources/12/test-clean.tar.gz", directory=DATA_DIR, show_progress=True)

if not os.path.exists(f'{DATA_DIR}/LibriSpeech'):
    with tarfile.open(f"{DATA_DIR}/test-clean.tar.gz") as tar:
        tar.extractall(path=DATA_DIR)
data/test-clean.tar.gz:   0%|          | 0.00/331M [00:00<?, ?B/s]

Define DataLoader for POT

In this step, we need to define DataLoader based on POT API. It will be used to collect statistics for quantization and run model evaluation. Wav2Vec2 model accepts a raw waveform of the speech signal as input and produces vocabulary class estimations as output. Since dataset contains audio files in flac-format, we use ‘soundfile’ package to convert them to waveform.

NOTE: Consider increase samples_limit to get more precise results, suggested value is 300 or more, it will take longer time to process.

class LibriSpeechDataLoader(DataLoader):
    samples_limit = 50

    @staticmethod
    def read_flac(file_name):
        speech, samplerate = sf.read(file_name)
        assert samplerate == 16000, "read_flac: only 16kHz supported!"
        return speech

    # Required methods
    def __init__(self, config):
        """Constructor
        :param config: data loader specific config
        """
        super().__init__(config)
        self._data_dir = config["data_source"]
        self._ds = []
        self._prepare_dataset()

    def __len__(self):
        """Returns size of the dataset"""
        return len(self._ds)

    def __getitem__(self, index):
        """
        Returns annotation, data and metadata at the specified index.
        Possible formats:
        (index, annotation), data
        (index, annotation), data, metadata
        """
        label = self._ds[index][0]
        inputs = {'inputs': np.expand_dims(self._ds[index][1], axis=0)}
        return label, inputs

    # Methods specific to the current implementation
    def _prepare_dataset(self):
        pattern = re.compile(r'([0-9\-]+)\s+(.+)')
        data_folder = Path(self._data_dir)
        txts = list(data_folder.glob('**/*.txt'))
        counter = 0
        for txt in txts:
            content = txt.open().readlines()
            for line in content:
                res = pattern.search(line)
                if not res:
                    continue
                name = res.group(1)
                transcript = res.group(2)
                fname = txt.parent / name
                fname = fname.with_suffix('.flac')
                identifier = str(fname.relative_to(data_folder))
                self._ds.append(((counter, transcript.upper()), LibriSpeechDataLoader.read_flac(os.path.join(self._data_dir, identifier))))
                counter += 1
                if counter >= self.samples_limit:
                    # Limit exceeded
                    return

Define WER Metric Calculation

At this step the Metric interface for WER metric is implemented. It is used for validating the accuracy of the model. WER stands for Word Error Rate, you can find more details on the Wiki page.

class MetricWER(Metric):
    alphabet = [
        "<pad>", "<s>", "</s>", "<unk>", "|",
        "e", "t", "a", "o", "n", "i", "h", "s", "r", "d", "l", "u",
        "m", "w", "c", "f", "g", "y", "p", "b", "v", "k", "'", "x", "j", "q", "z"]
    words_delimiter = '|'
    pad_token = '<pad>'

    @staticmethod
    def decode_logits(logits):
        decoding_vocab = dict(enumerate(MetricWER.alphabet))
        token_ids = np.squeeze(np.argmax(logits, -1))
        tokens = [decoding_vocab[idx] for idx in token_ids]
        tokens = [token_group[0] for token_group in groupby(tokens)]
        tokens = [t for t in tokens if t != MetricWER.pad_token]
        res_string = ''.join([t if t != MetricWER.words_delimiter else ' ' for t in tokens]).strip()
        res_string = ' '.join(res_string.split(' '))
        res_string = res_string.lower()
        return res_string

    # Required methods
    def __init__(self):
        super().__init__()
        self._name = "WER"
        self._sum_score = 0
        self._sum_words = 0
        self._cur_score = 0
        self._decoding_vocab = dict(enumerate(self.alphabet))

    @property
    def value(self):
        """Returns accuracy metric value for the last model output."""
        return {self._name: self._cur_score}

    @property
    def avg_value(self):
        """Returns accuracy metric value for all model outputs."""
        return {self._name: self._sum_score / self._sum_words if self._sum_words != 0 else 0}

    def update(self, output, target):
        """
        Updates prediction matches.

        :param output: model output
        :param target: annotations
        """
        decoded = [self.decode_logits(i) for i in output]
        target = [i.lower() for i in target]
        assert len(output) == len(target), "sizes of output and target mismatch!"
        for i in range(len(output)):
            self._get_metric_per_sample(decoded[i], target[i])

    def reset(self):
        """
        Resets collected matches
        """
        self._sum_score = 0
        self._sum_words = 0

    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {"direction": "higher-worse", "type": "WER"}}

    # Methods specific to the current implementation
    def _get_metric_per_sample(self, annotation, prediction):
        cur_score = self._editdistance_eval(annotation.split(), prediction.split())
        cur_words = len(annotation.split())

        self._sum_score += cur_score
        self._sum_words += cur_words
        self._cur_score = cur_score / cur_words

        result = cur_score / cur_words if cur_words != 0 else 0
        return result

    def _editdistance_eval(self, source, target):
        n, m = len(source), len(target)

        distance = np.zeros((n + 1, m + 1), dtype=int)
        distance[:, 0] = np.arange(0, n + 1)
        distance[0, :] = np.arange(0, m + 1)

        for i in range(1, n + 1):
            for j in range(1, m + 1):
                cost = 0 if source[i - 1] == target[j - 1] else 1

                distance[i][j] = min(distance[i - 1][j] + 1,
                                     distance[i][j - 1] + 1,
                                     distance[i - 1][j - 1] + cost)
        return distance[n][m]

Run Quantization Pipeline

Here we define a configuration for our quantization pipeline and run it. Please note that we use built-in IEEngine implementation of Engine interface from the POT API for model inference.

model_config = {"model_name": "wav2vec2_base", "model": ir_model_xml, "weights": ir_model_bin}

engine_config = {"device": "CPU"}

dataset_config = {"data_source": os.path.join(DATA_DIR, "LibriSpeech/test-clean")}

algorithms = [
    {
        "name": "DefaultQuantization",
        "params": {
            "target_device": "ANY",
            "model_type": "transformer",
            "preset": "performance",
            "stat_subset_size": 300,
            "activations": {
                "range_estimator": {
                    "min": {
                        "aggregator": "min",
                        "type": "min"
                    },
                    "max": {
                        "aggregator": "mean",
                        "type": "quantile",
                        "outlier_prob": 0.0001
                    }
                }
            },
            "ignored": {
                "scope": ["214"]
            }
        }
    }
]

# Step 1: Load the model
model = load_model(model_config=model_config)

# Step 2: Initialize the data loader
data_loader = LibriSpeechDataLoader(config=dataset_config)

# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric
metric = MetricWER()

# Step 4: Initialize the engine for metric calculation and statistics collection
engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)

# Step 5: Create a pipeline of compression algorithms
pipeline = create_pipeline(algo_config=algorithms, engine=engine)

# Step 6 (Optional): Evaluate the original model. Print the results
start_time = time.perf_counter()
fp_results = pipeline.evaluate(model=model)
end_time = time.perf_counter()
print(f"Evaluation finished in {end_time - start_time:.2f} seconds")
if fp_results:
    print("FP16 model results:")
    for name, value in fp_results.items():
        print(f"{name}: {value:.5f}")
Evaluation finished in 11.14 seconds
FP16 model results:
WER: 0.01724
# Step 7: Execute the pipeline
print(f"Quantizing model with {algorithms[0]['params']['preset']} preset and {algorithms[0]['name']}")
start_time = time.perf_counter()
compressed_model = pipeline.run(model=model)
end_time = time.perf_counter()
print(f"Quantization finished in {end_time - start_time:.2f} seconds")

# Step 8 (Optional): Compress model weights to quantized precision
#                    in order to reduce the size of final .bin file
compress_model_weights(model=compressed_model)

# Step 9: Save the compressed model to the desired path
compressed_model_paths = save_model(model=compressed_model, save_path=MODEL_DIR, model_name="quantized_wav2vec2_base")
compressed_model_xml = compressed_model_paths[0]["model"]
Quantizing model with performance preset and DefaultQuantization
Quantization finished in 89.63 seconds
# Step 10 (Optional): Evaluate the compressed model and print the results
int_results = pipeline.evaluate(model=compressed_model)

if int_results:
    print("INT8 model results:")
    for name, value in int_results.items():
        print(f"{name}: {value:.5f}")
INT8 model results:
WER: 0.07338

Model Usage Example with Inference Pipeline

Both initial (FP16) and quntized (INT8) models are exactly the same in the use

We start with taking one example from the dataset to show inference steps for it

audio = LibriSpeechDataLoader.read_flac(f'{DATA_DIR}/LibriSpeech/test-clean/121/127105/121-127105-0017.flac')

ipd.Audio(audio, rate=16000)

Next, we load quantized model to the inference pipeline

ie = Core()

model = ie.read_model(compressed_model_xml)
compiled_model = ie.compile_model(model=model, device_name='CPU')

input_data = np.expand_dims(audio, axis=0)
output_layer = compiled_model.outputs[0]

Time to make a prediction

predictions = compiled_model([input_data])[output_layer]

Now we just need to decode predicted probabilites to text using tokenizer from MetricWER class

Alternatively, you can use built-in tokenizer Wav2Vec2Processor from transformers package

predicted_text = MetricWER.decode_logits(predictions)

predicted_text
'it was almost the tone of hope everybody will stay'

Compare Performance of the Original and Quantized Models

Finally, we will measure the inference performance of the FP16 and INT8 models. To do this, we use Benchmark Tool - OpenVINO’s inference performance measurement tool.

NOTE: For more accurate performance, we recommended running benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app --help to see an overview of all command line options.

# Inference FP16 model (IR)
! benchmark_app -m $ir_model_xml -shape [1,30480] -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading OpenVINO
[ WARNING ] PerformanceMode was not explicitly specified in command line. Device CPU performance hint will be set to THROUGHPUT.
[ INFO ] OpenVINO:
         API version............. 2022.1.0-7019-cdb9bec7210-releases/2022/1
[ INFO ] Device info
         CPU
         openvino_intel_cpu_plugin version 2022.1
         Build................... 2022.1.0-7019-cdb9bec7210-releases/2022/1

[Step 3/11] Setting device configuration
[ WARNING ] -nstreams default value is determined automatically for CPU device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 4/11] Reading network files
[ INFO ] Read model took 168.08 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Reshaping model: 'inputs': {1,30480}
[ INFO ] Reshape model took 19.01 ms
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model input 'inputs' precision f32, dimensions ([...]): 1 30480
[ INFO ] Model output 'logits' precision f32, dimensions ([...]): 1 95 32
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 668.98 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] DEVICE: CPU
[ INFO ]   AVAILABLE_DEVICES  , ['']
[ INFO ]   RANGE_FOR_ASYNC_INFER_REQUESTS  , (1, 1, 1)
[ INFO ]   RANGE_FOR_STREAMS  , (1, 24)
[ INFO ]   FULL_DEVICE_NAME  , Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
[ INFO ]   OPTIMIZATION_CAPABILITIES  , ['WINOGRAD', 'FP32', 'FP16', 'INT8', 'BIN', 'EXPORT_IMPORT']
[ INFO ]   CACHE_DIR  ,
[ INFO ]   NUM_STREAMS  , 6
[ INFO ]   INFERENCE_NUM_THREADS  , 0
[ INFO ]   PERF_COUNT  , False
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS  , 0
[Step 9/11] Creating infer requests and preparing input data
[ INFO ] Create 6 infer requests took 1.94 ms
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests using 6 streams for CPU, inference only: True, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 105.21 ms
[Step 11/11] Dumping statistics report
Count:          2196 iterations
Duration:       60213.74 ms
Latency:
    Median:     166.55 ms
    AVG:        164.32 ms
    MIN:        124.75 ms
    MAX:        190.24 ms
Throughput: 36.47 FPS
# Inference INT8 model (IR)
! benchmark_app -m $compressed_model_xml -shape [1,30480] -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading OpenVINO
[ WARNING ] PerformanceMode was not explicitly specified in command line. Device CPU performance hint will be set to THROUGHPUT.
[ INFO ] OpenVINO:
         API version............. 2022.1.0-7019-cdb9bec7210-releases/2022/1
[ INFO ] Device info
         CPU
         openvino_intel_cpu_plugin version 2022.1
         Build................... 2022.1.0-7019-cdb9bec7210-releases/2022/1

[Step 3/11] Setting device configuration
[ WARNING ] -nstreams default value is determined automatically for CPU device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 4/11] Reading network files
[ INFO ] Read model took 144.82 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Reshaping model: 'inputs': {1,30480}
[ INFO ] Reshape model took 21.97 ms
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model input 'inputs' precision f32, dimensions ([...]): 1 30480
[ INFO ] Model output 'logits' precision f32, dimensions ([...]): 1 95 32
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 433.75 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] DEVICE: CPU
[ INFO ]   AVAILABLE_DEVICES  , ['']
[ INFO ]   RANGE_FOR_ASYNC_INFER_REQUESTS  , (1, 1, 1)
[ INFO ]   RANGE_FOR_STREAMS  , (1, 24)
[ INFO ]   FULL_DEVICE_NAME  , Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz
[ INFO ]   OPTIMIZATION_CAPABILITIES  , ['WINOGRAD', 'FP32', 'FP16', 'INT8', 'BIN', 'EXPORT_IMPORT']
[ INFO ]   CACHE_DIR  ,
[ INFO ]   NUM_STREAMS  , 6
[ INFO ]   INFERENCE_NUM_THREADS  , 0
[ INFO ]   PERF_COUNT  , False
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS  , 0
[Step 9/11] Creating infer requests and preparing input data
[ INFO ] Create 6 infer requests took 1.65 ms
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests using 6 streams for CPU, inference only: True, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 26.00 ms
[Step 11/11] Dumping statistics report
Count:          7764 iterations
Duration:       60062.31 ms
Latency:
    Median:     45.92 ms
    AVG:        46.31 ms
    MIN:        40.88 ms
    MAX:        61.79 ms
Throughput: 129.27 FPS