The goal of this tutorial is to demonstrate how to speed up the model by
applying 8-bit post-training quantization from
NNCF (Neural Network
Compression Framework) and infer quantized model via OpenVINO™ Toolkit.
The optimization process contains the following steps: 1. Quantize
converted OpenVINO model from
notebook with NNCF. 2. Compare
probability matrices between converted and quantized models on input
data examples. 3. Compare model size of converted and quantized models.
4. Compare performance of converted and quantized models.
NOTE: you should run
239-image-bind-convert notebook
first to generate OpenVINO IR models that are used for quantization.
frompathlibimportPathrepo_dir=Path("ImageBind")ifnotrepo_dir.exists():raiseRuntimeError('This notebook should be run after 239-image-bind-convert notebook')%cd {repo_dir}
NNCF enables
post-training quantization by adding the quantization layers into the
model graph and then using a subset of the training dataset to
initialize the parameters of these additional quantization layers. The
framework is designed so that modifications to your original training
code are minor. Quantization is the simplest scenario and requires a few
modifications.
The optimization process contains the following steps: 1. Create a
Dataset for quantization. 2. Run nncf.quantize for getting a
quantized model. 3. Serialize the INT8 model using
openvino.save_model function.
/home/ea/work/ov_venv/lib/python3.8/site-packages/torchvision/transforms/functional_tensor.py:5: UserWarning: The torchvision.transforms.functional_tensor module is deprecated in 0.15 and will be removed in 0.17. Please don't rely on it. You probably just need to use APIs in torchvision.transforms.functional or in torchvision.transforms.v2.functional.
warnings.warn(
/home/ea/work/ov_venv/lib/python3.8/site-packages/torchvision/transforms/_functional_video.py:6: UserWarning: The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. Please use the 'torchvision.transforms.functional' module instead.
warnings.warn(
/home/ea/work/ov_venv/lib/python3.8/site-packages/torchvision/transforms/_transforms_video.py:22: UserWarning: The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. Please use the 'torchvision.transforms' module instead.
warnings.warn(
The Conceptual
Captions dataset
consisting of ~3.3M images annotated with captions. Dataset is used to
quantize image and text models.
importimagebind.dataasdataimportosimportrequestsimporttempfilefromrequests.packages.urllib3.exceptionsimportInsecureRequestWarningrequests.packages.urllib3.disable_warnings(InsecureRequestWarning)defcheck_text_data(data):""" Check if the given data is text-based. """ifisinstance(data,str):returnTrueifisinstance(data,list):returnall(isinstance(x,str)forxindata)returnFalsedefcollate_fn(examples,image_column="image_url",text_column="caption"):""" Collates examples into a batch for processing. Preprocesses each example by loading and transforming image and text data. Checks if the text data in the example is valid by calling the `check_text_data` function. Downloads the image specified by the URL in the image_column of the example dictionary. Constructs and returns a dictionary representing the collated batch with the following keys: - "pixel_values": The pixel values of the preprocessed example. - "input_ids": The transformed text data of the preprocessed example. """assertlen(examples)==1example=examples[0]ifnotcheck_text_data(example[text_column]):raiseValueError("Text data is not valid")url=example[image_column]withtempfile.TemporaryDirectory()astempdir:f_name=os.path.join(tempdir,'image.jpg')try:response=requests.get(url,verify=False,timeout=20)withopen(f_name,"wb")asfile:file.write(response.content)pixel_values=data.load_and_transform_vision_data([f_name],"cpu")exceptException:print(f"Can't load image from url: {url}")returnNonetext=data.load_and_transform_text([example[text_column]],"cpu")return{"pixel_values":pixel_values,"input_ids":text}
fromdatasetsimportload_datasetimportitertoolsimporttorchfromtqdm.notebookimporttqdmdefcollect_vision_text_data(dataloader,init_steps):""" This function collects vision and text data from a dataloader for a specified number of initialization steps. It iterates over the dataloader, fetching batches and storing the relevant vision and text data. Returns a tuple containing the collected vision_data and text_data lists. """text_data=[]vision_data=[]print(f"Fetching {init_steps} for the initialization...")counter=0forbatchintqdm(dataloader):ifcounter==init_steps:breakwithtorch.no_grad():ifbatch:counter+=1text_data.append(batch["input_ids"].to("cpu"))vision_data.append(batch["pixel_values"].to("cpu"))returnvision_data,text_datadefprepare_vision_text_dataset(opt_init_steps=300):""" Prepares a vision-text dataset for quantization by collecting vision and text data. """dataset=load_dataset("conceptual_captions",streaming=True)train_dataset=dataset["train"].shuffle(seed=0)dataloader=torch.utils.data.DataLoader(train_dataset,collate_fn=collate_fn,batch_size=1)vision_data,text_data=collect_vision_text_data(dataloader,opt_init_steps)returnvision_data,text_data
The ESC-50 dataset is
used to quantize the audio modality of the ImageBind model. Dataset is a
labeled collection of 2000 environmental audio recordings suitable for
benchmarking methods of environmental sound classification. The dataset
consists of 5-second-long recordings organized into 50 semantic classes.
importnumpyasnpimporttorchaudiodefcollect_audio_data(dataloader,init_steps=300):""" This function collects audio data from a dataloader for a specified number of initialization steps. It iterates over the dataloader, fetching batches and storing them in a list. """audio_data=[]for_,batchintqdm(zip(range(init_steps),itertools.islice(dataloader,0,init_steps))):withtorch.no_grad():audio_data.append(batch)returnaudio_datadefprepare_audio_dataset():""" Prepares an "ashraq/esc50" audio dataset for quantization by collecting audio data. Collects audio data from the dataloader by calling the `collect_audio_data` function. Returns a list containing the collected calibration audio data batches. """audio_dataset=load_dataset("ashraq/esc50",streaming=True)train_dataset=audio_dataset["train"].shuffle(seed=42,buffer_size=1000)defcollate_fn(examples):assertlen(examples)==1withtempfile.TemporaryDirectory()astempdir:f_name=os.path.join(tempdir,'audio.wav')audio_data=examples[0]['audio']['array']sample_rate=examples[0]['audio']["sampling_rate"]audio_data=torch.from_numpy(audio_data).to(torch.float32).unsqueeze(0)torchaudio.save(f_name,audio_data,sample_rate)returndata.load_and_transform_audio_data([f_name],"cpu")dataloader=torch.utils.data.DataLoader(train_dataset,collate_fn=collate_fn,batch_size=1)calibration_data=collect_audio_data(dataloader)returncalibration_data
Create a quantized model from the pre-trained FP16 model.
Can't load image from url: http://homeklondike.org/wp-content/uploads/2015/06/2-Bright-living-room-in-the-attic1.jpgCan't load image from url: http://www.lovemeinitaly.com/wp-content/uploads/cache/images/2018/01/4A-e1491723576743/4A-e1491723576743-1964759082.jpgCan't load image from url: https://i0.wp.com/childphotocompetition.com/wp-content/uploads/2016/02/Agnieszka_He%E2%80%8E_childphotocompetition.jpgCan't load image from url: https://magankonoski.com/wp-content/uploads/2016/05/MaganKonoskiFineArtWeddingandLifestylePhotographer-25-683x1024.jpgCan't load image from url: http://www.huahin-home-property.com/wp-content/uploads/2016/11/2immobilier-real-eatate-huahin-maison-a-vendre-condo-for-salerent-The-Autumm-Khao-takibe.jpgCan't load image from url: http://www.americanclassichomes.com/blog/wp-content/uploads/2015/04/Alki_SB_Kitchen_internet.jpgCan't load image from url: http://assets.nydailynews.com/polopoly_fs/1.110031.1313943805!/img/httpImage/image.jpg_gen/derivatives/article_750/alg-fencer-sara-harvey-browne-2-jpg.jpgCan't load image from url: http://static.panoramio.com/photos/large/34107183.jpgCan't load image from url: https://odis.homeaway.com/odis/listing/2f9f1d46-0559-4811-95ed-c97cc8608793.c10.jpgCan't load image from url: https://odis.homeaway.com/odis/listing/75953842-3278-42a1-91ef-2bb2be2ecb05.c10.jpgCan't load image from url: https://ak6.picdn.net/shutterstock/videos/2504486/thumb/1.jpgCan't load image from url: http://www.buro247.my/thumb/625x960_0/galleries/2017/10/lady-dior-art-2-19.jpgCan't load image from url: http://oneindiaonepeople.com/wp-content/uploads/2014/02/13.jpgCan't load image from url: http://www.johnsoncitypress.com/image/2016/10/27/640x_cCM_q30/XC-Region-A-AA-JPG.jpgCan't load image from url: http://fromthedeckchair.com/wp-content/uploads/2013/06/ftdc_norwegianpearl-0737.jpgCan't load image from url: http://thedailyquotes.com/wp-content/uploads/2015/04/could-be-another-broken-heart-love-daily-quotes-sayings-pictures.jpgCan't load image from url: https://www.popsci.com/sites/popsci.com/files/styles/1000_1x_/public/vizdata_map_key.jpg?itok=7myhqx2PCan't load image from url: https://www.interlatesystems.com/img/1166/183.jpgCan't load image from url: https://i1.wp.com/dailynexus.com/wp-content/uploads/2016/10/HalloweenWeekend_KennethSong-4-1024x671.jpg?resize=1024%2C671Can't load image from url: https://odis.homeaway.com/odis/listing/d81ed29b-f448-444a-9048-ed9cc9fe666a.c10.jpgCan't load image from url: http://exploresrilanka.lk/wp-content/uploads/2016/04/BTI37666.jpgCan't load image from url: http://www.tampabay.com/storyimage/HI/20170528/ARTICLE/305289727/AR/0/AR-305289727.jpgCan't load image from url: http://wewegombel.me/photo/558689/IMG_7994.jpgCan't load image from url: http://www.thedonkeysanctuary.ie/sites/ireland/files/styles/large/public/press/259-1445414098.jpg?itok=dwa9kRh_Can't load image from url: https://thumb1.shutterstock.com/display_pic_with_logo/3816881/478955293/stock-vector-abstract-pattern-in-the-memphis-style-of-large-white-spots-and-little-green-with-black-dots-on-a-478955293.jpgCan't load image from url: http://media.santabanta.com/images/picsms/2016/sms-16401.jpgCan't load image from url: https://lookaside.fbsbx.com/lookaside/crawler/media/?media_id=657209177718359Can't load image from url: http://www.blogbeen.com/wp-content/uploads/2017/09/-mesmerizing-bathroom-tiles-11-jpg-bathroom-full-version-helulis-.jpgCan't load image from url: https://6e58e2e225bb143c019e-e234a4d870c026b5f56b4446f6e62d64.ssl.cf1.rackcdn.com/a9ad7fa8-cf6c-4d2b-bbc6-591e0fd0cb2f.jpgCan't load image from url: http://wewegombel.me/photo/487654/img_8173.jpgCan't load image from url: http://s1.ibtimes.com/sites/www.ibtimes.com/files/styles/lg/public/2011/06/04/109074-an-african-giant-pouch-rat-is-watched-by-his-handler-at-a-laboratory-i.jpgCan't load image from url: http://nnimgt-a.akamaihd.net/transform/v1/crop/frm/w9qsSAVumVxqyCiyw3G2iR/d9d78dda-7d5d-4420-9f3d-a1d44813c251.jpg/r0_64_960_604_w1200_h678_fmax.jpgCan't load image from url: https://www.thenational.ae/image/policy:1.197226:1499310330/image/jpeg.jpg?f=16x9&w=1024&$p$f$w=2589da4Can't load image from url: https://ak4.picdn.net/shutterstock/videos/14101994/thumb/1.jpg?i10c=img.resize(height:160)Can't load image from url: http://sanpancholife.com/photos/home/2386/super/5005683111355530342.jpegCan't load image from url: https://media.gettyimages.com/photos/two-bottles-of-pills-one-knocked-over-with-contents-spilling-out-and-picture-id73740799?s=612x612Can't load image from url: https://www.thestar.com/content/dam/thestar/entertainment/music/2017/04/17/prince-was-prescribed-oxycodone-under-another-name-court-document/prince-07.jpg.size.custom.crop.891x650.jpgCan't load image from url: http://photos.mycapture.com/TWCM/1473481/41921058E.jpgCan't load image from url: http://xboxhut.com/wp-content/uploads/2016/05/simple-bathroom-designs-grey-modern-double-sink-bathroom-vanities60-37.jpgCan't load image from url: http://seanverret.com/wp-content/uploads/2012/07/20120710_104349.jpgCan't load image from url: http://neveradulldayinpoland.com/wp-content/uploads/2014/04/DSC_3434-1024x682.jpgCan't load image from url: http://wewegombel.me/photo/687156/watercolor-christmas-tree-isolated-white-background-texture-paper-new-year-christmas-card-template-62641882.jpgCan't load image from url: http://expatedna.com/wp-content/uploads/2015/06/City-in-the-sky-by-Expat-Edna.jpgCan't load image from url: https://lookaside.fbsbx.com/lookaside/crawler/media/?media_id=1291121264312721Can't load image from url: https://i0.wp.com/cindi-keller.com/wp-content/uploads/2014/09/cindi-keller_2014-08-15_15.07.29_ronda-spain.jpg?w=400&h=533&crop&ssl=1Can't load image from url: http://www.robinhoodshow.com/clients/17668/8642054_org.jpgCan't load image from url: https://www.101india.com/sites/default/files/image-upload/blogs/TravelandFood/29NovSecretDevkundWaterfalls/Inline%204%20%3C%20Sunrise%20at%20the%20river%20behind%20the%20farmhouse%20%3E.jpgCan't load image from url: http://www.nextavenue.org/wp-content/uploads/2017/05/image-3-w1024-750x485.jpgCan't load image from url: http://nnimgt-a.akamaihd.net/transform/v1/crop/frm/342N54ExNnUCDyWzghgYbSC/cd538c73-466c-4e05-8202-0892dceb8a44.jpg/r401_321_5388_3369_w1200_h678_fmax.jpgCan't load image from url: https://www.universetoday.com/wp-content/uploads/2016/05/Earth-magnetosphere-ESA-Medialab.jpgCan't load image from url: https://c5eeb468edc90bcfda59-8477d1500ace5389b08f6bb1cc2fee82.ssl.cf5.rackcdn.com/837712-residential-x722qn-o.jpgCan't load image from url: https://ak3.picdn.net/shutterstock/videos/7414963/thumb/1.jpg
importloggingimportnncfimportopenvinoasovnncf.set_log_level(logging.ERROR)core=ov.Core()defquantize_openvino_model(modality,calibration_data):model_path=fp_model_paths[modality]ifnotos.path.exists(model_path):raiseRuntimeError(f"Model: {model_path} not found. \ First run 239-image-bind-convert notebook to convert model to OpenVINO IR.")model=core.read_model(model_path)quantized_model=nncf.quantize(model=model,calibration_dataset=calibration_data,model_type=nncf.ModelType.TRANSFORMER,# remove ignored_scope for nncf>=2.6.0 (PR with fix https://github.com/openvinotoolkit/nncf/pull/1953)ignored_scope=nncf.IgnoredScope(types=["ReduceL2"]))ov.save_model(quantized_model,int8_model_paths[modality])returnquantized_model
NOTE: Quantization is time and memory consuming operation.
Running quantization code below may take a long time.
iflen(vision_data)==0:raiseRuntimeError('Calibration dataset is empty. Please check internet connection and try to download images manually from the URLs above.')vision_dataset=nncf.Dataset(vision_data)vision_quantized_model=quantize_openvino_model(modality=ModalityType.VISION,calibration_data=vision_dataset)
2023-10-26 13:34:25.166422: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2023-10-26 13:34:25.203294: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-26 13:34:26.097309: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Statistics collection: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [01:18<00:00, 3.81it/s]
Applying Smooth Quant: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 129/129 [00:13<00:00, 9.69it/s]
Statistics collection: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [03:03<00:00, 1.64it/s]
Applying Fast Bias correction: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:23<00:00, 5.54it/s]
NNCF also supports quantization-aware training, and other algorithms
than quantization. See the NNCF
documentation
in the NNCF repository for more information.
Compare results for the OpenVINO FP16 model and the quantized model¶
Compare the probability matrices for FP16 and INT8 models. More
details about probability matrix can be found in
notebook.
defcalculate_compression_rate(modality):fp16_ir_model_size=Path(fp_model_paths[modality]).with_suffix(".bin").stat().st_size/1024quantized_model_size=Path(int8_model_paths[modality]).with_suffix(".bin").stat().st_size/1024print(f'Modality: {modality}')print(f" * FP16 IR model size: {fp16_ir_model_size:.2f} KB")print(f" * INT8 model size: {quantized_model_size:.2f} KB")print(f" * Model compression rate: {fp16_ir_model_size/quantized_model_size:.3f}")
Compare inference time of the FP16 IR and quantized models¶
To measure the inference performance of the FP16 and INT8
models, we use median inference time on calibration dataset. So we can
approximately estimate the speed up of the dynamic quantized models.
NOTE: For the most accurate performance estimation, it is
recommended to run benchmark_app in a terminal/command prompt
after closing other applications with static shapes.