MMS: Scaling Speech Technology to 1000+ languages with OpenVINO™¶
This Jupyter notebook can be launched after a local installation only.
The Massively Multilingual Speech (MMS) project expands speech technology from about 100 languages to over 1,000 by building a single multilingual speech recognition model supporting over 1,100 languages (more than 10 times as many as before), language identification models able to identify over 4,000 languages (40 times more than before), pretrained models supporting over 1,400 languages, and text-to-speech models for over 1,100 languages.
The MMS model was proposed in Scaling Speech Technology to 1,000+ Languages. The models and code are originally released here.
There are different open sourced models in the MMS project: Automatic Speech Recognition (ASR), Language Identification (LID) and Speech Synthesis (TTS). A simple diagram of this is below.
In this notebook we are considering ASR and LID. We will use LID model to identify language, and then language-specific ASR model to recognize it. Additional models quantization step is employed to improve models inference speed. In the end of the notebook there’s a Gradio-based interactive demo.
Table of contents:¶
Prerequisites¶
%pip install -q --upgrade pip
%pip install -q "transformers>=4.33.1" "openvino>=2023.1.0" "numpy>=1.21.0,<=1.24" "nncf>=2.6.0"
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu torch datasets accelerate soundfile librosa gradio jiwer
from pathlib import Path
import torch
import openvino as ov
Prepare an example audio¶
Read an audio file and process the audio data. Make sure that the audio
data is sampled to 16000 kHz. For this example we will use a streamable
version of the Multilingual LibriSpeech (MLS)
dataset.
It supports contains example on 7 languages:
'german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'
.
Choose one of them.
import ipywidgets as widgets
SAMPLE_LANG = widgets.Dropdown(
options=['german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'],
value='german',
description='Dataset language:',
disabled=False,
)
SAMPLE_LANG
Dropdown(description='Dataset language:', options=('german', 'dutch', 'french', 'spanish', 'italian', 'portugu…
Specify streaming=True
to not download the entire dataset.
from datasets import load_dataset
mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True)
mls_dataset = iter(mls_dataset) # make it iterable
example = next(mls_dataset) # get one example
Example has a dictionary structure. It contains an audio data and a text transcription.
print(example) # look at structure
{'file': None, 'audio': {'path': '1054_1599_000000.flac', 'array': array([-0.00131226, -0.00152588, -0.00134277, ..., 0.00411987,
0.00308228, -0.00015259]), 'sampling_rate': 16000}, 'text': 'mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit', 'speaker_id': 1054, 'chapter_id': 1599, 'id': '1054_1599_000000'}
import IPython.display as ipd
print(example['text'])
ipd.Audio(example['audio']['array'], rate=16_000)
mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit
Language Identification (LID)¶
Download pretrained model and processor¶
Different LID models are available based on the number of languages they can recognize - 126, 256, 512, 1024, 2048, 4017. We will use 126.
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
model_id = "facebook/mms-lid-126"
lid_processor = AutoFeatureExtractor.from_pretrained(model_id)
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
Use the original model to run an inference¶
inputs = lid_processor(example['audio']['array'], sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = lid_model(**inputs).logits
lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = lid_model.config.id2label[lang_id]
print(detected_lang)
deu
Convert to OpenVINO IR model and run an inference¶
Select device from dropdown list for running inference using OpenVINO
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value='AUTO',
description='Device:',
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
Convert model to OpenVINO format and compile it
MAX_SEQ_LENGTH = 30480
lid_model_xml_path = Path('models/ov_lid_model.xml')
def get_lid_model(model_path, compiled=True):
input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
if not model_path.exists() and model_path == lid_model_xml_path:
lid_model_xml_path.parent.mkdir(parents=True, exist_ok=True)
converted_model = ov.convert_model(lid_model, example_input={'input_values': input_values})
ov.save_model(converted_model, lid_model_xml_path)
if not compiled:
return converted_model
if compiled:
return core.compile_model(model_path, device_name=device.value)
return core.read_model(model_path)
compiled_lid_model = get_lid_model(lid_model_xml_path)
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
Now it is possible to run an inference.
def detect_language(compiled_model, audio_data):
inputs = lid_processor(audio_data, sampling_rate=16_000, return_tensors="pt")
outputs = compiled_model(inputs['input_values'])[0]
lang_id = torch.argmax(torch.from_numpy(outputs), dim=-1)[0].item()
detected_lang = lid_model.config.id2label[lang_id]
return detected_lang
detect_language(compiled_lid_model, example['audio']['array'])
'deu'
Let’s check another language.
SAMPLE_LANG = widgets.Dropdown(
options=['german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'],
value='french',
description='Dataset language:',
disabled=False,
)
SAMPLE_LANG
Dropdown(description='Dataset language:', index=2, options=('german', 'dutch', 'french', 'spanish', 'italian',…
mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True)
mls_dataset = iter(mls_dataset)
example = next(mls_dataset)
print(example['text'])
ipd.Audio(example['audio']['array'], rate=16_000)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
language_id = detect_language(compiled_lid_model, example['audio']['array'])
print(language_id)
fra
Automatic Speech Recognition (ASR)¶
Download pretrained model and processor¶
Download pretrained model and processor. By default, MMS loads adapter
weights for English. If you want to load adapter weights of another
language make sure to specify target_lang=<your-chosen-target-lang>
as well as ignore_mismatched_sizes=True
. The
ignore_mismatched_sizes=True
keyword has to be passed to allow the
language model head to be resized according to the vocabulary of the
specified language. Similarly, the processor should be loaded with the
same target language. It is also possible to change the supported
language later.
from transformers import Wav2Vec2ForCTC, AutoProcessor
model_id = "facebook/mms-1b-all"
asr_processor = AutoProcessor.from_pretrained(model_id)
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id)
You can look at all supported languages:
asr_processor.tokenizer.vocab.keys()
dict_keys(['abi', 'abk', 'abp', 'aca', 'acd', 'ace', 'acf', 'ach', 'acn', 'acr', 'acu', 'ade', 'adh', 'adj', 'adx', 'aeu', 'afr', 'agd', 'agg', 'agn', 'agr', 'agu', 'agx', 'aha', 'ahk', 'aia', 'aka', 'akb', 'ake', 'akp', 'alj', 'alp', 'alt', 'alz', 'ame', 'amf', 'amh', 'ami', 'amk', 'ann', 'any', 'aoz', 'apb', 'apr', 'ara', 'arl', 'asa', 'asg', 'asm', 'ast', 'ata', 'atb', 'atg', 'ati', 'atq', 'ava', 'avn', 'avu', 'awa', 'awb', 'ayo', 'ayr', 'ayz', 'azb', 'azg', 'azj-script_cyrillic', 'azj-script_latin', 'azz', 'bak', 'bam', 'ban', 'bao', 'bas', 'bav', 'bba', 'bbb', 'bbc', 'bbo', 'bcc-script_arabic', 'bcc-script_latin', 'bcl', 'bcw', 'bdg', 'bdh', 'bdq', 'bdu', 'bdv', 'beh', 'bel', 'bem', 'ben', 'bep', 'bex', 'bfa', 'bfo', 'bfy', 'bfz', 'bgc', 'bgq', 'bgr', 'bgt', 'bgw', 'bha', 'bht', 'bhz', 'bib', 'bim', 'bis', 'biv', 'bjr', 'bjv', 'bjw', 'bjz', 'bkd', 'bkv', 'blh', 'blt', 'blx', 'blz', 'bmq', 'bmr', 'bmu', 'bmv', 'bng', 'bno', 'bnp', 'boa', 'bod', 'boj', 'bom', 'bor', 'bos', 'bov', 'box', 'bpr', 'bps', 'bqc', 'bqi', 'bqj', 'bqp', 'bre', 'bru', 'bsc', 'bsq', 'bss', 'btd', 'bts', 'btt', 'btx', 'bud', 'bul', 'bus', 'bvc', 'bvz', 'bwq', 'bwu', 'byr', 'bzh', 'bzi', 'bzj', 'caa', 'cab', 'cac-dialect_sanmateoixtatan', 'cac-dialect_sansebastiancoatan', 'cak-dialect_central', 'cak-dialect_santamariadejesus', 'cak-dialect_santodomingoxenacoj', 'cak-dialect_southcentral', 'cak-dialect_western', 'cak-dialect_yepocapa', 'cap', 'car', 'cas', 'cat', 'cax', 'cbc', 'cbi', 'cbr', 'cbs', 'cbt', 'cbu', 'cbv', 'cce', 'cco', 'cdj', 'ceb', 'ceg', 'cek', 'ces', 'cfm', 'cgc', 'che', 'chf', 'chv', 'chz', 'cjo', 'cjp', 'cjs', 'ckb', 'cko', 'ckt', 'cla', 'cle', 'cly', 'cme', 'cmn-script_simplified', 'cmo-script_khmer', 'cmo-script_latin', 'cmr', 'cnh', 'cni', 'cnl', 'cnt', 'coe', 'cof', 'cok', 'con', 'cot', 'cou', 'cpa', 'cpb', 'cpu', 'crh', 'crk-script_latin', 'crk-script_syllabics', 'crn', 'crq', 'crs', 'crt', 'csk', 'cso', 'ctd', 'ctg', 'cto', 'ctu', 'cuc', 'cui', 'cuk', 'cul', 'cwa', 'cwe', 'cwt', 'cya', 'cym', 'daa', 'dah', 'dan', 'dar', 'dbj', 'dbq', 'ddn', 'ded', 'des', 'deu', 'dga', 'dgi', 'dgk', 'dgo', 'dgr', 'dhi', 'did', 'dig', 'dik', 'dip', 'div', 'djk', 'dnj-dialect_blowowest', 'dnj-dialect_gweetaawueast', 'dnt', 'dnw', 'dop', 'dos', 'dsh', 'dso', 'dtp', 'dts', 'dug', 'dwr', 'dyi', 'dyo', 'dyu', 'dzo', 'eip', 'eka', 'ell', 'emp', 'enb', 'eng', 'enx', 'epo', 'ese', 'ess', 'est', 'eus', 'evn', 'ewe', 'eza', 'fal', 'fao', 'far', 'fas', 'fij', 'fin', 'flr', 'fmu', 'fon', 'fra', 'frd', 'fry', 'ful', 'gag-script_cyrillic', 'gag-script_latin', 'gai', 'gam', 'gau', 'gbi', 'gbk', 'gbm', 'gbo', 'gde', 'geb', 'gej', 'gil', 'gjn', 'gkn', 'gld', 'gle', 'glg', 'glk', 'gmv', 'gna', 'gnd', 'gng', 'gof-script_latin', 'gog', 'gor', 'gqr', 'grc', 'gri', 'grn', 'grt', 'gso', 'gub', 'guc', 'gud', 'guh', 'guj', 'guk', 'gum', 'guo', 'guq', 'guu', 'gux', 'gvc', 'gvl', 'gwi', 'gwr', 'gym', 'gyr', 'had', 'hag', 'hak', 'hap', 'hat', 'hau', 'hay', 'heb', 'heh', 'hif', 'hig', 'hil', 'hin', 'hlb', 'hlt', 'hne', 'hnn', 'hns', 'hoc', 'hoy', 'hrv', 'hsb', 'hto', 'hub', 'hui', 'hun', 'hus-dialect_centralveracruz', 'hus-dialect_westernpotosino', 'huu', 'huv', 'hvn', 'hwc', 'hye', 'hyw', 'iba', 'ibo', 'icr', 'idd', 'ifa', 'ifb', 'ife', 'ifk', 'ifu', 'ify', 'ign', 'ikk', 'ilb', 'ilo', 'imo', 'ina', 'inb', 'ind', 'iou', 'ipi', 'iqw', 'iri', 'irk', 'isl', 'ita', 'itl', 'itv', 'ixl-dialect_sangasparchajul', 'ixl-dialect_sanjuancotzal', 'ixl-dialect_santamarianebaj', 'izr', 'izz', 'jac', 'jam', 'jav', 'jbu', 'jen', 'jic', 'jiv', 'jmc', 'jmd', 'jpn', 'jun', 'juy', 'jvn', 'kaa', 'kab', 'kac', 'kak', 'kam', 'kan', 'kao', 'kaq', 'kat', 'kay', 'kaz', 'kbo', 'kbp', 'kbq', 'kbr', 'kby', 'kca', 'kcg', 'kdc', 'kde', 'kdh', 'kdi', 'kdj', 'kdl', 'kdn', 'kdt', 'kea', 'kek', 'ken', 'keo', 'ker', 'key', 'kez', 'kfb', 'kff-script_telugu', 'kfw', 'kfx', 'khg', 'khm', 'khq', 'kia', 'kij', 'kik', 'kin', 'kir', 'kjb', 'kje', 'kjg', 'kjh', 'kki', 'kkj', 'kle', 'klu', 'klv', 'klw', 'kma', 'kmd', 'kml', 'kmr-script_arabic', 'kmr-script_cyrillic', 'kmr-script_latin', 'kmu', 'knb', 'kne', 'knf', 'knj', 'knk', 'kno', 'kog', 'kor', 'kpq', 'kps', 'kpv', 'kpy', 'kpz', 'kqe', 'kqp', 'kqr', 'kqy', 'krc', 'kri', 'krj', 'krl', 'krr', 'krs', 'kru', 'ksb', 'ksr', 'kss', 'ktb', 'ktj', 'kub', 'kue', 'kum', 'kus', 'kvn', 'kvw', 'kwd', 'kwf', 'kwi', 'kxc', 'kxf', 'kxm', 'kxv', 'kyb', 'kyc', 'kyf', 'kyg', 'kyo', 'kyq', 'kyu', 'kyz', 'kzf', 'lac', 'laj', 'lam', 'lao', 'las', 'lat', 'lav', 'law', 'lbj', 'lbw', 'lcp', 'lee', 'lef', 'lem', 'lew', 'lex', 'lgg', 'lgl', 'lhu', 'lia', 'lid', 'lif', 'lin', 'lip', 'lis', 'lit', 'lje', 'ljp', 'llg', 'lln', 'lme', 'lnd', 'lns', 'lob', 'lok', 'lom', 'lon', 'loq', 'lsi', 'lsm', 'ltz', 'luc', 'lug', 'luo', 'lwo', 'lww', 'lzz', 'maa-dialect_sanantonio', 'maa-dialect_sanjeronimo', 'mad', 'mag', 'mah', 'mai', 'maj', 'mak', 'mal', 'mam-dialect_central', 'mam-dialect_northern', 'mam-dialect_southern', 'mam-dialect_western', 'maq', 'mar', 'maw', 'maz', 'mbb', 'mbc', 'mbh', 'mbj', 'mbt', 'mbu', 'mbz', 'mca', 'mcb', 'mcd', 'mco', 'mcp', 'mcq', 'mcu', 'mda', 'mdf', 'mdv', 'mdy', 'med', 'mee', 'mej', 'men', 'meq', 'met', 'mev', 'mfe', 'mfh', 'mfi', 'mfk', 'mfq', 'mfy', 'mfz', 'mgd', 'mge', 'mgh', 'mgo', 'mhi', 'mhr', 'mhu', 'mhx', 'mhy', 'mib', 'mie', 'mif', 'mih', 'mil', 'mim', 'min', 'mio', 'mip', 'miq', 'mit', 'miy', 'miz', 'mjl', 'mjv', 'mkd', 'mkl', 'mkn', 'mlg', 'mlt', 'mmg', 'mnb', 'mnf', 'mnk', 'mnw', 'mnx', 'moa', 'mog', 'mon', 'mop', 'mor', 'mos', 'mox', 'moz', 'mpg', 'mpm', 'mpp', 'mpx', 'mqb', 'mqf', 'mqj', 'mqn', 'mri', 'mrw', 'msy', 'mtd', 'mtj', 'mto', 'muh', 'mup', 'mur', 'muv', 'muy', 'mvp', 'mwq', 'mwv', 'mxb', 'mxq', 'mxt', 'mxv', 'mya', 'myb', 'myk', 'myl', 'myv', 'myx', 'myy', 'mza', 'mzi', 'mzj', 'mzk', 'mzm', 'mzw', 'nab', 'nag', 'nan', 'nas', 'naw', 'nca', 'nch', 'ncj', 'ncl', 'ncu', 'ndj', 'ndp', 'ndv', 'ndy', 'ndz', 'neb', 'new', 'nfa', 'nfr', 'nga', 'ngl', 'ngp', 'ngu', 'nhe', 'nhi', 'nhu', 'nhw', 'nhx', 'nhy', 'nia', 'nij', 'nim', 'nin', 'nko', 'nlc', 'nld', 'nlg', 'nlk', 'nmz', 'nnb', 'nno', 'nnq', 'nnw', 'noa', 'nob', 'nod', 'nog', 'not', 'npi', 'npl', 'npy', 'nso', 'nst', 'nsu', 'ntm', 'ntr', 'nuj', 'nus', 'nuz', 'nwb', 'nxq', 'nya', 'nyf', 'nyn', 'nyo', 'nyy', 'nzi', 'obo', 'oci', 'ojb-script_latin', 'ojb-script_syllabics', 'oku', 'old', 'omw', 'onb', 'ood', 'orm', 'ory', 'oss', 'ote', 'otq', 'ozm', 'pab', 'pad', 'pag', 'pam', 'pan', 'pao', 'pap', 'pau', 'pbb', 'pbc', 'pbi', 'pce', 'pcm', 'peg', 'pez', 'pib', 'pil', 'pir', 'pis', 'pjt', 'pkb', 'pls', 'plw', 'pmf', 'pny', 'poh-dialect_eastern', 'poh-dialect_western', 'poi', 'pol', 'por', 'poy', 'ppk', 'pps', 'prf', 'prk', 'prt', 'pse', 'pss', 'ptu', 'pui', 'pus', 'pwg', 'pww', 'pxm', 'qub', 'quc-dialect_central', 'quc-dialect_east', 'quc-dialect_north', 'quf', 'quh', 'qul', 'quw', 'quy', 'quz', 'qvc', 'qve', 'qvh', 'qvm', 'qvn', 'qvo', 'qvs', 'qvw', 'qvz', 'qwh', 'qxh', 'qxl', 'qxn', 'qxo', 'qxr', 'rah', 'rai', 'rap', 'rav', 'raw', 'rej', 'rel', 'rgu', 'rhg', 'rif-script_arabic', 'rif-script_latin', 'ril', 'rim', 'rjs', 'rkt', 'rmc-script_cyrillic', 'rmc-script_latin', 'rmo', 'rmy-script_cyrillic', 'rmy-script_latin', 'rng', 'rnl', 'roh-dialect_sursilv', 'roh-dialect_vallader', 'rol', 'ron', 'rop', 'rro', 'rub', 'ruf', 'rug', 'run', 'rus', 'sab', 'sag', 'sah', 'saj', 'saq', 'sas', 'sat', 'sba', 'sbd', 'sbl', 'sbp', 'sch', 'sck', 'sda', 'sea', 'seh', 'ses', 'sey', 'sgb', 'sgj', 'sgw', 'shi', 'shk', 'shn', 'sho', 'shp', 'sid', 'sig', 'sil', 'sja', 'sjm', 'sld', 'slk', 'slu', 'slv', 'sml', 'smo', 'sna', 'snd', 'sne', 'snn', 'snp', 'snw', 'som', 'soy', 'spa', 'spp', 'spy', 'sqi', 'sri', 'srm', 'srn', 'srp-script_cyrillic', 'srp-script_latin', 'srx', 'stn', 'stp', 'suc', 'suk', 'sun', 'sur', 'sus', 'suv', 'suz', 'swe', 'swh', 'sxb', 'sxn', 'sya', 'syl', 'sza', 'tac', 'taj', 'tam', 'tao', 'tap', 'taq', 'tat', 'tav', 'tbc', 'tbg', 'tbk', 'tbl', 'tby', 'tbz', 'tca', 'tcc', 'tcs', 'tcz', 'tdj', 'ted', 'tee', 'tel', 'tem', 'teo', 'ter', 'tes', 'tew', 'tex', 'tfr', 'tgj', 'tgk', 'tgl', 'tgo', 'tgp', 'tha', 'thk', 'thl', 'tih', 'tik', 'tir', 'tkr', 'tlb', 'tlj', 'tly', 'tmc', 'tmf', 'tna', 'tng', 'tnk', 'tnn', 'tnp', 'tnr', 'tnt', 'tob', 'toc', 'toh', 'tom', 'tos', 'tpi', 'tpm', 'tpp', 'tpt', 'trc', 'tri', 'trn', 'trs', 'tso', 'tsz', 'ttc', 'tte', 'ttq-script_tifinagh', 'tue', 'tuf', 'tuk-script_arabic', 'tuk-script_latin', 'tuo', 'tur', 'tvw', 'twb', 'twe', 'twu', 'txa', 'txq', 'txu', 'tye', 'tzh-dialect_bachajon', 'tzh-dialect_tenejapa', 'tzj-dialect_eastern', 'tzj-dialect_western', 'tzo-dialect_chamula', 'tzo-dialect_chenalho', 'ubl', 'ubu', 'udm', 'udu', 'uig-script_arabic', 'uig-script_cyrillic', 'ukr', 'umb', 'unr', 'upv', 'ura', 'urb', 'urd-script_arabic', 'urd-script_devanagari', 'urd-script_latin', 'urk', 'urt', 'ury', 'usp', 'uzb-script_cyrillic', 'uzb-script_latin', 'vag', 'vid', 'vie', 'vif', 'vmw', 'vmy', 'vot', 'vun', 'vut', 'wal-script_ethiopic', 'wal-script_latin', 'wap', 'war', 'waw', 'way', 'wba', 'wlo', 'wlx', 'wmw', 'wob', 'wol', 'wsg', 'wwa', 'xal', 'xdy', 'xed', 'xer', 'xho', 'xmm', 'xnj', 'xnr', 'xog', 'xon', 'xrb', 'xsb', 'xsm', 'xsr', 'xsu', 'xta', 'xtd', 'xte', 'xtm', 'xtn', 'xua', 'xuo', 'yaa', 'yad', 'yal', 'yam', 'yao', 'yas', 'yat', 'yaz', 'yba', 'ybb', 'ycl', 'ycn', 'yea', 'yka', 'yli', 'yor', 'yre', 'yua', 'yue-script_traditional', 'yuz', 'yva', 'zaa', 'zab', 'zac', 'zad', 'zae', 'zai', 'zam', 'zao', 'zaq', 'zar', 'zas', 'zav', 'zaw', 'zca', 'zga', 'zim', 'ziw', 'zlm', 'zmz', 'zne', 'zos', 'zpc', 'zpg', 'zpi', 'zpl', 'zpm', 'zpo', 'zpt', 'zpu', 'zpz', 'ztq', 'zty', 'zul', 'zyb', 'zyp', 'zza'])
Switch out the language adapters by calling the load_adapter()
function for the model and set_target_lang()
for the tokenizer. Pass
the target language as an input - "detect_language_id"
which was
detected in the previous step.
asr_processor.tokenizer.set_target_lang(language_id)
asr_model.load_adapter(language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Use the original model for inference¶
inputs = asr_processor(example['audio']['array'], sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = asr_model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = asr_processor.decode(ids)
print(transcription)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
Convert to OpenVINO IR model and run inference¶
Convert to OpenVINO IR model format with ov.convert_model
function
directly. Use ov.save_model
function to serialize the result of
conversion. For convenience of further use, we will create a function
for these purposes.
asr_model_xml_path_template = 'models/ov_asr_{}_model.xml'
def get_asr_model(model_path_template, language_id, compiled=True):
input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
model_path = Path(model_path_template.format(language_id))
asr_processor.tokenizer.set_target_lang(language_id)
if not model_path.exists() and model_path_template == asr_model_xml_path_template:
asr_model.load_adapter(language_id)
model_path.parent.mkdir(parents=True, exist_ok=True)
converted_model = ov.convert_model(asr_model, example_input={'input_values': input_values})
ov.save_model(converted_model, model_path)
if not compiled:
return converted_model
if compiled:
return core.compile_model(model_path, device_name=device.value)
return core.read_model(model_path)
compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Run inference.
def recognize_audio(compiled_model, src_audio):
inputs = asr_processor(src_audio, sampling_rate=16_000, return_tensors="pt")
outputs = compiled_model(inputs['input_values'])[0]
ids = torch.argmax(torch.from_numpy(outputs), dim=-1)[0]
transcription = asr_processor.decode(ids)
return transcription
transcription = recognize_audio(compiled_asr_model, example['audio']['array'])
print("Original text:", example['text'])
print("Transcription:", transcription)
Original text: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
Transcription: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
Quantization¶
NNCF enables
post-training quantization by adding quantization layers into model
graph and then using a subset of the training dataset to initialize the
parameters of these additional quantization layers. Quantized operations
are executed in INT8
instead of FP32
/FP16
making model
inference faster.
The optimization process contains the following steps:
Create a calibration dataset for quantization.
Run
nncf.quantize()
to obtain quantized models.Serialize quantized
INT8
model usingopenvino.save_model()
function.
Note: Quantization is time and memory consuming operation. Running quantization code below may take some time.
compiled_quantized_lid_model = None
quantized_asr_model_xml_path_template = None
to_quantize = widgets.Checkbox(
value=False,
description='Quantization',
disabled=False,
)
to_quantize
Checkbox(value=True, description='Quantization')
Let’s load skip magic extension to skip quantization if to_quantize is not selected
import sys
sys.path.append("../utils")
%load_ext skip_kernel_extension
Preparing calibration dataset¶
Select the language to quantize the model for:
%%skip not $to_quantize.value
from IPython.display import display
display(SAMPLE_LANG)
Load validation split of the same MLS dataset for the selected language.
%%skip not $to_quantize.value
mls_dataset = iter(load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="validation", streaming=True))
example = next(mls_dataset)
Create calibration dataset for quantization.
%%skip not $to_quantize.value
CALIBRATION_DATASET_SIZE = 5
calibration_data = []
for i in range(CALIBRATION_DATASET_SIZE):
data = asr_processor(next(mls_dataset)['audio']['array'], sampling_rate=16_000, return_tensors="np")
calibration_data.append(data["input_values"])
Language identification model quantization¶
Run LID model quantization.
%%skip not $to_quantize.value
import nncf
quantized_lid_model_xml_path = Path(str(lid_model_xml_path).replace(".xml", "_quantized.xml"))
if not quantized_lid_model_xml_path.exists():
quantized_lid_model = nncf.quantize(
get_lid_model(lid_model_xml_path, compiled=False),
calibration_dataset=nncf.Dataset(calibration_data),
preset=nncf.QuantizationPreset.MIXED,
subset_size=len(calibration_data),
model_type=nncf.ModelType.TRANSFORMER
)
ov.save_model(quantized_lid_model, quantized_lid_model_xml_path)
compiled_quantized_lid_model = core.compile_model(quantized_lid_model, device_name=device.value)
else:
compiled_quantized_lid_model = get_lid_model(quantized_lid_model_xml_path)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00, 1.24s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 291/291 [00:18<00:00, 15.34it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:18<00:00, 3.65s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 298/298 [05:09<00:00, 1.04s/it]
Detect language with the quantized model.
%%skip not $to_quantize.value
language_id = detect_language(compiled_quantized_lid_model, example['audio']['array'])
print("Detected language:", language_id)
Detected language: fra
Speech recognition model quantization¶
Run ASR model quantization.
%%skip not $to_quantize.value
quantized_asr_model_xml_path_template = asr_model_xml_path_template.replace(".xml", "_quantized.xml")
quantized_asr_model_xml_path = Path(quantized_asr_model_xml_path_template.format(language_id))
if not quantized_asr_model_xml_path.exists():
quantized_asr_model = nncf.quantize(
get_asr_model(asr_model_xml_path_template, language_id, compiled=False),
calibration_dataset=nncf.Dataset(calibration_data),
preset=nncf.QuantizationPreset.MIXED,
subset_size=len(calibration_data),
model_type=nncf.ModelType.TRANSFORMER
)
ov.save_model(quantized_asr_model, quantized_asr_model_xml_path)
compiled_quantized_asr_model = core.compile_model(quantized_asr_model, device_name=device.value)
else:
compiled_quantized_asr_model = get_asr_model(quantized_asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.17s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 290/290 [00:17<00:00, 16.39it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:19<00:00, 3.93s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 393/393 [05:22<00:00, 1.22it/s]
Run transcription with quantized model and compare the result to the one produced by original model.
%%skip not $to_quantize.value
compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)
transcription_original = recognize_audio(compiled_asr_model, example['audio']['array'])
transcription_quantized = recognize_audio(compiled_quantized_asr_model, example['audio']['array'])
print("Transcription by original model: ", transcription_original)
print("Transcription by quantized model:", transcription_quantized)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Transcription by original model: le salon était de la plus haute magnificence dorée comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu décent à la maîtresse du logis qui avait fait corriger les tableaux
Transcription by quantized model: le salon était de la plus haute magnificence doré comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu decent à la maîtresse du logis qui avait fait corriger les tableaux
Compare model size, performance and accuracy¶
First we compare model size.
%%skip not $to_quantize.value
def calculate_compression_rate(model_path_ov, model_path_ov_int8, model_type):
model_size_fp32 = model_path_ov.with_suffix(".bin").stat().st_size / 10 ** 6
model_size_int8 = model_path_ov_int8.with_suffix(".bin").stat().st_size / 10 ** 6
print(f"{model_type} model footprint comparison:")
print(f" * FP32 IR model size: {model_size_fp32:.2f} MB")
print(f" * INT8 IR model size: {model_size_int8:.2f} MB")
return model_size_fp32, model_size_int8
lid_model_size_fp32, lid_model_size_int8 = \
calculate_compression_rate(lid_model_xml_path, quantized_lid_model_xml_path, 'LID')
asr_model_size_fp32, asr_model_size_int8 = \
calculate_compression_rate(Path(asr_model_xml_path_template.format(language_id)), quantized_asr_model_xml_path, 'ASR')
LID model footprint comparison:
* FP32 IR model size: 1931.81 MB
* INT8 IR model size: 968.96 MB
ASR model footprint comparison:
* FP32 IR model size: 1930.10 MB
* INT8 IR model size: 968.29 MB
Secondly we compare accuracy values of the original and quantized models
on a test split of MLS dataset. We rely on the Word Error Rate (WER)
metric and compute accuracy as (1 - WER)
.
We also measure inference time for both language identification and speech recognition models.
%%skip not $to_quantize.value
import time
from tqdm.notebook import tqdm
import numpy as np
from jiwer import wer
TEST_DATASET_SIZE = 20
test_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True)
test_dataset = test_dataset.take(TEST_DATASET_SIZE)
def calculate_transcription_time_and_accuracy(lid_model, asr_model):
ground_truths = []
predictions = []
identification_time = []
transcription_time = []
for data_item in tqdm(test_dataset, desc="Measuring performance and accuracy", total=TEST_DATASET_SIZE):
audio = data_item["audio"]["array"]
start_time = time.perf_counter()
detect_language(lid_model, audio)
end_time = time.perf_counter()
identification_time.append(end_time - start_time)
start_time = time.perf_counter()
transcription = recognize_audio(asr_model, audio)
end_time = time.perf_counter()
transcription_time.append(end_time - start_time)
ground_truths.append(data_item["text"])
predictions.append(transcription)
word_accuracy = (1 - wer(ground_truths, predictions)) * 100
mean_identification_time = np.mean(identification_time)
mean_transcription_time = np.mean(transcription_time)
return mean_identification_time, mean_transcription_time, word_accuracy
identification_time_fp32, transcription_time_fp32, accuracy_fp32 = \
calculate_transcription_time_and_accuracy(compiled_lid_model, compiled_asr_model)
identification_time_int8, transcription_time_int8, accuracy_int8 = \
calculate_transcription_time_and_accuracy(compiled_quantized_lid_model, compiled_quantized_asr_model)
print(f"LID model footprint reduction: {lid_model_size_fp32 / lid_model_size_int8:.3f}")
print(f"ASR model footprint reduction: {asr_model_size_fp32 / asr_model_size_int8:.3f}")
print(f"Language identification performance speedup: {identification_time_fp32 / identification_time_int8:.3f}")
print(f"Language transcription performance speedup: {transcription_time_fp32 / transcription_time_int8:.3f}")
print(f"Transcription word accuracy. FP32: {accuracy_fp32:.2f}%. INT8: {accuracy_int8:.2f}%. Accuracy drop :{accuracy_fp32 - accuracy_int8:.2f}%.")
Measuring performance and accuracy: 0%| | 0/20 [00:00<?, ?it/s]
Measuring performance and accuracy: 0%| | 0/20 [00:00<?, ?it/s]
LID model footprint reduction: 1.994
ASR model footprint reduction: 1.993
Language identification performance speedup: 1.425
Language transcription performance speedup: 1.489
Transcription word accuracy. FP32: 85.01%. INT8: 84.76%. Accuracy drop :0.25%.
Interactive demo with Gradio¶
In this demo you can try your own examples. Make sure that the audio data is sampled to 16000 kHz.
import gradio as gr
import librosa
import time
title = 'MMS with Gradio'
description = 'Gradio Demo for MMS and OpenVINO™. Upload a source audio, then click the "Submit" button to detect a language ID and a transcription. ' \
'Make sure that the audio data is sampled to 16000 kHz. If this language has not been used before, it may take some time to prepare the ASR model.' \
'\n' \
'> Note: In order to run quantized model to transcribe some language, first the quantized model for that specific language must be prepared.'
current_state = {
"fp32": {"model": None, "language": None},
"int8": {"model": None, "language": None}
}
def infer(src_audio_path, quantized):
src_audio, _ = librosa.load(src_audio_path)
lid_model = compiled_quantized_lid_model if quantized else compiled_lid_model
start_time = time.perf_counter()
detected_language_id = detect_language(lid_model, src_audio)
end_time = time.perf_counter()
identification_delta_time = f"{end_time - start_time:.2f}"
state = current_state["int8" if quantized else "fp32"]
if detected_language_id != state["language"]:
template_path = quantized_asr_model_xml_path_template if quantized else asr_model_xml_path_template
try:
gr.Info(f"Loading {'quantized' if quantized else ''} ASR model for '{detected_language_id}' language. "
"This will take some time.")
state["model"] = get_asr_model(template_path, detected_language_id)
state["language"] = detected_language_id
except RuntimeError as e:
if "Unable to read the model:" in str(e) and quantized:
raise gr.Error(f"There is no quantized ASR model for '{detected_language_id}' language. "
"Please run quantization for this language first.")
start_time = time.perf_counter()
transcription = recognize_audio(state["model"], src_audio)
end_time = time.perf_counter()
transcription_delta_time = f"{end_time - start_time:.2f}"
return detected_language_id, transcription, identification_delta_time, transcription_delta_time
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(f"# {title}")
with gr.Row():
gr.Markdown(description)
run_button = {True: None, False: None}
detected_language = {True: None, False: None}
transcription = {True: None, False: None}
identification_time = {True: None, False: None}
transcription_time = {True: None, False: None}
for quantized in [False, True]:
if quantized and not to_quantize.value:
break
with gr.Row():
with gr.Column():
if not quantized:
audio = gr.Audio(label="Source Audio", type='filepath')
run_button_name = "Run INT8" if quantized else "Run FP32" if to_quantize.value else "Run"
run_button[quantized] = gr.Button(value=run_button_name)
with gr.Column():
detected_language[quantized] = gr.Textbox(label=f"Detected language ID{' (Quantized)' if quantized else ''}")
transcription[quantized] = gr.Textbox(label=f"Transcription{' (Quantized)' if quantized else ''}")
identification_time[quantized] = gr.Textbox(label=f"Identification time{' (Quantized)' if quantized else ''}")
transcription_time[quantized] = gr.Textbox(label=f"Transcription time{' (Quantized)' if quantized else ''}")
run_button[False].click(infer,
inputs=[audio, gr.Number(0, visible=False)],
outputs=[detected_language[False], transcription[False], identification_time[False], transcription_time[False]])
if to_quantize.value:
run_button[True].click(infer,
inputs=[audio, gr.Number(1, visible=False)],
outputs=[detected_language[True], transcription[True], identification_time[True], transcription_time[True]])
try:
demo.queue().launch(debug=False)
except Exception:
demo.queue().launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL: http://127.0.0.1:7860 To create a public link, set share=True in launch().
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
WARNING:nncf:NNCF provides best results with torch==2.0.1, while current torch version is 1.13.1+cu117. If you encounter issues, consider switching to torch==2.0.1
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda-11.7'
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):