Quantize Speech Recognition Models using NNCF PTQ API¶
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:
This tutorial demonstrates how to use the NNCF (Neural Network Compression Framework) 8-bit quantization in post-training mode (without the fine-tuning pipeline) to optimize the speech recognition model, known as Data2Vec for the high-speed inference via OpenVINO™ Toolkit. This notebook uses a fine-tuned data2vec-audio-base-960h PyTorch model trained on the LibriSpeech ASR corpus. The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:
Download and prepare model.
Define data loading and accuracy validation functionality.
Prepare the model for quantization and quantize.
Compare performance of the original and quantized models.
Compare Accuracy of the Original and Quantized Models.
Table of contents:¶
Download and prepare model¶
data2vec is a framework for self-supervised representation learning for images, speech, and text as described in data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language (Baevski et al., 2022). The algorithm uses the same learning mechanism for different modalities.
In our case, we will use data2vec-audio-base-960h
model, which was
finetuned on 960 hours of audio from LibriSpeech Automatic Speech
Recognition corpus and distributed as part of HuggingFace transformers.
Obtain Pytorch model representation¶
For instantiating PyTorch model class, we should use
Data2VecAudioForCTC.from_pretrained
method with providing model ID
for downloading from HuggingFace hub. Model weights and configuration
files will be downloaded automatically in first time usage. Keep in mind
that downloading the files can take several minutes and depends on your
internet connection.
Additionally, we can create processor class which is responsible for model specific pre- and post-processing steps.
%pip install -q "openvino>=2023.3.0" "nncf>=2.7"
%pip install datasets "torchmetrics>=0.11.0" "torch>=2.1.0" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q soundfile librosa "transformers>=4.36.2" --extra-index-url https://download.pytorch.org/whl/cpu
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
model = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
Convert model to OpenVINO Intermediate Representation¶
from pathlib import Path
# Set model directory
MODEL_DIR = Path("model")
MODEL_DIR.mkdir(exist_ok=True)
import openvino as ov
import torch
core = ov.Core()
BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480
ir_model_path = MODEL_DIR / "data2vec-audo-base.xml"
if not ir_model_path.exists():
ov_model = ov.convert_model(model, example_input=torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float))
ov.save_model(ov_model, str(ir_model_path))
print("IR model saved to {}".format(ir_model_path))
else:
print("Read IR model from {}".format(ir_model_path))
ov_model = core.read_model(ir_model_path)
Prepare inference data¶
For demonstration purposes, we will use short dummy version of
LibriSpeech dataset - patrickvonplaten/librispeech_asr_dummy
to
speed up model evaluation. Model accuracy can be different from reported
in the paper. For reproducing original accuracy, use librispeech_asr
dataset.
from datasets import load_dataset
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# define preprocessing function for converting audio to input values for model
def map_to_input(batch):
preprocessed_signal = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=batch['audio']['sampling_rate'])
input_values = preprocessed_signal.input_values
batch['input_values'] = input_values
return batch
# apply preprocessing function to dataset and remove audio column, to save memory as we do not need it anymore
dataset = ds.map(map_to_input, batched=False, remove_columns=["audio"])
test_sample = ds[0]["audio"]
Check model inference result¶
The code below is used for running model inference on a single sample from the dataset. It contains the following steps:
Get the input_values tensor as model input.
Run model inference and obtain logits.
Find logits ids with highest probability, using argmax.
Decode predicted token ids, using processor.
For reference, see the same function provided for OpenVINO model.
import numpy as np
# inference function for pytorch
def torch_infer(model, sample):
logits = model(torch.Tensor(sample['input_values'])).logits
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
return transcription
# inference function for openvino
def ov_infer(model, sample):
output = model.output(0)
logits = model(np.array(sample['input_values']))[output]
predicted_ids = np.argmax(logits, axis=-1)
transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
return transcription
core = ov.Core()
pt_transcription = torch_infer(model, dataset[0])
compiled_model = core.compile_model(ov_model)
ov_transcription = ov_infer(compiled_model, dataset[0])
import IPython.display as ipd
print(f"[Reference]: {dataset[0]['text']}")
print(f"[PyTorch]: {pt_transcription[0]}")
print(f"[OpenVINO FP16]: {ov_transcription[0]}")
ipd.Audio(test_sample["array"], rate=16000)
[Reference]: BECAUSE YOU WERE SLEEPING INSTEAD OF CONQUERING THE LOVELY ROSE PRINCESS HAS BECOME A FIDDLE WITHOUT A BOW WHILE POOR SHAGGY SITS THERE A COOING DOVE
[PyTorch]: BECAUSE YOU WERE SLEEPING INSTEAD OF CONQUERING THE LOVELY RUSE PRINCESS HAS BECOME A FIDDLE WITHOUT A BOW A POOR SHAGGY SITS THERE ACCOOING DOVE
[OpenVINO FP16]: BECAUSE YOU WERE SLEEPING INSTEAD OF CONQUERING THE LOVELY RUSE PRINCESS HAS BECOME A FIDDLE WITHOUT A BOW A POOR SHAGGY SITS THERE ACCOOING DOVE
Validate model accuracy on dataset¶
For model accuracy evaluation, Word Error Rate metric can be used. Word Error Rate or WER is the ratio of errors in a transcript to the total words spoken. A lower WER in speech-to-text means better accuracy in recognizing speech.
For WER calculation, we will use torchmetrics library.
from torchmetrics.text import WordErrorRate
from tqdm.notebook import tqdm
def compute_wer(dataset, model, infer_fn):
wer = WordErrorRate()
for sample in tqdm(dataset):
# run infer function on sample
transcription = infer_fn(model, sample)
# update metric on sample result
wer.update(transcription, [sample['text']])
# finalize metric calculation
result = wer.compute()
return result
pt_result = compute_wer(dataset, model, torch_infer)
ov_result = compute_wer(dataset, compiled_model, ov_infer)
0%| | 0/73 [00:00<?, ?it/s]
0%| | 0/73 [00:00<?, ?it/s]
print(f'[PyTorch] Word Error Rate: {pt_result:.4f}')
print(f'[OpenVino] Word Error Rate: {ov_result:.4f}')
[PyTorch] Word Error Rate: 0.0383
[OpenVino] Word Error Rate: 0.0383
Quantization¶
NNCF provides a suite of advanced algorithms for Neural Networks inference optimization in OpenVINO with minimal accuracy drop.
Create a quantized model from the pre-trained FP16
model and the
calibration dataset. The optimization process contains the following
steps:
Create a Dataset for quantization.
Run
nncf.quantize
for getting an optimized model. Thenncf.quantize
function provides an interface for model quantization. It requires an instance of the OpenVINO Model and quantization dataset. Optionally, some additional parameters for the configuration quantization process (number of samples for quantization, preset, ignored scope, etc.) can be provided. For more accurate results, we should keep the operation in the postprocessing subgraph in floating point precision, using theignored_scope
parameter. For more information see Tune quantization parameters.Serialize OpenVINO IR model using
ov.save_model
function.
import nncf
from nncf.parameters import ModelType
def transform_fn(data_item):
"""
Extract the model's input from the data item.
The data item here is the data item that is returned from the data source per iteration.
This function should be passed when the data item cannot be used as model's input.
"""
return np.array(data_item["input_values"])
calibration_dataset = nncf.Dataset(dataset, transform_fn)
quantized_model = nncf.quantize(
ov_model,
calibration_dataset,
model_type=ModelType.TRANSFORMER, # specify additional transformer patterns in the model
subset_size=len(dataset),
ignored_scope=nncf.IgnoredScope(
names=[
"__module.data2vec_audio.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution_96",
],
),
)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino
Output()
Output()
INFO:nncf:1 ignored nodes were found by name in the NNCFGraph
INFO:nncf:36 ignored nodes were found by name in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 10 __module.data2vec_audio.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution_96
Output()
Output()
After quantization is finished, compressed model representation can be
saved using ov.save_model
function.
MODEL_NAME = 'quantized_data2vec_base'
quantized_model_path = Path(f"{MODEL_NAME}_openvino_model/{MODEL_NAME}_quantized.xml")
ov.save_model(quantized_model, quantized_model_path)
Check INT8 model inference result¶
INT8
model is the same in usage like the original one. We need to
read it, using the core.read_model
method and load on the device,
using core.compile_model
. After that, we can reuse the same
ov_infer
function for getting model inference result on test sample.
int8_compiled_model = core.compile_model(quantized_model)
transcription = ov_infer(int8_compiled_model, dataset[0])
print(f"[Reference]: {dataset[0]['text']}")
print(f"[OpenVINO INT8]: {transcription[0]}")
ipd.Audio(test_sample["array"], rate=16000)
[Reference]: BECAUSE YOU WERE SLEEPING INSTEAD OF CONQUERING THE LOVELY ROSE PRINCESS HAS BECOME A FIDDLE WITHOUT A BOW WHILE POOR SHAGGY SITS THERE A COOING DOVE
[OpenVINO INT8]: BECAUSE YOU WERE SLEEPING INSTEAD OF CONQUERING THE LOVELY RUSE PRINCESS HAS BECOME A FIDDLE WITHOUT A BOW ALE POORA SHAGGY SITS THERE ACCOOING DOVE
Compare Performance of the Original and Quantized Models¶
Benchmark
Tool
is used to measure the inference performance of the FP16
and
INT8
models.
NOTE: For more accurate performance, it is recommended to run
benchmark_app
in a terminal/command prompt after closing other applications. Runbenchmark_app -m model.xml -d CPU
to benchmark async inference on CPU for one minute. ChangeCPU
toGPU
to benchmark on GPU. Runbenchmark_app --help
to see an overview of all command-line options.
# Inference FP16 model (OpenVINO IR)
! benchmark_app -m $ir_model_path -shape [1,30480] -d CPU -api async -t 15
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.2.0-13089-cfd42bd2cb0-HEAD
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2023.2.0-13089-cfd42bd2cb0-HEAD
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 56.07 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ] input_values (node: input_values) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ] 1289 , logits (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'input_values': [1,30480]
[ INFO ] Reshape model took 36.04 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ] input_values (node: input_values) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ] 1289 , logits (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 729.25 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ] NETWORK_NAME: Model0
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12
[ INFO ] NUM_STREAMS: 12
[ INFO ] AFFINITY: Affinity.CORE
[ INFO ] INFERENCE_NUM_THREADS: 36
[ INFO ] PERF_COUNT: False
[ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ] ENABLE_CPU_PINNING: True
[ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ] ENABLE_HYPER_THREADING: True
[ INFO ] EXECUTION_DEVICES: ['CPU']
[ INFO ] CPU_DENORMALS_OPTIMIZATION: False
[ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'input_values'!. This input will be filled with random values!
[ INFO ] Fill input 'input_values' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 12 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 82.91 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count: 732 iterations
[ INFO ] Duration: 15378.93 ms
[ INFO ] Latency:
[ INFO ] Median: 251.03 ms
[ INFO ] Average: 251.09 ms
[ INFO ] Min: 121.95 ms
[ INFO ] Max: 298.62 ms
[ INFO ] Throughput: 47.60 FPS
# Inference INT8 model (OpenVINO IR)
! benchmark_app -m $quantized_model_path -shape [1,30480] -d CPU -api async -t 15
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.2.0-13089-cfd42bd2cb0-HEAD
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2023.2.0-13089-cfd42bd2cb0-HEAD
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 56.45 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ] input_values (node: input_values) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ] logits , 1289 (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'input_values': [1,30480]
[ INFO ] Reshape model took 56.70 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ] input_values (node: input_values) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ] logits , 1289 (node: __module.lm_head/aten::linear/Add) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 1359.11 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ] NETWORK_NAME: Model0
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 12
[ INFO ] NUM_STREAMS: 12
[ INFO ] AFFINITY: Affinity.CORE
[ INFO ] INFERENCE_NUM_THREADS: 36
[ INFO ] PERF_COUNT: False
[ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ] ENABLE_CPU_PINNING: True
[ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ] ENABLE_HYPER_THREADING: True
[ INFO ] EXECUTION_DEVICES: ['CPU']
[ INFO ] CPU_DENORMALS_OPTIMIZATION: False
[ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'input_values'!. This input will be filled with random values!
[ INFO ] Fill input 'input_values' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 12 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 80.96 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count: 1104 iterations
[ INFO ] Duration: 15213.04 ms
[ INFO ] Latency:
[ INFO ] Median: 164.32 ms
[ INFO ] Average: 164.68 ms
[ INFO ] Min: 100.60 ms
[ INFO ] Max: 253.01 ms
[ INFO ] Throughput: 72.57 FPS
Compare Accuracy of the Original and Quantized Models¶
Finally, calculate WER metric for the INT8
model representation and
compare it with the FP16
result.
int8_ov_result = compute_wer(dataset, int8_compiled_model, ov_infer)
print(f'[OpenVino FP16] Word Error Rate: {ov_result:.4}')
print(f'[OpenVino INT8] Word Error Rate: {int8_ov_result:.4f}')
0%| | 0/73 [00:00<?, ?it/s]
[OpenVino FP16] Word Error Rate: 0.03826
[OpenVino INT8] Word Error Rate: 0.0452