MMS: Scaling Speech Technology to 1000+ languages with OpenVINO™#

This Jupyter notebook can be launched after a local installation only.

Github

The Massively Multilingual Speech (MMS) project expands speech technology from about 100 languages to over 1,000 by building a single multilingual speech recognition model supporting over 1,100 languages (more than 10 times as many as before), language identification models able to identify over 4,000 languages (40 times more than before), pretrained models supporting over 1,400 languages, and text-to-speech models for over 1,100 languages.

The MMS model was proposed in Scaling Speech Technology to 1,000+ Languages. The models and code are originally released here.

There are different open sourced models in the MMS project: Automatic Speech Recognition (ASR), Language Identification (LID) and Speech Synthesis (TTS). A simple diagram of this is below.

LID and ASR flow

LID and ASR flow#

In this notebook we are considering ASR and LID. We will use LID model to identify language, and then language-specific ASR model to recognize it. Additional models quantization step is employed to improve models inference speed. In the end of the notebook there’s a Gradio-based interactive demo.

Table of contents:

Installation Instructions#

This is a self-contained example that relies solely on its own code.

We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.

Prerequisites#

%pip install -q --upgrade pip
%pip install -q "transformers>=4.33.1" "torch>=2.1" "openvino>=2023.1.0" "numpy>=1.21.0" "nncf>=2.9.0"
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu torch "datasets>=2.14.6" accelerate soundfile librosa "gradio>=4.19" jiwer
from pathlib import Path

import torch

import openvino as ov

Prepare an example audio#

Read an audio file and process the audio data. Make sure that the audio data is sampled to 16000 kHz. For this example we will use a streamable version of the Multilingual LibriSpeech (MLS) dataset. It supports contains example on 7 languages: 'german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'. Choose one of them.

import ipywidgets as widgets


SAMPLE_LANG = widgets.Dropdown(
    options=["german", "dutch", "french", "spanish", "italian", "portuguese", "polish"],
    value="german",
    description="Dataset language:",
    disabled=False,
)

SAMPLE_LANG
Dropdown(description='Dataset language:', options=('german', 'dutch', 'french', 'spanish', 'italian', 'portugu…

Specify streaming=True to not download the entire dataset.

from datasets import load_dataset


mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True, trust_remote_code=True)
mls_dataset = iter(mls_dataset)  # make it iterable

example = next(mls_dataset)  # get one example

Example has a dictionary structure. It contains an audio data and a text transcription.

print(example)  # look at structure
{'file': None, 'audio': {'path': '1054_1599_000000.flac', 'array': array([-0.00131226, -0.00152588, -0.00134277, ...,  0.00411987,
        0.00308228, -0.00015259]), 'sampling_rate': 16000}, 'text': 'mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit', 'speaker_id': 1054, 'chapter_id': 1599, 'id': '1054_1599_000000'}
import IPython.display as ipd

print(example["transcript"])
ipd.Audio(example["audio"]["array"], rate=16_000)
mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit

Language Identification (LID)#

Download pretrained model and processor#

Different LID models are available based on the number of languages they can recognize - 126, 256, 512, 1024, 2048, 4017. We will use 126.

from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor

model_id = "facebook/mms-lid-126"

lid_processor = AutoFeatureExtractor.from_pretrained(model_id)
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)

Use the original model to run an inference#

inputs = lid_processor(example["audio"]["array"], sampling_rate=16_000, return_tensors="pt")

with torch.no_grad():
    outputs = lid_model(**inputs).logits

lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = lid_model.config.id2label[lang_id]
print(detected_lang)
deu

Convert to OpenVINO IR model and run an inference#

Select device from dropdown list for running inference using OpenVINO

core = ov.Core()

import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)

from notebook_utils import device_widget

device = device_widget("CPU", exclude=["NPU"])

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

Convert model to OpenVINO format and compile it

MAX_SEQ_LENGTH = 30480

lid_model_xml_path = Path("models/ov_lid_model.xml")


def get_lid_model(model_path, compiled=True):
    input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)

    if not model_path.exists() and model_path == lid_model_xml_path:
        lid_model_xml_path.parent.mkdir(parents=True, exist_ok=True)
        converted_model = ov.convert_model(lid_model, example_input={"input_values": input_values})
        ov.save_model(converted_model, lid_model_xml_path)
        if not compiled:
            return converted_model
    if compiled:
        return core.compile_model(model_path, device_name=device.value)
    return core.read_model(model_path)


compiled_lid_model = get_lid_model(lid_model_xml_path)
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):

Now it is possible to run an inference.

def detect_language(compiled_model, audio_data):
    inputs = lid_processor(audio_data, sampling_rate=16_000, return_tensors="pt")

    outputs = compiled_model(inputs["input_values"])[0]

    lang_id = torch.argmax(torch.from_numpy(outputs), dim=-1)[0].item()
    detected_lang = lid_model.config.id2label[lang_id]

    return detected_lang
detect_language(compiled_lid_model, example["audio"]["array"])
'deu'

Let’s check another language.

SAMPLE_LANG = widgets.Dropdown(
    options=["german", "dutch", "french", "spanish", "italian", "portuguese", "polish"],
    value="french",
    description="Dataset language:",
    disabled=False,
)

SAMPLE_LANG
Dropdown(description='Dataset language:', index=2, options=('german', 'dutch', 'french', 'spanish', 'italian',…
mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True, trust_remote_code=True)
mls_dataset = iter(mls_dataset)

example = next(mls_dataset)
print(example["transcript"])
ipd.Audio(example["audio"]["array"], rate=16_000)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
language_id = detect_language(compiled_lid_model, example["audio"]["array"])
print(language_id)
fra

Automatic Speech Recognition (ASR)#

Download pretrained model and processor#

Download pretrained model and processor. By default, MMS loads adapter weights for English. If you want to load adapter weights of another language make sure to specify target_lang=<your-chosen-target-lang> as well as ignore_mismatched_sizes=True. The ignore_mismatched_sizes=True keyword has to be passed to allow the language model head to be resized according to the vocabulary of the specified language. Similarly, the processor should be loaded with the same target language. It is also possible to change the supported language later.

from transformers import Wav2Vec2ForCTC, AutoProcessor

model_id = "facebook/mms-1b-all"

asr_processor = AutoProcessor.from_pretrained(model_id)
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id)

You can look at all supported languages:

asr_processor.tokenizer.vocab.keys()
dict_keys(['abi', 'abk', 'abp', 'aca', 'acd', 'ace', 'acf', 'ach', 'acn', 'acr', 'acu', 'ade', 'adh', 'adj', 'adx', 'aeu', 'afr', 'agd', 'agg', 'agn', 'agr', 'agu', 'agx', 'aha', 'ahk', 'aia', 'aka', 'akb', 'ake', 'akp', 'alj', 'alp', 'alt', 'alz', 'ame', 'amf', 'amh', 'ami', 'amk', 'ann', 'any', 'aoz', 'apb', 'apr', 'ara', 'arl', 'asa', 'asg', 'asm', 'ast', 'ata', 'atb', 'atg', 'ati', 'atq', 'ava', 'avn', 'avu', 'awa', 'awb', 'ayo', 'ayr', 'ayz', 'azb', 'azg', 'azj-script_cyrillic', 'azj-script_latin', 'azz', 'bak', 'bam', 'ban', 'bao', 'bas', 'bav', 'bba', 'bbb', 'bbc', 'bbo', 'bcc-script_arabic', 'bcc-script_latin', 'bcl', 'bcw', 'bdg', 'bdh', 'bdq', 'bdu', 'bdv', 'beh', 'bel', 'bem', 'ben', 'bep', 'bex', 'bfa', 'bfo', 'bfy', 'bfz', 'bgc', 'bgq', 'bgr', 'bgt', 'bgw', 'bha', 'bht', 'bhz', 'bib', 'bim', 'bis', 'biv', 'bjr', 'bjv', 'bjw', 'bjz', 'bkd', 'bkv', 'blh', 'blt', 'blx', 'blz', 'bmq', 'bmr', 'bmu', 'bmv', 'bng', 'bno', 'bnp', 'boa', 'bod', 'boj', 'bom', 'bor', 'bos', 'bov', 'box', 'bpr', 'bps', 'bqc', 'bqi', 'bqj', 'bqp', 'bre', 'bru', 'bsc', 'bsq', 'bss', 'btd', 'bts', 'btt', 'btx', 'bud', 'bul', 'bus', 'bvc', 'bvz', 'bwq', 'bwu', 'byr', 'bzh', 'bzi', 'bzj', 'caa', 'cab', 'cac-dialect_sanmateoixtatan', 'cac-dialect_sansebastiancoatan', 'cak-dialect_central', 'cak-dialect_santamariadejesus', 'cak-dialect_santodomingoxenacoj', 'cak-dialect_southcentral', 'cak-dialect_western', 'cak-dialect_yepocapa', 'cap', 'car', 'cas', 'cat', 'cax', 'cbc', 'cbi', 'cbr', 'cbs', 'cbt', 'cbu', 'cbv', 'cce', 'cco', 'cdj', 'ceb', 'ceg', 'cek', 'ces', 'cfm', 'cgc', 'che', 'chf', 'chv', 'chz', 'cjo', 'cjp', 'cjs', 'ckb', 'cko', 'ckt', 'cla', 'cle', 'cly', 'cme', 'cmn-script_simplified', 'cmo-script_khmer', 'cmo-script_latin', 'cmr', 'cnh', 'cni', 'cnl', 'cnt', 'coe', 'cof', 'cok', 'con', 'cot', 'cou', 'cpa', 'cpb', 'cpu', 'crh', 'crk-script_latin', 'crk-script_syllabics', 'crn', 'crq', 'crs', 'crt', 'csk', 'cso', 'ctd', 'ctg', 'cto', 'ctu', 'cuc', 'cui', 'cuk', 'cul', 'cwa', 'cwe', 'cwt', 'cya', 'cym', 'daa', 'dah', 'dan', 'dar', 'dbj', 'dbq', 'ddn', 'ded', 'des', 'deu', 'dga', 'dgi', 'dgk', 'dgo', 'dgr', 'dhi', 'did', 'dig', 'dik', 'dip', 'div', 'djk', 'dnj-dialect_blowowest', 'dnj-dialect_gweetaawueast', 'dnt', 'dnw', 'dop', 'dos', 'dsh', 'dso', 'dtp', 'dts', 'dug', 'dwr', 'dyi', 'dyo', 'dyu', 'dzo', 'eip', 'eka', 'ell', 'emp', 'enb', 'eng', 'enx', 'epo', 'ese', 'ess', 'est', 'eus', 'evn', 'ewe', 'eza', 'fal', 'fao', 'far', 'fas', 'fij', 'fin', 'flr', 'fmu', 'fon', 'fra', 'frd', 'fry', 'ful', 'gag-script_cyrillic', 'gag-script_latin', 'gai', 'gam', 'gau', 'gbi', 'gbk', 'gbm', 'gbo', 'gde', 'geb', 'gej', 'gil', 'gjn', 'gkn', 'gld', 'gle', 'glg', 'glk', 'gmv', 'gna', 'gnd', 'gng', 'gof-script_latin', 'gog', 'gor', 'gqr', 'grc', 'gri', 'grn', 'grt', 'gso', 'gub', 'guc', 'gud', 'guh', 'guj', 'guk', 'gum', 'guo', 'guq', 'guu', 'gux', 'gvc', 'gvl', 'gwi', 'gwr', 'gym', 'gyr', 'had', 'hag', 'hak', 'hap', 'hat', 'hau', 'hay', 'heb', 'heh', 'hif', 'hig', 'hil', 'hin', 'hlb', 'hlt', 'hne', 'hnn', 'hns', 'hoc', 'hoy', 'hrv', 'hsb', 'hto', 'hub', 'hui', 'hun', 'hus-dialect_centralveracruz', 'hus-dialect_westernpotosino', 'huu', 'huv', 'hvn', 'hwc', 'hye', 'hyw', 'iba', 'ibo', 'icr', 'idd', 'ifa', 'ifb', 'ife', 'ifk', 'ifu', 'ify', 'ign', 'ikk', 'ilb', 'ilo', 'imo', 'ina', 'inb', 'ind', 'iou', 'ipi', 'iqw', 'iri', 'irk', 'isl', 'ita', 'itl', 'itv', 'ixl-dialect_sangasparchajul', 'ixl-dialect_sanjuancotzal', 'ixl-dialect_santamarianebaj', 'izr', 'izz', 'jac', 'jam', 'jav', 'jbu', 'jen', 'jic', 'jiv', 'jmc', 'jmd', 'jpn', 'jun', 'juy', 'jvn', 'kaa', 'kab', 'kac', 'kak', 'kam', 'kan', 'kao', 'kaq', 'kat', 'kay', 'kaz', 'kbo', 'kbp', 'kbq', 'kbr', 'kby', 'kca', 'kcg', 'kdc', 'kde', 'kdh', 'kdi', 'kdj', 'kdl', 'kdn', 'kdt', 'kea', 'kek', 'ken', 'keo', 'ker', 'key', 'kez', 'kfb', 'kff-script_telugu', 'kfw', 'kfx', 'khg', 'khm', 'khq', 'kia', 'kij', 'kik', 'kin', 'kir', 'kjb', 'kje', 'kjg', 'kjh', 'kki', 'kkj', 'kle', 'klu', 'klv', 'klw', 'kma', 'kmd', 'kml', 'kmr-script_arabic', 'kmr-script_cyrillic', 'kmr-script_latin', 'kmu', 'knb', 'kne', 'knf', 'knj', 'knk', 'kno', 'kog', 'kor', 'kpq', 'kps', 'kpv', 'kpy', 'kpz', 'kqe', 'kqp', 'kqr', 'kqy', 'krc', 'kri', 'krj', 'krl', 'krr', 'krs', 'kru', 'ksb', 'ksr', 'kss', 'ktb', 'ktj', 'kub', 'kue', 'kum', 'kus', 'kvn', 'kvw', 'kwd', 'kwf', 'kwi', 'kxc', 'kxf', 'kxm', 'kxv', 'kyb', 'kyc', 'kyf', 'kyg', 'kyo', 'kyq', 'kyu', 'kyz', 'kzf', 'lac', 'laj', 'lam', 'lao', 'las', 'lat', 'lav', 'law', 'lbj', 'lbw', 'lcp', 'lee', 'lef', 'lem', 'lew', 'lex', 'lgg', 'lgl', 'lhu', 'lia', 'lid', 'lif', 'lin', 'lip', 'lis', 'lit', 'lje', 'ljp', 'llg', 'lln', 'lme', 'lnd', 'lns', 'lob', 'lok', 'lom', 'lon', 'loq', 'lsi', 'lsm', 'ltz', 'luc', 'lug', 'luo', 'lwo', 'lww', 'lzz', 'maa-dialect_sanantonio', 'maa-dialect_sanjeronimo', 'mad', 'mag', 'mah', 'mai', 'maj', 'mak', 'mal', 'mam-dialect_central', 'mam-dialect_northern', 'mam-dialect_southern', 'mam-dialect_western', 'maq', 'mar', 'maw', 'maz', 'mbb', 'mbc', 'mbh', 'mbj', 'mbt', 'mbu', 'mbz', 'mca', 'mcb', 'mcd', 'mco', 'mcp', 'mcq', 'mcu', 'mda', 'mdf', 'mdv', 'mdy', 'med', 'mee', 'mej', 'men', 'meq', 'met', 'mev', 'mfe', 'mfh', 'mfi', 'mfk', 'mfq', 'mfy', 'mfz', 'mgd', 'mge', 'mgh', 'mgo', 'mhi', 'mhr', 'mhu', 'mhx', 'mhy', 'mib', 'mie', 'mif', 'mih', 'mil', 'mim', 'min', 'mio', 'mip', 'miq', 'mit', 'miy', 'miz', 'mjl', 'mjv', 'mkd', 'mkl', 'mkn', 'mlg', 'mlt', 'mmg', 'mnb', 'mnf', 'mnk', 'mnw', 'mnx', 'moa', 'mog', 'mon', 'mop', 'mor', 'mos', 'mox', 'moz', 'mpg', 'mpm', 'mpp', 'mpx', 'mqb', 'mqf', 'mqj', 'mqn', 'mri', 'mrw', 'msy', 'mtd', 'mtj', 'mto', 'muh', 'mup', 'mur', 'muv', 'muy', 'mvp', 'mwq', 'mwv', 'mxb', 'mxq', 'mxt', 'mxv', 'mya', 'myb', 'myk', 'myl', 'myv', 'myx', 'myy', 'mza', 'mzi', 'mzj', 'mzk', 'mzm', 'mzw', 'nab', 'nag', 'nan', 'nas', 'naw', 'nca', 'nch', 'ncj', 'ncl', 'ncu', 'ndj', 'ndp', 'ndv', 'ndy', 'ndz', 'neb', 'new', 'nfa', 'nfr', 'nga', 'ngl', 'ngp', 'ngu', 'nhe', 'nhi', 'nhu', 'nhw', 'nhx', 'nhy', 'nia', 'nij', 'nim', 'nin', 'nko', 'nlc', 'nld', 'nlg', 'nlk', 'nmz', 'nnb', 'nno', 'nnq', 'nnw', 'noa', 'nob', 'nod', 'nog', 'not', 'npi', 'npl', 'npy', 'nso', 'nst', 'nsu', 'ntm', 'ntr', 'nuj', 'nus', 'nuz', 'nwb', 'nxq', 'nya', 'nyf', 'nyn', 'nyo', 'nyy', 'nzi', 'obo', 'oci', 'ojb-script_latin', 'ojb-script_syllabics', 'oku', 'old', 'omw', 'onb', 'ood', 'orm', 'ory', 'oss', 'ote', 'otq', 'ozm', 'pab', 'pad', 'pag', 'pam', 'pan', 'pao', 'pap', 'pau', 'pbb', 'pbc', 'pbi', 'pce', 'pcm', 'peg', 'pez', 'pib', 'pil', 'pir', 'pis', 'pjt', 'pkb', 'pls', 'plw', 'pmf', 'pny', 'poh-dialect_eastern', 'poh-dialect_western', 'poi', 'pol', 'por', 'poy', 'ppk', 'pps', 'prf', 'prk', 'prt', 'pse', 'pss', 'ptu', 'pui', 'pus', 'pwg', 'pww', 'pxm', 'qub', 'quc-dialect_central', 'quc-dialect_east', 'quc-dialect_north', 'quf', 'quh', 'qul', 'quw', 'quy', 'quz', 'qvc', 'qve', 'qvh', 'qvm', 'qvn', 'qvo', 'qvs', 'qvw', 'qvz', 'qwh', 'qxh', 'qxl', 'qxn', 'qxo', 'qxr', 'rah', 'rai', 'rap', 'rav', 'raw', 'rej', 'rel', 'rgu', 'rhg', 'rif-script_arabic', 'rif-script_latin', 'ril', 'rim', 'rjs', 'rkt', 'rmc-script_cyrillic', 'rmc-script_latin', 'rmo', 'rmy-script_cyrillic', 'rmy-script_latin', 'rng', 'rnl', 'roh-dialect_sursilv', 'roh-dialect_vallader', 'rol', 'ron', 'rop', 'rro', 'rub', 'ruf', 'rug', 'run', 'rus', 'sab', 'sag', 'sah', 'saj', 'saq', 'sas', 'sat', 'sba', 'sbd', 'sbl', 'sbp', 'sch', 'sck', 'sda', 'sea', 'seh', 'ses', 'sey', 'sgb', 'sgj', 'sgw', 'shi', 'shk', 'shn', 'sho', 'shp', 'sid', 'sig', 'sil', 'sja', 'sjm', 'sld', 'slk', 'slu', 'slv', 'sml', 'smo', 'sna', 'snd', 'sne', 'snn', 'snp', 'snw', 'som', 'soy', 'spa', 'spp', 'spy', 'sqi', 'sri', 'srm', 'srn', 'srp-script_cyrillic', 'srp-script_latin', 'srx', 'stn', 'stp', 'suc', 'suk', 'sun', 'sur', 'sus', 'suv', 'suz', 'swe', 'swh', 'sxb', 'sxn', 'sya', 'syl', 'sza', 'tac', 'taj', 'tam', 'tao', 'tap', 'taq', 'tat', 'tav', 'tbc', 'tbg', 'tbk', 'tbl', 'tby', 'tbz', 'tca', 'tcc', 'tcs', 'tcz', 'tdj', 'ted', 'tee', 'tel', 'tem', 'teo', 'ter', 'tes', 'tew', 'tex', 'tfr', 'tgj', 'tgk', 'tgl', 'tgo', 'tgp', 'tha', 'thk', 'thl', 'tih', 'tik', 'tir', 'tkr', 'tlb', 'tlj', 'tly', 'tmc', 'tmf', 'tna', 'tng', 'tnk', 'tnn', 'tnp', 'tnr', 'tnt', 'tob', 'toc', 'toh', 'tom', 'tos', 'tpi', 'tpm', 'tpp', 'tpt', 'trc', 'tri', 'trn', 'trs', 'tso', 'tsz', 'ttc', 'tte', 'ttq-script_tifinagh', 'tue', 'tuf', 'tuk-script_arabic', 'tuk-script_latin', 'tuo', 'tur', 'tvw', 'twb', 'twe', 'twu', 'txa', 'txq', 'txu', 'tye', 'tzh-dialect_bachajon', 'tzh-dialect_tenejapa', 'tzj-dialect_eastern', 'tzj-dialect_western', 'tzo-dialect_chamula', 'tzo-dialect_chenalho', 'ubl', 'ubu', 'udm', 'udu', 'uig-script_arabic', 'uig-script_cyrillic', 'ukr', 'umb', 'unr', 'upv', 'ura', 'urb', 'urd-script_arabic', 'urd-script_devanagari', 'urd-script_latin', 'urk', 'urt', 'ury', 'usp', 'uzb-script_cyrillic', 'uzb-script_latin', 'vag', 'vid', 'vie', 'vif', 'vmw', 'vmy', 'vot', 'vun', 'vut', 'wal-script_ethiopic', 'wal-script_latin', 'wap', 'war', 'waw', 'way', 'wba', 'wlo', 'wlx', 'wmw', 'wob', 'wol', 'wsg', 'wwa', 'xal', 'xdy', 'xed', 'xer', 'xho', 'xmm', 'xnj', 'xnr', 'xog', 'xon', 'xrb', 'xsb', 'xsm', 'xsr', 'xsu', 'xta', 'xtd', 'xte', 'xtm', 'xtn', 'xua', 'xuo', 'yaa', 'yad', 'yal', 'yam', 'yao', 'yas', 'yat', 'yaz', 'yba', 'ybb', 'ycl', 'ycn', 'yea', 'yka', 'yli', 'yor', 'yre', 'yua', 'yue-script_traditional', 'yuz', 'yva', 'zaa', 'zab', 'zac', 'zad', 'zae', 'zai', 'zam', 'zao', 'zaq', 'zar', 'zas', 'zav', 'zaw', 'zca', 'zga', 'zim', 'ziw', 'zlm', 'zmz', 'zne', 'zos', 'zpc', 'zpg', 'zpi', 'zpl', 'zpm', 'zpo', 'zpt', 'zpu', 'zpz', 'ztq', 'zty', 'zul', 'zyb', 'zyp', 'zza'])

Switch out the language adapters by calling the load_adapter() function for the model and set_target_lang() for the tokenizer. Pass the target language as an input - "detect_language_id" which was detected in the previous step.

asr_processor.tokenizer.set_target_lang(language_id)
asr_model.load_adapter(language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize

Use the original model for inference#

inputs = asr_processor(example["audio"]["array"], sampling_rate=16_000, return_tensors="pt")

with torch.no_grad():
    outputs = asr_model(**inputs).logits

ids = torch.argmax(outputs, dim=-1)[0]
transcription = asr_processor.decode(ids)
print(transcription)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle

Convert to OpenVINO IR model and run inference#

Convert to OpenVINO IR model format with ov.convert_model function directly. Use ov.save_model function to serialize the result of conversion. For convenience of further use, we will create a function for these purposes.

asr_model_xml_path_template = "models/ov_asr_{}_model.xml"


def get_asr_model(model_path_template, language_id, compiled=True):
    input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
    model_path = Path(model_path_template.format(language_id))

    asr_processor.tokenizer.set_target_lang(language_id)
    if not model_path.exists() and model_path_template == asr_model_xml_path_template:
        asr_model.load_adapter(language_id)

        model_path.parent.mkdir(parents=True, exist_ok=True)
        converted_model = ov.convert_model(asr_model, example_input={"input_values": input_values})
        ov.save_model(converted_model, model_path)
        if not compiled:
            return converted_model

    if compiled:
        return core.compile_model(model_path, device_name=device.value)
    return core.read_model(model_path)


compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)

Run inference.

def recognize_audio(compiled_model, src_audio):
    inputs = asr_processor(src_audio, sampling_rate=16_000, return_tensors="pt")
    outputs = compiled_model(inputs["input_values"])[0]

    ids = torch.argmax(torch.from_numpy(outputs), dim=-1)[0]
    transcription = asr_processor.decode(ids)

    return transcription


transcription = recognize_audio(compiled_asr_model, example["audio"]["array"])
print("Original text:", example["transcript"])
print("Transcription:", transcription)
Original text: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
Transcription: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle

Quantization#

NNCF enables post-training quantization by adding quantization layers into model graph and then using a subset of the training dataset to initialize the parameters of these additional quantization layers. Quantized operations are executed in INT8 instead of FP32/FP16 making model inference faster.

The optimization process contains the following steps:

  1. Create a calibration dataset for quantization.

  2. Run nncf.quantize() to obtain quantized models.

  3. Serialize quantized INT8 model using openvino.save_model() function.

Note: Quantization is time and memory consuming operation. Running quantization code below may take some time.

from notebook_utils import quantization_widget

compiled_quantized_lid_model = None
quantized_asr_model_xml_path_template = None

to_quantize = quantization_widget()

to_quantize
Checkbox(value=True, description='Quantization')

Let’s load skip magic extension to skip quantization if to_quantize is not selected

# Fetch `skip_kernel_extension` module
import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/skip_kernel_extension.py",
)
open("skip_kernel_extension.py", "w").write(r.text)

%load_ext skip_kernel_extension

Preparing calibration dataset#

Select the language to quantize the model for:

%%skip not $to_quantize.value

from IPython.display import display

display(SAMPLE_LANG)

Load dev split of the same MLS dataset for the selected language.

%%skip not $to_quantize.value

mls_dataset = iter(load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="dev", streaming=True, trust_remote_code=True))
example = next(mls_dataset)

Create calibration dataset for quantization.

%%skip not $to_quantize.value

CALIBRATION_DATASET_SIZE = 5

calibration_data = []
for i in range(CALIBRATION_DATASET_SIZE):
    data = asr_processor(next(mls_dataset)['audio']['array'], sampling_rate=16_000, return_tensors="np")
    calibration_data.append(data["input_values"])

Language identification model quantization#

Run LID model quantization.

%%skip not $to_quantize.value

import nncf

quantized_lid_model_xml_path = Path(str(lid_model_xml_path).replace(".xml", "_quantized.xml"))

if not quantized_lid_model_xml_path.exists():
    quantized_lid_model = nncf.quantize(
        get_lid_model(lid_model_xml_path, compiled=False),
        calibration_dataset=nncf.Dataset(calibration_data),
        subset_size=len(calibration_data),
        model_type=nncf.ModelType.TRANSFORMER
    )
    ov.save_model(quantized_lid_model, quantized_lid_model_xml_path)
    compiled_quantized_lid_model = core.compile_model(quantized_lid_model, device_name=device.value)
else:
    compiled_quantized_lid_model = get_lid_model(quantized_lid_model_xml_path)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00,  1.24s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 291/291 [00:18<00:00, 15.34it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:18<00:00,  3.65s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 298/298 [05:09<00:00,  1.04s/it]

Detect language with the quantized model.

%%skip not $to_quantize.value

language_id = detect_language(compiled_quantized_lid_model, example['audio']['array'])
print("Detected language:", language_id)
Detected language: fra

Speech recognition model quantization#

Run ASR model quantization.

%%skip not $to_quantize.value

quantized_asr_model_xml_path_template = asr_model_xml_path_template.replace(".xml", "_quantized.xml")
quantized_asr_model_xml_path = Path(quantized_asr_model_xml_path_template.format(language_id))

if not quantized_asr_model_xml_path.exists():
    quantized_asr_model = nncf.quantize(
        get_asr_model(asr_model_xml_path_template, language_id, compiled=False),
        calibration_dataset=nncf.Dataset(calibration_data),
        subset_size=len(calibration_data),
        model_type=nncf.ModelType.TRANSFORMER
    )
    ov.save_model(quantized_asr_model, quantized_asr_model_xml_path)
    compiled_quantized_asr_model = core.compile_model(quantized_asr_model, device_name=device.value)
else:
    compiled_quantized_asr_model = get_asr_model(quantized_asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00,  1.17s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 290/290 [00:17<00:00, 16.39it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:19<00:00,  3.93s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 393/393 [05:22<00:00,  1.22it/s]

Run transcription with quantized model and compare the result to the one produced by original model.

%%skip not $to_quantize.value

compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)
transcription_original = recognize_audio(compiled_asr_model, example['audio']['array'])
transcription_quantized = recognize_audio(compiled_quantized_asr_model, example['audio']['array'])
print("Transcription by original model: ", transcription_original)
print("Transcription by quantized model:", transcription_quantized)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Transcription by original model:  le salon était de la plus haute magnificence dorée comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu décent à la maîtresse du logis qui avait fait corriger les tableaux
Transcription by quantized model: le salon était de la plus haute magnificence doré comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu decent à la maîtresse du logis qui avait fait corriger les tableaux

Compare model size, performance and accuracy#

First we compare model size.

%%skip not $to_quantize.value

def calculate_compression_rate(model_path_ov, model_path_ov_int8, model_type):
    model_size_fp32 = model_path_ov.with_suffix(".bin").stat().st_size / 10 ** 6
    model_size_int8 = model_path_ov_int8.with_suffix(".bin").stat().st_size / 10 ** 6
    print(f"{model_type} model footprint comparison:")
    print(f"    * FP32 IR model size: {model_size_fp32:.2f} MB")
    print(f"    * INT8 IR model size: {model_size_int8:.2f} MB")
    return model_size_fp32, model_size_int8

lid_model_size_fp32, lid_model_size_int8 = \
    calculate_compression_rate(lid_model_xml_path, quantized_lid_model_xml_path, 'LID')
asr_model_size_fp32, asr_model_size_int8 = \
    calculate_compression_rate(Path(asr_model_xml_path_template.format(language_id)), quantized_asr_model_xml_path, 'ASR')
LID model footprint comparison:
    * FP32 IR model size: 1931.81 MB
    * INT8 IR model size: 968.96 MB
ASR model footprint comparison:
    * FP32 IR model size: 1930.10 MB
    * INT8 IR model size: 968.29 MB

Secondly we compare accuracy values of the original and quantized models on a test split of MLS dataset. We rely on the Word Error Rate (WER) metric and compute accuracy as (1 - WER).

We also measure inference time for both language identification and speech recognition models.

%%skip not $to_quantize.value

import time
from tqdm.notebook import tqdm
import numpy as np
from jiwer import wer

TEST_DATASET_SIZE = 20
test_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True, trust_remote_code=True)
test_dataset = test_dataset.take(TEST_DATASET_SIZE)

def calculate_transcription_time_and_accuracy(lid_model, asr_model):
    ground_truths = []
    predictions = []
    identification_time = []
    transcription_time = []
    for data_item in tqdm(test_dataset, desc="Measuring performance and accuracy", total=TEST_DATASET_SIZE):
        audio = data_item["audio"]["array"]

        start_time = time.perf_counter()
        detect_language(lid_model, audio)
        end_time = time.perf_counter()
        identification_time.append(end_time - start_time)

        start_time = time.perf_counter()
        transcription = recognize_audio(asr_model, audio)
        end_time = time.perf_counter()
        transcription_time.append(end_time - start_time)

        ground_truths.append(data_item["transcript"])
        predictions.append(transcription)

    word_accuracy = (1 - wer(ground_truths, predictions)) * 100
    mean_identification_time = np.mean(identification_time)
    mean_transcription_time = np.mean(transcription_time)
    return mean_identification_time, mean_transcription_time, word_accuracy

identification_time_fp32, transcription_time_fp32, accuracy_fp32 = \
    calculate_transcription_time_and_accuracy(compiled_lid_model, compiled_asr_model)
identification_time_int8, transcription_time_int8, accuracy_int8 = \
    calculate_transcription_time_and_accuracy(compiled_quantized_lid_model, compiled_quantized_asr_model)
print(f"LID model footprint reduction: {lid_model_size_fp32 / lid_model_size_int8:.3f}")
print(f"ASR model footprint reduction: {asr_model_size_fp32 / asr_model_size_int8:.3f}")
print(f"Language identification performance speedup: {identification_time_fp32 / identification_time_int8:.3f}")
print(f"Language transcription performance speedup:  {transcription_time_fp32 / transcription_time_int8:.3f}")
print(f"Transcription word accuracy. FP32: {accuracy_fp32:.2f}%. INT8: {accuracy_int8:.2f}%. Accuracy drop :{accuracy_fp32 - accuracy_int8:.2f}%.")
Measuring performance and accuracy:   0%|          | 0/20 [00:00<?, ?it/s]
Measuring performance and accuracy:   0%|          | 0/20 [00:00<?, ?it/s]
LID model footprint reduction: 1.994
ASR model footprint reduction: 1.993
Language identification performance speedup: 1.425
Language transcription performance speedup:  1.489
Transcription word accuracy. FP32: 85.01%. INT8: 84.76%. Accuracy drop :0.25%.

Interactive demo with Gradio#

In this demo you can try your own examples. Make sure that the audio data is sampled to 16000 kHz.

import gradio as gr
import librosa
import time


current_state = {
    "fp32": {"model": None, "language": None},
    "int8": {"model": None, "language": None},
}


def infer(src_audio_path, quantized):
    src_audio, _ = librosa.load(src_audio_path)
    lid_model = compiled_quantized_lid_model if quantized else compiled_lid_model

    start_time = time.perf_counter()
    detected_language_id = detect_language(lid_model, src_audio)
    end_time = time.perf_counter()
    identification_delta_time = f"{end_time - start_time:.2f}"

    state = current_state["int8" if quantized else "fp32"]
    if detected_language_id != state["language"]:
        template_path = quantized_asr_model_xml_path_template if quantized else asr_model_xml_path_template
        try:
            gr.Info(f"Loading {'quantized' if quantized else ''} ASR model for '{detected_language_id}' language. " "This will take some time.")
            state["model"] = get_asr_model(template_path, detected_language_id)
            state["language"] = detected_language_id
        except RuntimeError as e:
            if "Unable to read the model:" in str(e) and quantized:
                raise gr.Error(f"There is no quantized ASR model for '{detected_language_id}' language. " "Please run quantization for this language first.")

    start_time = time.perf_counter()
    transcription = recognize_audio(state["model"], src_audio)
    end_time = time.perf_counter()
    transcription_delta_time = f"{end_time - start_time:.2f}"

    return (
        detected_language_id,
        transcription,
        identification_delta_time,
        transcription_delta_time,
    )
if not Path("gradio_helper.py").exists():
    r = requests.get(
        url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/mms-massively-multilingual-speech/gradio_helper.py"
    )
    open("gradio_helper.py", "w").write(r.text)

from gradio_helper import make_demo

demo = make_demo(fn=infer, quantized=to_quantize.value)

try:
    demo.queue().launch(debug=False)
except Exception:
    demo.queue().launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
# please uncomment and run this cell for stopping gradio interface
# demo.close()