Automatic speech recognition using Whisper and OpenVINO with Generate API#

This Jupyter notebook can be launched after a local installation only.

Github

Whisper is an automatic speech recognition (ASR) system trained on 680,000 hours of multilingual and multitask supervised data collected from the web.

Whisper is a Transformer based encoder-decoder model, also referred to as a sequence-to-sequence model. It maps a sequence of audio spectrogram features to a sequence of text tokens. First, the raw audio inputs are converted to a log-Mel spectrogram by action of the feature extractor. Then, the Transformer encoder encodes the spectrogram to form a sequence of encoder hidden states. Finally, the decoder autoregressively predicts text tokens, conditional on both the previous tokens and the encoder hidden states.

You can see the model architecture in the diagram below:

whisper_architecture.svg

whisper_architecture.svg#

In this tutorial, we consider how to run Whisper using OpenVINO. We will use the pre-trained model from the Hugging Face Transformers library. The Hugging Face Optimum Intel library converts the models to OpenVINO™ IR format. To simplify the user experience, we will use OpenVINO Generate API for Whisper automatic speech recognition scenarios.

Installation Instructions#

This is a self-contained example that relies solely on its own code.

We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.

Table of contents:

Prerequisites#

%pip install -q "transformers>=4.35" "torch>=2.3" "torchvision>=0.18.1" "onnx>=1.16.1" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "git+https://github.com/huggingface/optimum-intel.git"
%pip install -q --pre -U "openvino" "openvino-tokenizers" "openvino-genai" --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
%pip install -q datasets  "gradio>=4.0" "soundfile>=0.12" "librosa" "python-ffmpeg<=1.0.16"
%pip install -q "nncf>=2.13.0" "jiwer"
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
aeiou 0.0.20 requires soundfile<=0.10.2, but you have soundfile 0.12.1 which is incompatible.
descript-audiotools 0.7.2 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
import requests
from pathlib import Path

if not Path("notebook_utils.py").exists():
    r = requests.get(
        url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
    )
    open("notebook_utils.py", "w").write(r.text)

Load PyTorch model#

The AutoModelForSpeechSeq2Seq.from_pretrained method is used for the initialization of PyTorch Whisper model using the transformers library. The model will be downloaded once during first run and this process may require some time.

You may also choose other models from Whisper collection, more on them here.

Preprocessing and post-processing are important in this model use. AutoProcessor class used for initialization WhisperProcessor is responsible for preparing audio input data for the model, converting it to Mel-spectrogram and decoding predicted output token_ids into string using tokenizer. We will use pipeline method to transcribe audios of arbitrary length.

import ipywidgets as widgets

model_ids = {
    "Multilingual models": [
        "openai/whisper-large-v3-turbo",
        "openai/whisper-large-v3",
        "openai/whisper-large-v2",
        "openai/whisper-large",
        "openai/whisper-medium",
        "openai/whisper-small",
        "openai/whisper-base",
        "openai/whisper-tiny",
    ],
    "English-only models": [
        "openai/whisper-medium.en",
        "openai/whisper-small.en",
        "openai/whisper-base.en",
        "openai/whisper-tiny.en",
    ],
}

model_type = widgets.Dropdown(
    options=model_ids.keys(),
    value="Multilingual models",
    description="Model:",
    disabled=False,
)

model_type
Dropdown(description='Model:', options=('Multilingual models', 'English-only models'), value='Multilingual mod…
model_id = widgets.Dropdown(
    options=model_ids[model_type.value],
    value=model_ids[model_type.value][-1],
    description="Model:",
    disabled=False,
)

model_id
Dropdown(description='Model:', index=7, options=('openai/whisper-large-v3-turbo', 'openai/whisper-large-v3', '…
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
from transformers.utils import logging

processor = AutoProcessor.from_pretrained(model_id.value)

pt_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id.value)

pipe_pt = pipeline(
    "automatic-speech-recognition",
    model=pt_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device="cpu",
)
2024-10-08 06:42:59.802092: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-10-08 06:42:59.836152: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-10-08 06:43:00.493598: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Run PyTorch model inference#

The pipeline expects audio data in numpy array format. We will use .wav file and convert it numpy array format for that purpose.

from notebook_utils import download_file

en_example_short = Path("data", "librispeech_asr_demo_validation_short.wav")

# a wav sample
download_file(
    "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/librispeech_asr_demo_validation_0.wav",
    en_example_short.name,
    directory=en_example_short.parent,
)
data/librispeech_asr_demo_validation_short.wav:   0%|          | 0.00/183k [00:00<?, ?B/s]
PosixPath('/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/notebooks/whisper-asr-genai/data/librispeech_asr_demo_validation_short.wav')
import librosa

en_raw_speech, samplerate = librosa.load(str(en_example_short), sr=16000)

Let’s check how to work the transcribe task.

import copy
import IPython.display as ipd

logging.set_verbosity_error()

sample = copy.deepcopy(en_raw_speech)

display(ipd.Audio(sample, rate=samplerate))

pt_result = pipe_pt(sample)
print(f"Result: {pt_result['text']}")
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/whisper/generation_whisper.py:496: FutureWarning: The input name inputs is deprecated. Please make sure to use input_features instead.
  warnings.warn(
Result:  Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

If the multilingual model was chosen, let’s see how task translate is working. We will use facebook/multilingual_librispeech multilingual dataset, so you can choose the language. The model will translate audio from the selected language into English. Conversion of audio to numpy format is handled by Hugging Face datasets implementation. A complete list of languages ​​supported by the model can be found in the paper.

import ipywidgets as widgets

languages = {"japanese": "ja_jp", "dutch": "da_dk", "french": "fr_fr", "spanish": "ca_es", "italian": "it_it", "portuguese": "pt_br", "polish": "pl_pl"}

SAMPLE_LANG = None
if model_type.value == "Multilingual models":
    SAMPLE_LANG = widgets.Dropdown(
        options=languages.keys(),
        value="italian",
        description="Dataset language:",
        disabled=False,
    )

SAMPLE_LANG
Dropdown(description='Dataset language:', index=4, options=('japanese', 'dutch', 'french', 'spanish', 'italian…
from datasets import load_dataset

mls_dataset = None
if model_type.value == "Multilingual models":
    mls_dataset = load_dataset("google/fleurs", languages[SAMPLE_LANG.value], split="test", streaming=True, trust_remote_code=True)
    mls_dataset = iter(mls_dataset)  # make it iterable
    mls_example = next(mls_dataset)  # get one example
if model_type.value == "Multilingual models":
    sample = copy.deepcopy(mls_example["audio"])

    display(ipd.Audio(sample["array"], rate=sample["sampling_rate"]))
    print(f"Reference: {mls_example['raw_transcription']}")

    pt_result = pipe_pt(sample, generate_kwargs={"task": "translate"})
    print(f"\nResult: {pt_result['text']}")
Reference: Il blog è uno strumento che si prefigge di incoraggiare la collaborazione e sviluppare l'apprendimento degli studenti ben oltre la giornata scolastica normale.

Result:  The blog is our tool that is prefilled to encourage collaboration and develop the learning of the students and to attract a normal school class.

Download and convert model to OpenVINO IR via Optimum Intel CLI#

Listed Whisper model are available for downloading via the HuggingFace hub. We will use optimum-cli interface for exporting it into OpenVINO Intermediate Representation (IR) format.

Optimum CLI interface for converting models supports export to OpenVINO (supported starting optimum-intel 1.12 version). General command format:

optimum-cli export openvino --model <model_id_or_path> --task <task> <output_dir>

where --model argument is model id from HuggingFace Hub or local directory with model (saved using .save_pretrained method), --task is one of supported task that exported model should solve. For LLMs it will be automatic-speech-recognition-with-past. If model initialization requires to use remote code, --trust-remote-code flag additionally should be passed. Full list of supported arguments available via --help For more details and examples of usage, please check optimum documentation.

import logging
import nncf
import os
from IPython.display import display, Markdown

nncf.set_log_level(logging.ERROR)

model_path = Path(model_id.value.split("/")[1])
export_command = f"optimum-cli export openvino --model {model_id.value} --library transformers --task automatic-speech-recognition-with-past --framework pt {str(model_path)}"

display(Markdown("**Export command:**"))
display(Markdown(f"`{export_command}`"))

exit_code = os.system(export_command)
if exit_code != 0:
    raise Exception("Failed to load and convert model!")
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino

Export command:

optimum-cli export openvino --model openai/whisper-tiny --library transformers --task automatic-speech-recognition-with-past --framework pt whisper-tiny

2024-10-08 06:43:10.758697: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Moving the following attributes in the config to the generation config: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}. You are seeing this warning because you've set generation parameters in the model config, as opposed to in the generation config.
Using framework PyTorch: 2.3.1+cpu
Overriding 1 configuration item(s)
    - use_cache -> False
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:1071: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_features.shape[-1] != expected_seq_length:
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:388: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
Using framework PyTorch: 2.3.1+cpu
Overriding 1 configuration item(s)
    - use_cache -> True
Passing a tuple of past_key_values is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of EncoderDecoderCache instead, e.g. past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values).
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:101: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if sequence_length != 1:
Using framework PyTorch: 2.3.1+cpu
Overriding 1 configuration item(s)
    - use_cache -> True
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/cache_utils.py:447: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/790/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/cache_utils.py:432: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors

Run inference OpenVINO model with WhisperPipeline#

To simplify user experience we will use OpenVINO Generate API. Firstly we will create pipeline with WhisperPipeline. You can construct it straight away from the folder with the converted model. It will automatically load the model, tokenizer, detokenizer and default generation configuration.

from notebook_utils import device_widget

device = device_widget(default="CPU", exclude=["NPU"])

device
Dropdown(description='Device:', options=('CPU', 'AUTO'), value='CPU')
import openvino_genai

ov_pipe = openvino_genai.WhisperPipeline(str(model_path), device=device.value)

Let’s run the transcribe task. We just call generate for that and put array as input.

sample = copy.deepcopy(en_raw_speech)

genai_result = ov_pipe.generate(sample)

display(ipd.Audio(sample, rate=samplerate))
print(f"Result: {genai_result}")
Result:  Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

Let’s see how to work the translate task. It supports for multilingual models only. For that case we will specify language and task options. We can do this in different ways. We can get default config with get_generation_config(), setup parameters and put config directly to generate(). It’s also possible to specify the needed options just as inputs in the generate() method and we will use this way. Then we just run generate method and get the output in text format.

languages_genai = {
    "japanese": "<|ja|>",
    "dutch": "<|da|>",
    "french": "<|fr|>",
    "spanish": "<|es|>",
    "italian": "<|it|>",
    "portuguese": "<|pt|>",
    "polish": "<|pl|>",
}

if model_type.value == "Multilingual models":
    sample = copy.deepcopy(mls_example["audio"])

    genai_result_ml = ov_pipe.generate(sample["array"], max_new_tokens=100, task="translate", language=languages_genai[SAMPLE_LANG.value])

    display(ipd.Audio(sample["array"], rate=sample["sampling_rate"]))
    print(f"Reference: {mls_example['raw_transcription']}")
    print(f"\nResult: {genai_result_ml}")
Reference: Il blog è uno strumento che si prefigge di incoraggiare la collaborazione e sviluppare l'apprendimento degli studenti ben oltre la giornata scolastica normale.

Result:  The blog is our tool that is prefilled to encourage collaboration and develop the learning of the students and to attract a normal school class.

Compare performance PyTorch vs OpenVINO#

import time
import numpy as np
from tqdm.notebook import tqdm


def measure_perf(pipe, n=10, model_type="ov"):
    timers = []
    for _ in tqdm(range(n), desc="Measuring performance"):
        sample = copy.deepcopy(en_raw_speech)
        start = time.perf_counter()
        if model_type == "pt":
            pipe(sample)
        elif model_type == "ov":
            pipe.generate(sample)
        end = time.perf_counter()
        timers.append(end - start)
    return np.median(timers)
perf_torch = measure_perf(pipe_pt, model_type="pt")
perf_ov = measure_perf(ov_pipe)
Measuring performance:   0%|          | 0/10 [00:00<?, ?it/s]
Measuring performance:   0%|          | 0/10 [00:00<?, ?it/s]
print(f"Mean torch {model_id.value} generation time: {perf_torch:.3f}s")
print(f"Mean openvino {model_id.value} generation time: {perf_ov:.3f}s")
print(f"Performance {model_id.value} openvino speedup: {perf_torch / perf_ov:.3f}")
Mean torch openai/whisper-tiny generation time: 0.273s
Mean openvino openai/whisper-tiny generation time: 0.166s
Performance openai/whisper-tiny openvino speedup: 1.650

Quantization#

NNCF enables post-training quantization by adding the quantization layers into the model graph and then using a subset of the training dataset to initialize the parameters of these additional quantization layers. The framework is designed so that modifications to your original training code are minor.

The optimization process contains the following steps:

  1. Create a calibration dataset for quantization.

  2. Run nncf.quantize to obtain quantized encoder and decoder models.

  3. Serialize the INT8 model using openvino.save_model function.

Note: Quantization is time and memory consuming operation. Running quantization code below may take some time.

Please select below whether you would like to run Whisper quantization.

from notebook_utils import quantization_widget

to_quantize = quantization_widget()

to_quantize
Checkbox(value=True, description='Quantization')
# Fetch `skip_kernel_extension` module
import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/skip_kernel_extension.py",
)
open("skip_kernel_extension.py", "w").write(r.text)

%load_ext skip_kernel_extension

Let’s load converted OpenVINO model format using Optimum-Intel to easily quantize it.

Optimum Intel can be used to load optimized models from the Hugging Face Hub or local folder to create pipelines to run an inference with OpenVINO Runtime using Hugging Face APIs. The Optimum Inference models are API compatible with Hugging Face Transformers models. This means we just need to replace the AutoModelForXxx class with the corresponding OVModelForXxx class.

Below is an example of the whisper-tiny model

-from transformers import AutoModelForSpeechSeq2Seq
+from optimum.intel.openvino import OVModelForSpeechSeq2Seq
from transformers import AutoTokenizer, pipeline

model_id = "openai/whisper-tiny"
-model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
+model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)

Like the original PyTorch model, the OpenVINO model is also compatible with HuggingFace pipeline interface for automatic-speech-recognition. Pipeline can be used for long audio transcription. Distil-Whisper uses a chunked algorithm to transcribe long-form audio files. In practice, this chunked long-form algorithm is 9x faster than the sequential algorithm proposed by OpenAI in the Whisper paper. To enable chunking, pass the chunk_length_s parameter to the pipeline. For Distil-Whisper, a chunk length of 15 seconds is optimal. To activate batching, pass the argument batch_size.

from optimum.intel.openvino import OVModelForSpeechSeq2Seq

ov_model = OVModelForSpeechSeq2Seq.from_pretrained(str(model_path), device=device.value)
ov_processor = AutoProcessor.from_pretrained(str(model_path))
Compiling the encoder to CPU ...
Compiling the decoder to CPU ...
Compiling the decoder to CPU ...

Prepare calibration datasets#

First step is to prepare calibration datasets for quantization. Since we quantize whisper encoder and decoder separately, we need to prepare a calibration dataset for each of the models. We import an InferRequestWrapper class that will intercept model inputs and collect them to a list. Then we run model inference on some small amount of audio samples. Generally, increasing the calibration dataset size improves quantization quality.

%%skip not $to_quantize.value

from itertools import islice
from optimum.intel.openvino.quantization import InferRequestWrapper


def collect_calibration_dataset(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    # Overwrite model request properties, saving the original ones for restoring later
    encoder_calibration_data = []
    decoder_calibration_data = []
    ov_model.encoder.request = InferRequestWrapper(ov_model.encoder.request, encoder_calibration_data, apply_caching=True)
    ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request,
                                                             decoder_calibration_data,
                                                             apply_caching=True)

    pipe = pipeline(
      "automatic-speech-recognition",
      model=ov_model,
      chunk_length_s=30,
      tokenizer=ov_processor.tokenizer,
      feature_extractor=ov_processor.feature_extractor)
    try:
        calibration_dataset = dataset = load_dataset("openslr/librispeech_asr", "clean", split="validation", streaming=True, trust_remote_code=True)
        for sample in tqdm(islice(calibration_dataset, calibration_dataset_size), desc="Collecting calibration data",
                           total=calibration_dataset_size):
            pipe(sample["audio"], return_timestamps=True)
    finally:
        ov_model.encoder.request = ov_model.encoder.request.request
        ov_model.decoder_with_past.request = ov_model.decoder_with_past.request.request

    return encoder_calibration_data, decoder_calibration_data

Quantize Whisper encoder and decoder models#

Below we run the quantize function which calls nncf.quantize on Whisper encoder and decoder-with-past models. We don’t quantize first-step-decoder because its share in whole inference time is negligible.

%%skip not $to_quantize.value

import gc
import shutil
import nncf
import openvino as ov
from datasets import load_dataset
from tqdm.notebook import tqdm

def extract_input_features(sample):
    input_features = processor(
        sample["audio"]["array"],
        sampling_rate=sample["audio"]["sampling_rate"],
        return_tensors="pt",
    ).input_features
    return input_features



CALIBRATION_DATASET_SIZE = 30
quantized_model_path = Path(f"{model_path}-quantized")


def quantize(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    if not quantized_model_path.exists():
        encoder_calibration_data, decoder_calibration_data = collect_calibration_dataset(
            ov_model, calibration_dataset_size
        )
        print("Quantizing encoder")
        quantized_encoder = nncf.quantize(
            ov_model.encoder.model,
            nncf.Dataset(encoder_calibration_data),
            subset_size=len(encoder_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.80)
        )
        ov.save_model(quantized_encoder, quantized_model_path / "openvino_encoder_model.xml")
        del quantized_encoder
        del encoder_calibration_data
        gc.collect()

        print("Quantizing decoder with past")
        quantized_decoder_with_past = nncf.quantize(
            ov_model.decoder_with_past.model,
            nncf.Dataset(decoder_calibration_data),
            subset_size=len(decoder_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.96)
        )
        ov.save_model(quantized_decoder_with_past, quantized_model_path / "openvino_decoder_with_past_model.xml")
        del quantized_decoder_with_past
        del decoder_calibration_data
        gc.collect()

        # Copy the config file and the first-step-decoder manually
        shutil.copy(model_path / "config.json", quantized_model_path / "config.json")
        shutil.copy(model_path / "generation_config.json", quantized_model_path / "generation_config.json")
        shutil.copy(model_path / "openvino_decoder_model.xml", quantized_model_path / "openvino_decoder_model.xml")
        shutil.copy(model_path / "openvino_decoder_model.bin", quantized_model_path / "openvino_decoder_model.bin")
        shutil.copy(model_path / "openvino_tokenizer.xml", quantized_model_path / "openvino_tokenizer.xml")
        shutil.copy(model_path / "openvino_tokenizer.bin", quantized_model_path / "openvino_tokenizer.bin")
        shutil.copy(model_path / "openvino_detokenizer.xml", quantized_model_path / "openvino_detokenizer.xml")
        shutil.copy(model_path / "openvino_detokenizer.bin", quantized_model_path / "openvino_detokenizer.bin")
        shutil.copy(model_path / "tokenizer_config.json", quantized_model_path / "tokenizer_config.json")
        shutil.copy(model_path / "tokenizer.json", quantized_model_path / "tokenizer.json")
        shutil.copy(model_path / "vocab.json", quantized_model_path / "vocab.json")
        shutil.copy(model_path / "preprocessor_config.json", quantized_model_path / "preprocessor_config.json")
        shutil.copy(model_path / "special_tokens_map.json", quantized_model_path / "special_tokens_map.json")
        shutil.copy(model_path / "normalizer.json", quantized_model_path / "normalizer.json")
        shutil.copy(model_path / "merges.txt", quantized_model_path / "merges.txt")
        shutil.copy(model_path / "added_tokens.json", quantized_model_path / "added_tokens.json")

    quantized_ov_pipe = openvino_genai.WhisperPipeline(str(quantized_model_path), device=device.value)
    return quantized_ov_pipe


ov_quantized_pipe = quantize(ov_model, CALIBRATION_DATASET_SIZE)
Collecting calibration data:   0%|          | 0/30 [00:00<?, ?it/s]
Output()
Quantizing encoder
Output()
Output()
Output()
Quantizing decoder with past
Output()
Output()
Output()
Output()

Run quantized model inference#

Let’s compare the transcription results for original and quantized models.

%%skip not $to_quantize.value

sample = copy.deepcopy(en_raw_speech)

genai_result = ov_pipe.generate(sample)
quantized_genai_result = ov_quantized_pipe.generate(sample)

display(ipd.Audio(sample, rate=samplerate))

print(f"Original : {genai_result}")
print(f"Quantized: {quantized_genai_result}")
Original :  Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.
Quantized:  Mr Quilder is the apostle of the middle classes and we are glad to welcome his gospel.

Compare performance and accuracy of the original and quantized models#

Finally, we compare original and quantized Whisper models from accuracy and performance stand-points.

To measure accuracy, we use 1 - WER as a metric, where WER stands for Word Error Rate.

%%skip not $to_quantize.value

import time
from contextlib import contextmanager
from jiwer import wer, wer_standardize


TEST_DATASET_SIZE = 50

def calculate_transcription_time_and_accuracy(ov_model, test_samples):
    whole_infer_times = []

    ground_truths = []
    predictions = []
    for data_item in tqdm(test_samples, desc="Measuring performance and accuracy"):

        start_time = time.perf_counter()
        transcription = ov_model.generate(data_item["audio"]["array"])
        end_time = time.perf_counter()
        whole_infer_times.append(end_time - start_time)

        ground_truths.append(data_item["text"])
        predictions.append(transcription.texts[0])

    word_accuracy = (1 - wer(ground_truths, predictions, reference_transform=wer_standardize,
                             hypothesis_transform=wer_standardize)) * 100
    mean_whole_infer_time = sum(whole_infer_times)
    return word_accuracy, mean_whole_infer_time

test_dataset = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
test_dataset = test_dataset.shuffle(seed=42).take(TEST_DATASET_SIZE)
test_samples = [sample for sample in test_dataset]

accuracy_original, times_original = calculate_transcription_time_and_accuracy(ov_pipe, test_samples)
accuracy_quantized, times_quantized = calculate_transcription_time_and_accuracy(ov_quantized_pipe, test_samples)
print(f"Whole pipeline performance speedup: {times_original / times_quantized:.3f}")
print(f"Whisper transcription word accuracy. Original model: {accuracy_original:.2f}%. Quantized model: {accuracy_quantized:.2f}%.")
print(f"Accuracy drop: {accuracy_original - accuracy_quantized:.2f}%.")
Measuring performance and accuracy:   0%|          | 0/50 [00:00<?, ?it/s]
Measuring performance and accuracy:   0%|          | 0/50 [00:00<?, ?it/s]
Whole pipeline performance speedup: 1.381
Whisper transcription word accuracy. Original model: 82.88%. Quantized model: 84.13%.
Accuracy drop: -1.25%.

Interactive demo#

We are also providing an interactive demo using the Gradio interface, where you can test model capabilities on your own audio data (using the upload button) or record using your microphone.

import requests

if not Path("gradio_helper.py").exists():
    r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/distil-whisper-asr/gradio_helper.py")
    open("gradio_helper.py", "w").write(r.text)

from gradio_helper import make_demo, GradioPipeline

pipe = ov_quantized_pipe if to_quantize.value else ov_pipe

gr_pipeline = GradioPipeline(pipe, multilingual=(not model_id.value.endswith(".en")), quantized=to_quantize.value)

demo = make_demo(gr_pipeline)

try:
    demo.launch(debug=False)
except Exception:
    demo.launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL:  http://127.0.0.1:7860

To create a public link, set share=True in launch().