Quantize Speech Recognition Models with OpenVINO™ Post-Training Optimization Tool ¶
This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.
This tutorial demonstrates how to apply INT8
quantization to the
speech recognition model, known as
Wav2Vec2,
using the Post-Training Optimization Tool API (POT
API)
(part of the OpenVINO Toolkit). This
notebook uses a fine-tuned
Wav2Vec2-Base-960h
PyTorch model trained on the LibriSpeech ASR
corpus. The tutorial is designed to be
extendable to custom models and datasets. It consists of the following
steps:
Download and prepare the Wav2Vec2 model and LibriSpeech dataset.
Define data loading and accuracy validation functionality.
Prepare the model for quantization.
Run optimization pipeline.
Compare performance of the original and quantized models.
Imports¶
import os
import sys
import time
import re
import numpy as np
import torch
import tarfile
from pathlib import Path
from itertools import groupby
import soundfile as sf
import IPython.display as ipd
from transformers import Wav2Vec2ForCTC
from openvino.runtime import Core
from openvino.tools.pot import Metric, DataLoader, IEEngine, \
load_model, save_model, compress_model_weights, create_pipeline
sys.path.append("../utils")
from notebook_utils import download_file
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino/offline_transformations/__init__.py:10: FutureWarning: The module is private and following namespace offline_transformations will be removed in the future, use openvino.runtime.passes instead! warnings.warn(
Settings¶
# Set the data and model directories, model source URL and model filename.
DATA_DIR = "../data/datasets/librispeech"
MODEL_DIR = "model"
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
Prepare the Model¶
Perform the following: - Download and unpack a pre-trained Wav2Vec2 model. - Convert the model to ONNX. - Run Model Optimizer to convert the model from the ONNX representation to the OpenVINO Intermediate Representation (OpenVINO IR)
download_file("https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin", directory=Path(MODEL_DIR) / 'pytorch', show_progress=True)
download_file("https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json", directory=Path(MODEL_DIR) / 'pytorch', show_progress=False)
model/pytorch/pytorch_model.bin: 0%| | 0.00/360M [00:00<?, ?B/s]
PosixPath('/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/pytorch/config.json')
Import all dependencies to load the original PyTorch model and convert it to the ONNX representation.
BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480
def export_model_to_onnx(model, path):
with torch.no_grad():
default_input = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
inputs = {
"inputs": default_input
}
symbolic_names = {0: "batch_size", 1: "sequence_len"}
torch.onnx.export(
model,
(inputs["inputs"]),
path,
opset_version=11,
input_names=["inputs"],
output_names=["logits"],
dynamic_axes={
"inputs": symbolic_names,
"logits": symbolic_names,
},
)
print("ONNX model saved to {}".format(path))
torch_model = Wav2Vec2ForCTC.from_pretrained(Path(MODEL_DIR) / 'pytorch')
onnx_model_path = Path(MODEL_DIR) / "wav2vec2_base.onnx"
if not onnx_model_path.exists():
export_model_to_onnx(torch_model, onnx_model_path)
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at model/pytorch and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
ONNX model saved to model/wav2vec2_base.onnx
Convert the ONNX Model to OpenVINO IR¶
ir_model_xml = onnx_model_path.with_suffix(".xml")
ir_model_bin = onnx_model_path.with_suffix(".bin")
if not ir_model_xml.exists():
!mo --input_model $onnx_model_path --output_dir $MODEL_DIR --input_shape "[1,-1]" --compress_to_fp16
Check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2023_bu_IOTG_OpenVINO-2022-3&content=upg_all&medium=organic or on https://github.com/openvinotoolkit/openvino
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai/latest/openvino_2_0_transition_guide.html
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/wav2vec2_base.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/notebooks/107-speech-recognition-quantization/model/wav2vec2_base.bin
Prepare LibriSpeech Dataset¶
Use the code below to download and unpack the archive with ‘test-clean’ subset of LibriSpeech Dataset
download_file("http://openslr.elda.org/resources/12/test-clean.tar.gz", directory=DATA_DIR, show_progress=True)
if not os.path.exists(f'{DATA_DIR}/LibriSpeech'):
with tarfile.open(f"{DATA_DIR}/test-clean.tar.gz") as tar:
tar.extractall(path=DATA_DIR)
../data/datasets/librispeech/test-clean.tar.gz: 0%| | 0.00/331M [00:00<?, ?B/s]
Define DataLoader for POT¶
Define DataLoader
based on POT API, as it will be used to collect
statistics for quantization and run model evaluation. Wav2Vec2 model
accepts a raw waveform of the speech signal as input and produces
vocabulary class estimations as output. Since the dataset contains audio
files in FLAC format, use the ‘soundfile’ package to convert them to
waveform.
NOTE: Consider increasing
samples_limit
to get more precise results. A suggested value is300
or more, as it will take longer time to process.
class LibriSpeechDataLoader(DataLoader):
samples_limit = 4
@staticmethod
def read_flac(file_name):
speech, samplerate = sf.read(file_name)
assert samplerate == 16000, "read_flac: only 16kHz supported!"
return speech
# Required methods
def __init__(self, config):
"""Constructor
:param config: data loader specific config
"""
super().__init__(config)
self._data_dir = config["data_source"]
self._ds = []
self._prepare_dataset()
def __len__(self):
"""Returns size of the dataset"""
return len(self._ds)
def __getitem__(self, index):
"""
Returns annotation, data and metadata at the specified index.
Possible formats:
(index, annotation), data
(index, annotation), data, metadata
"""
label = self._ds[index][0]
inputs = {'inputs': np.expand_dims(self._ds[index][1], axis=0)}
return label, inputs
# Methods specific to the current implementation
def _prepare_dataset(self):
pattern = re.compile(r'([0-9\-]+)\s+(.+)')
data_folder = Path(self._data_dir)
txts = list(data_folder.glob('**/*.txt'))
counter = 0
for txt in txts:
content = txt.open().readlines()
for line in content:
res = pattern.search(line)
if not res:
continue
name = res.group(1)
transcript = res.group(2)
fname = txt.parent / name
fname = fname.with_suffix('.flac')
identifier = str(fname.relative_to(data_folder))
self._ds.append(((counter, transcript.upper()), LibriSpeechDataLoader.read_flac(os.path.join(self._data_dir, identifier))))
counter += 1
if counter >= self.samples_limit:
# Limit exceeded
return
Define WER Metric Calculation¶
In this step, the Metric
interface for WER metric is implemented. It
is used for validating the accuracy of the model. WER
stands for
Word Error Rate. For more details, refer to the Wiki
page.
class MetricWER(Metric):
alphabet = [
"<pad>", "<s>", "</s>", "<unk>", "|",
"e", "t", "a", "o", "n", "i", "h", "s", "r", "d", "l", "u",
"m", "w", "c", "f", "g", "y", "p", "b", "v", "k", "'", "x", "j", "q", "z"]
words_delimiter = '|'
pad_token = '<pad>'
@staticmethod
def decode_logits(logits):
decoding_vocab = dict(enumerate(MetricWER.alphabet))
token_ids = np.squeeze(np.argmax(logits, -1))
tokens = [decoding_vocab[idx] for idx in token_ids]
tokens = [token_group[0] for token_group in groupby(tokens)]
tokens = [t for t in tokens if t != MetricWER.pad_token]
res_string = ''.join([t if t != MetricWER.words_delimiter else ' ' for t in tokens]).strip()
res_string = ' '.join(res_string.split(' '))
res_string = res_string.lower()
return res_string
# Required methods
def __init__(self):
super().__init__()
self._name = "WER"
self._sum_score = 0
self._sum_words = 0
self._cur_score = 0
self._decoding_vocab = dict(enumerate(self.alphabet))
@property
def value(self):
"""Returns accuracy metric value for the last model output."""
return {self._name: self._cur_score}
@property
def avg_value(self):
"""Returns accuracy metric value for all model outputs."""
return {self._name: self._sum_score / self._sum_words if self._sum_words != 0 else 0}
def update(self, output, target):
"""
Updates prediction matches.
:param output: model output
:param target: annotations
"""
decoded = [self.decode_logits(i) for i in output]
target = [i.lower() for i in target]
assert len(output) == len(target), "sizes of output and target mismatch!"
for i in range(len(output)):
self._get_metric_per_sample(decoded[i], target[i])
def reset(self):
"""
Resets collected matches
"""
self._sum_score = 0
self._sum_words = 0
def get_attributes(self):
"""
Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
Required attributes: 'direction': 'higher-better' or 'higher-worse'
'type': metric type
"""
return {self._name: {"direction": "higher-worse", "type": "WER"}}
# Methods specific to the current implementation
def _get_metric_per_sample(self, annotation, prediction):
cur_score = self._editdistance_eval(annotation.split(), prediction.split())
cur_words = len(annotation.split())
self._sum_score += cur_score
self._sum_words += cur_words
self._cur_score = cur_score / cur_words
result = cur_score / cur_words if cur_words != 0 else 0
return result
def _editdistance_eval(self, source, target):
n, m = len(source), len(target)
distance = np.zeros((n + 1, m + 1), dtype=int)
distance[:, 0] = np.arange(0, n + 1)
distance[0, :] = np.arange(0, m + 1)
for i in range(1, n + 1):
for j in range(1, m + 1):
cost = 0 if source[i - 1] == target[j - 1] else 1
distance[i][j] = min(distance[i - 1][j] + 1,
distance[i][j - 1] + 1,
distance[i - 1][j - 1] + cost)
return distance[n][m]
Run Quantization Pipeline¶
Use the code below to define a configuration for the quantization
pipeline and run it. Keep in mind that built-in IEEngine
implementation of Engine
interface from the POT API for model
inference is used here.
model_config = {"model_name": "wav2vec2_base", "model": ir_model_xml, "weights": ir_model_bin}
engine_config = {"device": "CPU"}
dataset_config = {"data_source": os.path.join(DATA_DIR, "LibriSpeech/test-clean")}
algorithms = [
{
"name": "DefaultQuantization",
"params": {
"target_device": "ANY",
"model_type": "transformer",
"preset": "performance",
"stat_subset_size": 300,
"activations": {
"range_estimator": {
"min": {
"aggregator": "min",
"type": "min"
},
"max": {
"aggregator": "mean",
"type": "quantile",
"outlier_prob": 0.0001
}
}
},
"ignored": {
"scope": ["214"]
}
}
}
]
# Step 1: Load the model.
model = load_model(model_config=model_config)
# Step 2: Initialize the data loader.
data_loader = LibriSpeechDataLoader(config=dataset_config)
# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric.
metric = MetricWER()
# Step 4: Initialize the engine for metric calculation and statistics collection.
engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)
# Step 5: Create a pipeline of compression algorithms.
pipeline = create_pipeline(algo_config=algorithms, engine=engine)
# Step 6 (Optional): Evaluate the original model. Print the results.
start_time = time.perf_counter()
fp_results = pipeline.evaluate(model=model)
end_time = time.perf_counter()
print(f"Evaluation finished in {end_time - start_time:.2f} seconds")
if fp_results:
print("FP16 model results:")
for name, value in fp_results.items():
print(f"{name}: {value:.5f}")
Evaluation finished in 8.06 seconds
FP16 model results:
WER: 0.03704
# Step 7: Execute the pipeline.
print(f"Quantizing model with {algorithms[0]['params']['preset']} preset and {algorithms[0]['name']}")
start_time = time.perf_counter()
compressed_model = pipeline.run(model=model)
end_time = time.perf_counter()
print(f"Quantization finished in {end_time - start_time:.2f} seconds")
# Step 8 (Optional): Compress model weights to quantized precision
# in order to reduce the size of the final .bin file.
compress_model_weights(model=compressed_model)
# Step 9: Save the compressed model to the desired path.
compressed_model_paths = save_model(model=compressed_model, save_path=MODEL_DIR, model_name="quantized_wav2vec2_base")
compressed_model_xml = compressed_model_paths[0]["model"]
Quantizing model with performance preset and DefaultQuantization
Quantization finished in 68.10 seconds
# Step 10 (Optional): Evaluate the compressed model and print the results.
int_results = pipeline.evaluate(model=compressed_model)
if int_results:
print("INT8 model results:")
for name, value in int_results.items():
print(f"{name}: {value:.5f}")
INT8 model results:
WER: 0.05556
Model Usage Example with Inference Pipeline¶
Both initial (FP16
) and quantized (INT8
) models are exactly the
same in use.
Start with taking one example from the dataset to show inference steps for it.
audio = LibriSpeechDataLoader.read_flac(f'{DATA_DIR}/LibriSpeech/test-clean/121/127105/121-127105-0017.flac')
ipd.Audio(audio, rate=16000)
Next, load the quantized model to the inference pipeline.
ie = Core()
model = ie.read_model(compressed_model_xml)
compiled_model = ie.compile_model(model=model, device_name='CPU')
input_data = np.expand_dims(audio, axis=0)
output_layer = compiled_model.outputs[0]
Next, make a prediction.
predictions = compiled_model([input_data])[output_layer]
Now, you just need to decode predicted probabilities to text, using
tokenizer from MetricWER
class.
Alternatively, use a built-in Wav2Vec2Processor
tokenizer from the
transformers
package.
predicted_text = MetricWER.decode_logits(predictions)
predicted_text
'it was almost the tone of hope everybody will stay'
Compare Performance of the Original and Quantized Models¶
Finally, use Benchmark
Tool
to measure the inference performance of the FP16
and INT8
models.
NOTE: For more accurate performance, it is recommended to run
benchmark_app
in a terminal/command prompt after closing other applications. Runbenchmark_app -m model.xml -d CPU
to benchmark async inference on CPU for one minute. ChangeCPU
toGPU
to benchmark on GPU. Runbenchmark_app --help
to see an overview of all command-line options.
# Inference FP16 model (OpenVINO IR)
! benchmark_app -m $ir_model_xml -shape "[1,30480]" -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 152.67 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ] inputs (node: inputs) : f32 / [...] / [1,?]
[ INFO ] Model outputs:
[ INFO ] logits (node: logits) : f32 / [...] / [1,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'inputs': [1,30480]
[ INFO ] Reshape model took 20.98 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ] inputs (node: inputs) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ] logits (node: logits) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 1045.18 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ] NETWORK_NAME: torch_jit
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ] NUM_STREAMS: 6
[ INFO ] AFFINITY: Affinity.CORE
[ INFO ] INFERENCE_NUM_THREADS: 24
[ INFO ] PERF_COUNT: False
[ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 70.85 ms
[Step 11/11] Dumping statistics report
[ INFO ] Count: 2754 iterations
[ INFO ] Duration: 60231.38 ms
[ INFO ] Latency:
[ INFO ] Median: 130.85 ms
[ INFO ] Average: 131.00 ms
[ INFO ] Min: 110.84 ms
[ INFO ] Max: 149.16 ms
[ INFO ] Throughput: 45.72 FPS
# Inference INT8 model (OpenVINO IR)
! benchmark_app -m $compressed_model_xml -shape "[1,30480]" -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 147.54 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ] inputs (node: inputs) : f32 / [...] / [1,?]
[ INFO ] Model outputs:
[ INFO ] logits (node: logits) : f32 / [...] / [1,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'inputs': [1,30480]
[ INFO ] Reshape model took 24.48 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ] inputs (node: inputs) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ] logits (node: logits) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 938.11 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ] NETWORK_NAME: torch_jit
[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ] NUM_STREAMS: 6
[ INFO ] AFFINITY: Affinity.CORE
[ INFO ] INFERENCE_NUM_THREADS: 24
[ INFO ] PERF_COUNT: False
[ INFO ] INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ] PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 60000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 34.13 ms
[Step 11/11] Dumping statistics report
[ INFO ] Count: 7674 iterations
[ INFO ] Duration: 60046.26 ms
[ INFO ] Latency:
[ INFO ] Median: 46.52 ms
[ INFO ] Average: 46.82 ms
[ INFO ] Min: 31.95 ms
[ INFO ] Max: 61.98 ms
[ INFO ] Throughput: 127.80 FPS