Quantize Speech Recognition Models with accuracy control using NNCF PTQ API

This Jupyter notebook can be launched after a local installation only.

Github

This tutorial demonstrates how to apply INT8 quantization with accuracy control to the speech recognition model, known as Wav2Vec2, using the NNCF (Neural Network Compression Framework) 8-bit quantization with accuracy control in post-training mode (without the fine-tuning pipeline). This notebook uses a fine-tuned Wav2Vec2-Base-960h PyTorch model trained on the LibriSpeech ASR corpus. The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:

  • Download and prepare the Wav2Vec2 model and LibriSpeech dataset.

  • Define data loading and accuracy validation functionality.

  • Model quantization with accuracy control.

  • Compare Accuracy of original PyTorch model, OpenVINO FP16 and INT8 models.

  • Compare performance of the original and quantized models.

The advanced quantization flow allows to apply 8-bit quantization to the model with control of accuracy metric. This is achieved by keeping the most impactful operations within the model in the original precision. The flow is based on the Basic 8-bit quantization and has the following differences:

  • Besides the calibration dataset, a validation dataset is required to compute the accuracy metric. Both datasets can refer to the same data in the simplest case.

  • Validation function, used to compute accuracy metric is required. It can be a function that is already available in the source framework or a custom function.

  • Since accuracy validation is run several times during the quantization process, quantization with accuracy control can take more time than the Basic 8-bit quantization flow.

  • The resulted model can provide smaller performance improvement than the Basic 8-bit quantization flow because some of the operations are kept in the original precision.

NOTE: Currently, 8-bit quantization with accuracy control in NNCF is available only for models in OpenVINO representation.

The steps for the quantization with accuracy control are described below.

Table of contents:

%pip install -q "openvino>=2023.1.0"
%pip install -q "nncf>=2.6.0"
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu soundfile librosa transformers torch datasets torchmetrics

Imports

import numpy as np
import torch

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
2023-10-10 09:32:06.465943: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2023-10-10 09:32:06.505459: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-10 09:32:07.113533: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Prepare the Model

For instantiating PyTorch model class, we should use Wav2Vec2ForCTC.from_pretrained method with providing model ID for downloading from HuggingFace hub. Model weights and configuration files will be downloaded automatically in first time usage. Keep in mind that downloading the files can take several minutes and depends on your internet connection.

Additionally, we can create processor class which is responsible for model specific pre- and post-processing steps.

BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480


torch_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", ctc_loss_reduction="mean")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Convert it to the OpenVINO Intermediate Representation (OpenVINO IR)

import openvino as ov


default_input = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
ov_model = ov.convert_model(torch_model, example_input=default_input)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[ WARNING ]  Please fix your imports. Module %s has been moved to %s. The old module will be deleted in version %s.
/home/ea/work/ov_venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/ea/work/ov_venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):

Prepare LibriSpeech Dataset

For demonstration purposes, we will use short dummy version of LibriSpeech dataset - patrickvonplaten/librispeech_asr_dummy to speed up model evaluation. Model accuracy can be different from reported in the paper. For reproducing original accuracy, use librispeech_asr dataset.

from datasets import load_dataset


dataset = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
test_sample = dataset[0]["audio"]


# define preprocessing function for converting audio to input values for model
def map_to_input(batch):
    preprocessed_signal = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=batch['audio']['sampling_rate'])
    input_values = preprocessed_signal.input_values
    batch['input_values'] = input_values
    return batch


# apply preprocessing function to dataset and remove audio column, to save memory as we do not need it anymore
dataset = dataset.map(map_to_input, batched=False, remove_columns=["audio"])
Found cached dataset librispeech_asr_dummy (/home/ea/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc)
Loading cached processed dataset at /home/ea/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc/cache-dcb48242e67b91b1.arrow

Prepare calibration dataset

import nncf


def transform_fn(data_item):
    """
    Extract the model's input from the data item.
    The data item here is the data item that is returned from the data source per iteration.
    This function should be passed when the data item cannot be used as model's input.
    """
    return np.array(data_item["input_values"])


calibration_dataset = nncf.Dataset(dataset, transform_fn)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino

Prepare validation function

Define the validation function.

from torchmetrics import WordErrorRate
from tqdm.notebook import tqdm


def validation_fn(model, dataset):
    """
    Calculate and returns a metric for the model.
    """
    wer = WordErrorRate()
    for sample in dataset:
        # run infer function on sample
        output = model.output(0)
        logits = model(np.array(sample['input_values']))[output]
        predicted_ids = np.argmax(logits, axis=-1)
        transcription = processor.batch_decode(torch.from_numpy(predicted_ids))

        # update metric on sample result
        wer.update(transcription, [sample['text']])

    result = wer.compute()

    return 1 - result

Run quantization with accuracy control

You should provide the calibration dataset and the validation dataset. It can be the same dataset. - parameter max_drop defines the accuracy drop threshold. The quantization process stops when the degradation of accuracy metric on the validation dataset is less than the max_drop. The default value is 0.01. NNCF will stop the quantization and report an error if the max_drop value can’t be reached. - drop_type defines how the accuracy drop will be calculated: ABSOLUTE (used by default) or RELATIVE. - ranking_subset_size - size of a subset that is used to rank layers by their contribution to the accuracy drop. Default value is 300, and the more samples it has the better ranking, potentially. Here we use the value 25 to speed up the execution.

NOTE: Execution can take tens of minutes and requires up to 10 GB of free memory

from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters
from nncf.parameters import ModelType

quantized_model = nncf.quantize_with_accuracy_control(
    ov_model,
    calibration_dataset=calibration_dataset,
    validation_dataset=calibration_dataset,
    validation_fn=validation_fn,
    max_drop=0.01,
    drop_type=nncf.DropType.ABSOLUTE,
    model_type=ModelType.TRANSFORMER,
    advanced_accuracy_restorer_parameters=AdvancedAccuracyRestorerParameters(
        ranking_subset_size=25
    ),
)
Statistics collection:  24%|███████████████████████████████████▎                                                                                                             | 73/300 [00:12<00:37,  5.98it/s]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 41.01it/s]
INFO:nncf:36 ignored nodes was found by name in the NNCFGraph
Statistics collection:  24%|███████████████████████████████████▎                                                                                                             | 73/300 [00:22<01:08,  3.31it/s]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 74/74 [00:23<00:00,  3.09it/s]
INFO:nncf:Validation of initial model was started
INFO:nncf:Elapsed Time: 00:00:00
INFO:nncf:Elapsed Time: 00:00:11
INFO:nncf:Metric of initial model: 0.9469565153121948
INFO:nncf:Collecting values for each data item using the initial model
INFO:nncf:Elapsed Time: 00:00:09
INFO:nncf:Validation of quantized model was started
INFO:nncf:Elapsed Time: 00:00:22
INFO:nncf:Elapsed Time: 00:00:11
INFO:nncf:Metric of quantized model: 0.5
INFO:nncf:Collecting values for each data item using the quantized model
INFO:nncf:Elapsed Time: 00:00:06
INFO:nncf:Accuracy drop: 0.4469565153121948 (DropType.ABSOLUTE)
INFO:nncf:Accuracy drop: 0.4469565153121948 (DropType.ABSOLUTE)
INFO:nncf:Total number of quantized operations in the model: 94
INFO:nncf:Number of parallel processes to rank quantized operations: 14
INFO:nncf:ORIGINAL metric is used to rank quantizers
INFO:nncf:Calculating ranking score for groups of quantizers
INFO:nncf:Elapsed Time: 00:04:36
INFO:nncf:Changing the scope of quantizer nodes was started
INFO:nncf:Reverted 1 operations to the floating-point precision:
    __module.wav2vec2.feature_extractor.conv_layers.2.conv/aten::_convolution/Convolution_11
INFO:nncf:Accuracy drop with the new quantization scope is 0.06173914670944214 (DropType.ABSOLUTE)
INFO:nncf:Reverted 1 operations to the floating-point precision:
    __module.wav2vec2.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution_10
INFO:nncf:Accuracy drop with the new quantization scope is 0.010434746742248535 (DropType.ABSOLUTE)
INFO:nncf:Reverted 1 operations to the floating-point precision:
    __module.wav2vec2.feature_extractor.conv_layers.3.conv/aten::_convolution/Convolution_12
INFO:nncf:Algorithm completed: achieved required accuracy drop 0.006956517696380615 (DropType.ABSOLUTE)
INFO:nncf:3 out of 94 were reverted back to the floating-point precision:
    __module.wav2vec2.feature_extractor.conv_layers.2.conv/aten::_convolution/Convolution_11
    __module.wav2vec2.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution_10
    __module.wav2vec2.feature_extractor.conv_layers.3.conv/aten::_convolution/Convolution_12

Model Usage Example

import IPython.display as ipd


ipd.Audio(test_sample["array"], rate=16000)
core = ov.Core()

compiled_quantized_model = core.compile_model(model=quantized_model, device_name='CPU')

input_data = np.expand_dims(test_sample["array"], axis=0)

Next, make a prediction.

predictions = compiled_quantized_model([input_data])[0]
predicted_ids = np.argmax(predictions, axis=-1)
transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
transcription
['I E O WE WORD I O O FAGGI  FARE E BO']

Compare Accuracy of the Original and Quantized Models

  • Define dataloader for test dataset.

  • Define functions to get inference for PyTorch and OpenVINO models.

  • Define functions to compute Word Error Rate.

# inference function for pytorch
def torch_infer(model, sample):
    logits = model(torch.Tensor(sample['input_values'])).logits
    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    return transcription


# inference function for openvino
def ov_infer(model, sample):
    output = model.output(0)
    logits = model(np.array(sample['input_values']))[output]
    predicted_ids = np.argmax(logits, axis=-1)
    transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
    return transcription


def compute_wer(dataset, model, infer_fn):
    wer = WordErrorRate()
    for sample in tqdm(dataset):
        # run infer function on sample
        transcription = infer_fn(model, sample)
        # update metric on sample result
        wer.update(transcription, [sample['text']])
    # finalize metric calculation
    result = wer.compute()
    return result

Now, compute WER for the original PyTorch model and quantized model.

pt_result = compute_wer(dataset, torch_model, torch_infer)
quantized_result = compute_wer(dataset, compiled_quantized_model, ov_infer)

print(f'[PyTorch]   Word Error Rate: {pt_result:.4f}')
print(f'[Quantized OpenVino]  Word Error Rate: {quantized_result:.4f}')
0%|          | 0/73 [00:00<?, ?it/s]
0%|          | 0/73 [00:00<?, ?it/s]
[PyTorch]   Word Error Rate: 0.0530
[Quantized OpenVino]  Word Error Rate: 0.0600