Quantize Wav2Vec Speech Recognition Model using NNCF PTQ API#
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:
This tutorial demonstrates how to apply INT8
quantization to the
speech recognition model, known as
Wav2Vec2,
using the NNCF (Neural Network Compression Framework) 8-bit quantization
in post-training mode (without the fine-tuning pipeline). This notebook
uses a fine-tuned
Wav2Vec2-Base-960h
PyTorch model trained on the LibriSpeech ASR
corpus. The tutorial is designed to be
extendable to custom models and datasets. It consists of the following
steps:
Download and prepare the Wav2Vec2 model and LibriSpeech dataset.
Define data loading and accuracy validation functionality.
Model quantization.
Compare Accuracy of original PyTorch model, OpenVINO FP16 and INT8 models.
Compare performance of the original and quantized models.
Table of contents:
Installation Instructions#
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
%pip install -q "openvino>=2023.3.0" "nncf>=2.7"
%pip install datasets "torchmetrics>=0.11.0" "torch>=2.1.0" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q soundfile librosa "transformers>=4.36.2" --extra-index-url https://download.pytorch.org/whl/cpu
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu
Requirement already satisfied: datasets in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (3.0.2)
Requirement already satisfied: torchmetrics>=0.11.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (1.5.0)
Requirement already satisfied: torch>=2.1.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (2.4.1+cpu)
Requirement already satisfied: filelock in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (3.16.1)
Requirement already satisfied: numpy>=1.17 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (1.24.4)
Requirement already satisfied: pyarrow>=15.0.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (17.0.0)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (2.0.3)
Requirement already satisfied: requests>=2.32.2 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (2.32.3)
Requirement already satisfied: tqdm>=4.66.3 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (4.66.5)
Requirement already satisfied: xxhash in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (3.5.0)
Requirement already satisfied: multiprocess<0.70.17 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)
Requirement already satisfied: aiohttp in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (3.10.10)
Requirement already satisfied: huggingface-hub>=0.23.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (0.26.1)
Requirement already satisfied: packaging in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (24.1)
Requirement already satisfied: pyyaml>=5.1 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from datasets) (6.0.2)
Requirement already satisfied: lightning-utilities>=0.8.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torchmetrics>=0.11.0) (0.11.8)
Requirement already satisfied: typing-extensions in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torchmetrics>=0.11.0) (4.12.2)
Requirement already satisfied: sympy in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torch>=2.1.0) (1.13.3)
Requirement already satisfied: networkx in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torch>=2.1.0) (3.1)
Requirement already satisfied: jinja2 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from torch>=2.1.0) (3.1.4)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (2.4.3)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (6.1.0)
Requirement already satisfied: yarl<2.0,>=1.12.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (1.15.2)
Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)
Requirement already satisfied: setuptools in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from lightning-utilities>=0.8.0->torchmetrics>=0.11.0) (75.2.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.32.2->datasets) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.32.2->datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.32.2->datasets) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from requests>=2.32.2->datasets) (2024.8.30)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from jinja2->torch>=2.1.0) (2.1.5)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: tzdata>=2022.1 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from sympy->torch>=2.1.0) (1.3.0)
Requirement already satisfied: six>=1.5 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Requirement already satisfied: propcache>=0.2.0 in /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Imports#
import numpy as np
import openvino as ov
import torch
import IPython.display as ipd
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
Settings#
from pathlib import Path
# Set the data and model directories, model source URL and model filename.
MODEL_DIR = Path("model")
MODEL_DIR.mkdir(exist_ok=True)
Prepare the Model#
Perform the following: - Download and unpack a pre-trained Wav2Vec2 model. - Run model conversion API to convert the model from the PyTorch representation to the OpenVINO Intermediate Representation (OpenVINO IR).
torch_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", ctc_loss_reduction="mean")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480
ov_model = ov.convert_model(torch_model, example_input=torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float))
ir_model_path = MODEL_DIR / "wav2vec2_base.xml"
ov.save_model(ov_model, ir_model_path)
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py:4779: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead warnings.warn( /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/801/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:871: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
Prepare LibriSpeech Dataset#
For demonstration purposes, we will use short dummy version of
LibriSpeech dataset - patrickvonplaten/librispeech_asr_dummy
to
speed up model evaluation. Model accuracy can be different from reported
in the paper. For reproducing original accuracy, use librispeech_asr
dataset.
from datasets import load_dataset
dataset = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
test_sample = dataset[0]["audio"]
# define preprocessing function for converting audio to input values for model
def map_to_input(batch):
preprocessed_signal = processor(
batch["audio"]["array"],
return_tensors="pt",
padding="longest",
sampling_rate=batch["audio"]["sampling_rate"],
)
input_values = preprocessed_signal.input_values
batch["input_values"] = input_values
return batch
# apply preprocessing function to dataset and remove audio column, to save memory as we do not need it anymore
dataset = dataset.map(map_to_input, batched=False, remove_columns=["audio"])
Run Quantization#
NNCF provides a suite of advanced algorithms for Neural Networks inference optimization in OpenVINO with minimal accuracy drop.
Create a quantized model from the pre-trained FP16
model and the
calibration dataset. The optimization process contains the following
steps:
Create a Dataset for quantization.
Run
nncf.quantize
for getting an optimized model. Thenncf.quantize
function provides an interface for model quantization. It requires an instance of the OpenVINO Model and quantization dataset. Optionally, some additional parameters for the configuration quantization process (number of samples for quantization, preset, ignored scope, etc.) can be provided. For more accurate results, we should keep the operation in the postprocessing subgraph in floating point precision, using theignored_scope
parameter. For more information see Tune quantization parameters. For this model, ignored scope was selected experimentally, based on result of quantization with accuracy control. For understanding how it works please check following notebookSerialize OpenVINO IR model using
ov.save_model
function.
import nncf
from nncf.parameters import ModelType
def transform_fn(data_item):
"""
Extract the model's input from the data item.
The data item here is the data item that is returned from the data source per iteration.
This function should be passed when the data item cannot be used as model's input.
"""
return np.array(data_item["input_values"])
calibration_dataset = nncf.Dataset(dataset, transform_fn)
quantized_model = nncf.quantize(
ov_model,
calibration_dataset,
model_type=ModelType.TRANSFORMER, # specify additional transformer patterns in the model
ignored_scope=nncf.IgnoredScope(
names=[
"__module.wav2vec2.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution",
"__module.wav2vec2.feature_extractor.conv_layers.2.conv/aten::_convolution/Convolution",
"__module.wav2vec2.feature_extractor.conv_layers.3.conv/aten::_convolution/Convolution",
"__module.wav2vec2.feature_extractor.conv_layers.0.conv/aten::_convolution/Convolution",
],
),
)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
2024-10-23 04:48:23.521146: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0. 2024-10-23 04:48:23.553631: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-10-23 04:48:24.182318: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Output()
Output()
INFO:nncf:4 ignored nodes were found by names in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 2 __module.wav2vec2.feature_extractor.conv_layers.0.conv/aten::_convolution/Convolution
INFO:nncf:Not adding activation input quantizer for operation: 5 __module.wav2vec2.feature_extractor.conv_layers.1.conv/aten::_convolution/Convolution
INFO:nncf:Not adding activation input quantizer for operation: 7 __module.wav2vec2.feature_extractor.conv_layers.2.conv/aten::_convolution/Convolution
INFO:nncf:Not adding activation input quantizer for operation: 9 __module.wav2vec2.feature_extractor.conv_layers.3.conv/aten::_convolution/Convolution
Output()
Output()
MODEL_NAME = "quantized_wav2vec2_base"
quantized_model_path = Path(f"{MODEL_NAME}_openvino_model/{MODEL_NAME}_quantized.xml")
ov.save_model(quantized_model, quantized_model_path)
Model Usage Example with Inference Pipeline#
Both initial (FP16
) and quantized (INT8
) models are exactly the
same in use.
Start with taking one example from the dataset to show inference steps for it.
ipd.Audio(test_sample["array"], rate=16000)