Automatic speech recognition using Distil-Whisper and OpenVINO

This Jupyter notebook can be launched after a local installation only.

Github

Distil-Whisper is a distilled variant of the Whisper model by OpenAI. The Distil-Whisper is proposed in the paper Robust Knowledge Distillation via Large-Scale Pseudo Labelling. According to authors, compared to Whisper, Distil-Whisper runs in several times faster with 50% fewer parameters, while performing to within 1% word error rate (WER) on out-of-distribution evaluation data.

Whisper is a Transformer based encoder-decoder model, also referred to as a sequence-to-sequence model. It maps a sequence of audio spectrogram features to a sequence of text tokens. First, the raw audio inputs are converted to a log-Mel spectrogram by action of the feature extractor. Then, the Transformer encoder encodes the spectrogram to form a sequence of encoder hidden states. Finally, the decoder autoregressively predicts text tokens, conditional on both the previous tokens and the encoder hidden states.

You can see the model architecture in the diagram below:

whisper_architecture.svg

whisper_architecture.svg

In this tutorial, we consider how to run Distil-Whisper using OpenVINO. We will use the pre-trained model from the Hugging Face Transformers library. To simplify the user experience, the Hugging Face Optimum library is used to convert the model to OpenVINO™ IR format. To further improve OpenVINO Distil-Whisper model performance INT8 post-training quantization from NNCF is applied.

Table of contents:

Prerequisites

%pip install -q "transformers>=4.35" onnx "git+https://github.com/huggingface/optimum-intel.git" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "openvino>=2023.2.0" datasets  "gradio>=4.0" "librosa" "soundfile"
%pip install -q "nncf>=2.6.0" "jiwer"

Load PyTorch model

The AutoModelForSpeechSeq2Seq.from_pretrained method is used for the initialization of PyTorch Whisper model using the transformers library. By default, we will use the distil-whisper/distil-large-v2 model as an example in this tutorial. The model will be downloaded once during first run and this process may require some time.

You may also choose other models from Distil-Whisper hugging face collection such as distil-whisper/distil-medium.en or distil-whisper/distil-small.en. Models of the original Whisper architecture are also available, more on them here.

Preprocessing and post-processing are important in this model use. AutoProcessor class used for initialization WhisperProcessor is responsible for preparing audio input data for the model, converting it to Mel-spectrogram and decoding predicted output token_ids into string using tokenizer.

import ipywidgets as widgets

model_ids = {
    "Distil-Whisper": [
        "distil-whisper/distil-large-v2",
        "distil-whisper/distil-medium.en",
        "distil-whisper/distil-small.en"
    ],
    "Whisper": [
        "openai/whisper-large-v3",
        "openai/whisper-large-v2",
        "openai/whisper-large",
        "openai/whisper-medium",
        "openai/whisper-small",
        "openai/whisper-base",
        "openai/whisper-tiny",
        "openai/whisper-medium.en",
        "openai/whisper-small.en",
        "openai/whisper-base.en",
        "openai/whisper-tiny.en",
    ]
}

model_type = widgets.Dropdown(
    options=model_ids.keys(),
    value="Distil-Whisper",
    description="Model type:",
    disabled=False,
)

model_type
model_id = widgets.Dropdown(
    options=model_ids[model_type.value],
    value=model_ids[model_type.value][0],
    description="Model:",
    disabled=False,
)

model_id
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

processor = AutoProcessor.from_pretrained(model_id.value)

pt_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id.value)
pt_model.eval();
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Prepare input sample

The processor expects audio data in numpy array format and information about the audio sampling rate and returns the input_features tensor for making predictions. Conversion of audio to numpy format is handled by Hugging Face datasets implementation.

from datasets import load_dataset

def extract_input_features(sample):
    input_features = processor(
        sample["audio"]["array"],
        sampling_rate=sample["audio"]["sampling_rate"],
        return_tensors="pt",
    ).input_features
    return input_features

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]
input_features = extract_input_features(sample)

Run model inference

To perform speech recognition, one can use generate interface of the model. After generation is finished processor.batch_decode can be used for decoding predicted token_ids into text transcription.

import IPython.display as ipd

predicted_ids = pt_model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))
print(f"Reference: {sample['text']}")
print(f"Result: {transcription[0]}")
Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
Result:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

Load OpenVINO model using Optimum library

The Hugging Face Optimum API is a high-level API that enables us to convert and quantize models from the Hugging Face Transformers library to the OpenVINO™ IR format. For more details, refer to the Hugging Face Optimum documentation.

Optimum Intel can be used to load optimized models from the Hugging Face Hub and create pipelines to run an inference with OpenVINO Runtime using Hugging Face APIs. The Optimum Inference models are API compatible with Hugging Face Transformers models. This means we just need to replace the AutoModelForXxx class with the corresponding OVModelForXxx class.

Below is an example of the distil-whisper model

-from transformers import AutoModelForSpeechSeq2Seq
+from optimum.intel.openvino import OVModelForSpeechSeq2Seq
from transformers import AutoTokenizer, pipeline

model_id = "distil-whisper/distil-large-v2"
-model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
+model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)

Model class initialization starts with calling the from_pretrained method. When downloading and converting the Transformers model, the parameter export=True should be added. We can save the converted model for the next usage with the save_pretrained method. Tokenizers and Processors are distributed with models also compatible with the OpenVINO model. It means that we can reuse initialized early processor.

from pathlib import Path
from optimum.intel.openvino import OVModelForSpeechSeq2Seq

model_path = Path(model_id.value.replace('/', '_'))
ov_config = {"CACHE_DIR": ""}

if not model_path.exists():
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
        model_id.value, ov_config=ov_config, export=True, compile=False, load_in_8bit=False
    )
    ov_model.half()
    ov_model.save_pretrained(model_path)
else:
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
        model_path, ov_config=ov_config, compile=False
    )
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino

Select Inference device

import openvino as ov
import ipywidgets as widgets

core = ov.Core()

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=4, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'AUTO'), value='AUTO')

Compile OpenVINO model

ov_model.to(device.value)
ov_model.compile()
Compiling the encoder to AUTO ...
Compiling the decoder to AUTO ...
Compiling the decoder to AUTO ...

Run OpenVINO model inference

predicted_ids = ov_model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))
print(f"Reference: {sample['text']}")
print(f"Result: {transcription[0]}")
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/optimum/intel/openvino/modeling_seq2seq.py:457: FutureWarning: shared_memory is deprecated and will be removed in 2024.0. Value of shared_memory is going to override share_inputs value. Please use only share_inputs explicitly.
  last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to(
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/optimum/intel/openvino/modeling_seq2seq.py:538: FutureWarning: shared_memory is deprecated and will be removed in 2024.0. Value of shared_memory is going to override share_inputs value. Please use only share_inputs explicitly.
  self.request.start_async(inputs, shared_memory=True)
Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
Result:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

Compare performance PyTorch vs OpenVINO

import time
import numpy as np
from tqdm.notebook import tqdm


def measure_perf(model, sample, n=10):
    timers = []
    input_features = extract_input_features(sample)
    for _ in tqdm(range(n), desc="Measuring performance"):
        start = time.perf_counter()
        model.generate(input_features)
        end = time.perf_counter()
        timers.append(end - start)
    return np.median(timers)
perf_torch = measure_perf(pt_model, sample)
perf_ov = measure_perf(ov_model, sample)
Measuring performance:   0%|          | 0/10 [00:00<?, ?it/s]
Measuring performance:   0%|          | 0/10 [00:00<?, ?it/s]
print(f"Mean torch {model_id.value} generation time: {perf_torch:.3f}s")
print(f"Mean openvino {model_id.value} generation time: {perf_ov:.3f}s")
print(f"Performance {model_id.value} openvino speedup: {perf_torch / perf_ov:.3f}")
Mean torch distil-large-v2 generation time: 3.064s
Mean openvino distil-large-v2 generation time: 1.819s
Performance distil-large-v2 openvino speedup: 1.684

load_in_8bit

Compare with OpenAI Whisper

Usage OpenVINO model with HuggingFace pipelines

Like the original PyTorch model, the OpenVINO model is also compatible with HuggingFace pipeline interface for automatic-speech-recognition. Pipeline can be used for long audio transcription. Distil-Whisper uses a chunked algorithm to transcribe long-form audio files. In practice, this chunked long-form algorithm is 9x faster than the sequential algorithm proposed by OpenAI in the Whisper paper. To enable chunking, pass the chunk_length_s parameter to the pipeline. For Distil-Whisper, a chunk length of 15 seconds is optimal. To activate batching, pass the argument batch_size.

from transformers import pipeline

ov_model.generation_config = pt_model.generation_config

pipe = pipeline(
    "automatic-speech-recognition",
    model=ov_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=16,
)
The model 'OVModelForWhisper' is not supported for automatic-speech-recognition. Supported models are ['Pop2PianoForConditionalGeneration', 'SeamlessM4TForSpeechToText', 'SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'SpeechT5ForSpeechToText', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample_long = dataset[0]


def format_timestamp(seconds: float):
    """
    format time in srt-file expected format
    """
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    return (
        f"{hours}:" if hours > 0 else "00:"
    ) + f"{minutes:02d}:{seconds:02d},{milliseconds:03d}"


def prepare_srt(transcription):
    """
    Format transcription into srt file format
    """
    segment_lines = []
    for idx, segment in enumerate(transcription["chunks"]):
        segment_lines.append(str(idx + 1) + "\n")
        timestamps = segment["timestamp"]
        time_start = format_timestamp(timestamps[0])
        time_end = format_timestamp(timestamps[1])
        time_str = f"{time_start} --> {time_end}\n"
        segment_lines.append(time_str)
        segment_lines.append(segment["text"] + "\n\n")
    return segment_lines

return_timestamps argument allows getting timestamps of start and end of speech associated with each processed chunk. It could be useful in tasks like speech separation or generation of video subtitles. In this example, we provide output formatting in SRT format, one of the popular subtitles format.

result = pipe(sample_long["audio"].copy(), return_timestamps=True)
srt_lines = prepare_srt(result)

display(
    ipd.Audio(sample_long["audio"]["array"], rate=sample_long["audio"]["sampling_rate"])
)
print("".join(srt_lines))
1
00:00:00,000 --> 00:00:06,560
 Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

2
00:00:06,560 --> 00:00:11,280
 Nor is Mr. Quilter's manner less interesting than his matter.

3
00:00:11,280 --> 00:00:16,840
 He tells us that at this festive season of the year, with Christmas and roast beef looming

4
00:00:16,840 --> 00:00:23,760
 before us, similes drawn from eating and its results occur most readily to the mind.

5
00:00:23,760 --> 00:00:29,360
 He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and

6
00:00:29,360 --> 00:00:33,640
 can discover in it but little of Rocky Ithaca.

7
00:00:33,640 --> 00:00:39,760
 Lennel's pictures are a sort of upgards and Adam paintings, and Mason's exquisite

8
00:00:39,760 --> 00:00:44,720
 idles are as national as a jingo poem.

9
00:00:44,720 --> 00:00:50,320
 Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker used

10
00:00:50,320 --> 00:00:52,920
 to flash his teeth.

11
00:00:52,920 --> 00:00:58,680
 And Mr. John Collier gives his sitter a cheerful slap on the back, before he says, like

12
00:00:58,680 --> 00:01:01,120
 a shampooer and a Turkish bath,

13
00:01:01,120 --> 00:01:02,000
 Next man!

Quantization

NNCF enables post-training quantization by adding the quantization layers into the model graph and then using a subset of the training dataset to initialize the parameters of these additional quantization layers. The framework is designed so that modifications to your original training code are minor.

The optimization process contains the following steps:

  1. Create a calibration dataset for quantization.

  2. Run nncf.quantize to obtain quantized encoder and decoder models.

  3. Serialize the INT8 model using openvino.save_model function.

NOTE: Quantization is time and memory consuming operation. Running quantization code below may take some time.

Please select below whether you would like to run Distil-Whisper quantization.

to_quantize = widgets.Checkbox(
    value=True,
    description='Quantization',
    disabled=False,
)

to_quantize
Checkbox(value=True, description='Quantization')
# Fetch notebook_utils module
import urllib.request

urllib.request.urlretrieve(
    url='https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/main/notebooks/utils/skip_kernel_extension.py',
    filename='skip_kernel_extension.py'
)

%load_ext skip_kernel_extension

Prepare calibration datasets

First step is to prepare calibration datasets for quantization. Since we quantize whisper encoder and decoder separately, we need to prepare a calibration dataset for each of the models. We import an InferRequestWrapper class that will intercept model inputs and collect them to a list. Then we run model inference on some small amount of audio samples. Generally, increasing the calibration dataset size improves quantization quality.

%%skip not $to_quantize.value

from itertools import islice
from optimum.intel.openvino.quantization import InferRequestWrapper


def collect_calibration_dataset(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    # Overwrite model request properties, saving the original ones for restoring later
    original_encoder_request = ov_model.encoder.request
    original_decoder_with_past_request = ov_model.decoder_with_past.request
    encoder_calibration_data = []
    decoder_calibration_data = []
    ov_model.encoder.request = InferRequestWrapper(original_encoder_request, encoder_calibration_data)
    ov_model.decoder_with_past.request = InferRequestWrapper(original_decoder_with_past_request,
                                                             decoder_calibration_data)

    calibration_dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
    for sample in tqdm(islice(calibration_dataset, calibration_dataset_size), desc="Collecting calibration data",
                       total=calibration_dataset_size):
        input_features = extract_input_features(sample)
        ov_model.generate(input_features)

    ov_model.encoder.request = original_encoder_request
    ov_model.decoder_with_past.request = original_decoder_with_past_request

    return encoder_calibration_data, decoder_calibration_data

Quantize Distil-Whisper encoder and decoder models

Below we run the quantize function which calls nncf.quantize on Distil-Whisper encoder and decoder-with-past models. We don’t quantize first-step-decoder because its share in whole inference time is negligible.

%%skip not $to_quantize.value

import gc
import shutil
import nncf

CALIBRATION_DATASET_SIZE = 50
quantized_model_path = Path(f"{model_path}_quantized")


def quantize(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    if not quantized_model_path.exists():
        encoder_calibration_data, decoder_calibration_data = collect_calibration_dataset(
            ov_model, calibration_dataset_size
        )
        print("Quantizing encoder")
        quantized_encoder = nncf.quantize(
            ov_model.encoder.model,
            nncf.Dataset(encoder_calibration_data),
            subset_size=len(encoder_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.50)
        )
        ov.save_model(quantized_encoder, quantized_model_path / "openvino_encoder_model.xml")
        del quantized_encoder
        del encoder_calibration_data
        gc.collect()

        print("Quantizing decoder with past")
        quantized_decoder_with_past = nncf.quantize(
            ov_model.decoder_with_past.model,
            nncf.Dataset(decoder_calibration_data),
            subset_size=len(decoder_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.95)
        )
        ov.save_model(quantized_decoder_with_past, quantized_model_path / "openvino_decoder_with_past_model.xml")
        del quantized_decoder_with_past
        del decoder_calibration_data
        gc.collect()

        # Copy the config file and the first-step-decoder manually
        shutil.copy(model_path / "config.json", quantized_model_path / "config.json")
        shutil.copy(model_path / "openvino_decoder_model.xml", quantized_model_path / "openvino_decoder_model.xml")
        shutil.copy(model_path / "openvino_decoder_model.bin", quantized_model_path / "openvino_decoder_model.bin")

    quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_model_path, ov_config=ov_config, compile=False)
    quantized_ov_model.to(device.value)
    quantized_ov_model.compile()
    return quantized_ov_model


ov_quantized_model = quantize(ov_model, CALIBRATION_DATASET_SIZE)
Collecting calibration data:   0%|          | 0/10 [00:00<?, ?it/s]
Quantizing encoder
Statistics collection: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:15<00:00,  1.55s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:10<00:00, 12.24it/s]
INFO:nncf:96 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.99s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:21<00:00,  7.60it/s]
Quantizing decoder with past
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:04<00:00, 85.63it/s]
Applying Smooth Quant: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 16.09it/s]
INFO:nncf:12 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:07<00:00, 52.93it/s]
Applying Fast Bias correction: 100%|████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 18.50it/s]
Compiling the encoder to AUTO ...
Compiling the decoder to AUTO ...
Compiling the decoder to AUTO ...

Run quantized model inference

Let’s compare the transcription results for original and quantized models.

%%skip not $to_quantize.value

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]
input_features = extract_input_features(sample)

predicted_ids = ov_model.generate(input_features)
transcription_original = processor.batch_decode(predicted_ids, skip_special_tokens=True)

predicted_ids = ov_quantized_model.generate(input_features)
transcription_quantized = processor.batch_decode(predicted_ids, skip_special_tokens=True)

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))
print(f"Original : {transcription_original[0]}")
print(f"Quantized: {transcription_quantized[0]}")
Original :  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.
Quantized:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

Results are the same!

Compare performance and accuracy of the original and quantized models

Finally, we compare original and quantized Distil-Whisper models from accuracy and performance stand-points.

To measure accuracy, we use 1 - WER as a metric, where WER stands for Word Error Rate.

When measuring inference time, we do it separately for encoder and decoder-with-past model forwards, and for the whole model inference too.

%%skip not $to_quantize.value

import time
from contextlib import contextmanager
from jiwer import wer, wer_standardize


TEST_DATASET_SIZE = 50
MEASURE_TIME = False

@contextmanager
def time_measurement():
    global MEASURE_TIME
    try:
        MEASURE_TIME = True
        yield
    finally:
        MEASURE_TIME = False

def time_fn(obj, fn_name, time_list):
    original_fn = getattr(obj, fn_name)

    def wrapper(*args, **kwargs):
        if not MEASURE_TIME:
            return original_fn(*args, **kwargs)
        start_time = time.perf_counter()
        result = original_fn(*args, **kwargs)
        end_time = time.perf_counter()
        time_list.append(end_time - start_time)
        return result

    setattr(obj, fn_name, wrapper)

def calculate_transcription_time_and_accuracy(ov_model, test_samples):
    encoder_infer_times = []
    decoder_with_past_infer_times = []
    whole_infer_times = []
    time_fn(ov_model, "generate", whole_infer_times)
    time_fn(ov_model.encoder, "forward", encoder_infer_times)
    time_fn(ov_model.decoder_with_past, "forward", decoder_with_past_infer_times)

    ground_truths = []
    predictions = []
    for data_item in tqdm(test_samples, desc="Measuring performance and accuracy"):
        input_features = extract_input_features(data_item)

        with time_measurement():
            predicted_ids = ov_model.generate(input_features)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        ground_truths.append(data_item["text"])
        predictions.append(transcription[0])

    word_accuracy = (1 - wer(ground_truths, predictions, reference_transform=wer_standardize,
                             hypothesis_transform=wer_standardize)) * 100
    mean_whole_infer_time = sum(whole_infer_times)
    mean_encoder_infer_time = sum(encoder_infer_times)
    mean_decoder_with_time_infer_time = sum(decoder_with_past_infer_times)
    return word_accuracy, (mean_whole_infer_time, mean_encoder_infer_time, mean_decoder_with_time_infer_time)

test_dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
test_dataset = test_dataset.shuffle(seed=42).take(TEST_DATASET_SIZE)
test_samples = [sample for sample in test_dataset]

accuracy_original, times_original = calculate_transcription_time_and_accuracy(ov_model, test_samples)
accuracy_quantized, times_quantized = calculate_transcription_time_and_accuracy(ov_quantized_model, test_samples)
print(f"Encoder performance speedup: {times_original[1] / times_quantized[1]:.3f}")
print(f"Decoder with past performance speedup: {times_original[2] / times_quantized[2]:.3f}")
print(f"Whole pipeline performance speedup: {times_original[0] / times_quantized[0]:.3f}")
print(f"Whisper transcription word accuracy. Original model: {accuracy_original:.2f}%. Quantized model: {accuracy_quantized:.2f}%.")
print(f"Accuracy drop: {accuracy_original - accuracy_quantized:.2f}%.")
Got disconnected from remote data host. Retrying in 5sec [1/20]
Got disconnected from remote data host. Retrying in 5sec [2/20]
Measuring performance and accuracy:   0%|          | 0/50 [00:00<?, ?it/s]
Measuring performance and accuracy:   0%|          | 0/50 [00:00<?, ?it/s]
Encoder performance speedup: 1.751
Decoder with past performance speedup: 1.777
Whole pipeline performance speedup: 1.711
Whisper transcription word accuracy. Original model: 85.29%. Quantized model: 85.29%.
Accuracy drop: 0.00%.

As we can see quantization significantly improves model inference time without major accuracy drop!

Interactive demo

We are also providing an interactive demo using the Gradio interface, where you can test model capabilities on your own audio data (using the upload button) or record using your microphone. Please note, that Distil-Whisper is currently only available for English speech recognition. Multilingual support will be provided later.

from transformers.pipelines.audio_utils import ffmpeg_read
import gradio as gr
import urllib.request

urllib.request.urlretrieve(
    url="https://huggingface.co/spaces/distil-whisper/whisper-vs-distil-whisper/resolve/main/assets/example_1.wav",
    filename="example_1.wav",
)

BATCH_SIZE = 16
MAX_AUDIO_MINS = 30  # maximum audio input in minutes


generate_kwargs = {"language": "en", "task": "transcribe"} if not model_id.value.endswith(".en") else {}
ov_pipe = pipeline(
    "automatic-speech-recognition",
    model=ov_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    generate_kwargs=generate_kwargs,
)
ov_pipe_forward = ov_pipe._forward

if to_quantize.value:
    ov_quantized_model.generation_config = ov_model.generation_config
    ov_quantized_pipe = pipeline(
        "automatic-speech-recognition",
        model=ov_quantized_model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=15,
        generate_kwargs=generate_kwargs,
    )
    ov_quantized_pipe_forward = ov_quantized_pipe._forward


def transcribe(inputs, quantized=False):
    pipe = ov_quantized_pipe if quantized else ov_pipe
    pipe_forward = ov_quantized_pipe_forward if quantized else ov_pipe_forward

    if inputs is None:
        raise gr.Error(
            "No audio file submitted! Please record or upload an audio file before submitting your request."
        )

    with open(inputs, "rb") as f:
        inputs = f.read()

    inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
    audio_length_mins = len(inputs) / pipe.feature_extractor.sampling_rate / 60

    if audio_length_mins > MAX_AUDIO_MINS:
        raise gr.Error(
            f"To ensure fair usage of the Space, the maximum audio length permitted is {MAX_AUDIO_MINS} minutes."
            f"Got an audio of length {round(audio_length_mins, 3)} minutes."
        )

    inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}

    def _forward_ov_time(*args, **kwargs):
        global ov_time
        start_time = time.time()
        result = pipe_forward(*args, **kwargs)
        ov_time = time.time() - start_time
        ov_time = round(ov_time, 2)
        return result

    pipe._forward = _forward_ov_time
    ov_text = pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
    return ov_text, ov_time


with gr.Blocks() as demo:
    gr.HTML(
        """
                <div style="text-align: center; max-width: 700px; margin: 0 auto;">
                  <div
                    style="
                      display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
                    "
                  >
                    <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
                      OpenVINO Distil-Whisper demo
                    </h1>
                  </div>
                </div>
            """
    )
    audio = gr.components.Audio(type="filepath", label="Audio input")
    with gr.Row():
        button = gr.Button("Transcribe")
        if to_quantize.value:
            button_q = gr.Button("Transcribe quantized")
    with gr.Row():
        infer_time = gr.components.Textbox(
            label="OpenVINO Distil-Whisper Transcription Time (s)"
        )
        if to_quantize.value:
            infer_time_q = gr.components.Textbox(
                label="OpenVINO Quantized Distil-Whisper Transcription Time (s)"
            )
    with gr.Row():
        transcription = gr.components.Textbox(
            label="OpenVINO Distil-Whisper Transcription", show_copy_button=True
        )
        if to_quantize.value:
            transcription_q = gr.components.Textbox(
                label="OpenVINO Quantized Distil-Whisper Transcription", show_copy_button=True
            )
    button.click(
        fn=transcribe,
        inputs=audio,
        outputs=[transcription, infer_time],
    )
    if to_quantize.value:
        button_q.click(
            fn=transcribe,
            inputs=[audio, gr.Number(value=1, visible=False)],
            outputs=[transcription_q, infer_time_q],
        )
    gr.Markdown("## Examples")
    gr.Examples(
        [["./example_1.wav"]],
        audio,
        outputs=[transcription, infer_time],
        fn=transcribe,
        cache_examples=False,
    )
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
try:
    demo.launch(debug=False)
except Exception:
    demo.launch(share=True, debug=False)