Hugging Face Model Hub with OpenVINO™¶
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:
The Hugging Face (HF) Model Hub is a central repository for pre-trained deep learning models. It allows exploration and provides access to thousands of models for a wide range of tasks, including text classification, question answering, and image classification. Hugging Face provides Python packages that serve as APIs and tools to easily download and fine tune state-of-the-art pretrained models, namely transformers and diffusers packages.
Throughout this notebook we will learn: 1. How to load a HF pipeline
using the transformers
package and then convert it to OpenVINO. 2.
How to load the same pipeline using Optimum Intel package.
Table of contents:¶
Converting a Model from the HF Transformers Package¶
Hugging Face transformers package provides API for initializing a model and loading a set of pre-trained weights using the model text handle. Discovering a desired model name is straightforward with HF website’s Models page, one can choose a model solving a particular machine learning problem and even sort the models by popularity and novelty.
Installing Requirements¶
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu "transformers[torch]>=4.33.0"
%pip install -q ipywidgets
%pip install -q "openvino>=2023.1.0"
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Imports¶
from pathlib import Path
import numpy as np
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
Initializing a Model Using the HF Transformers Package¶
We will use roberta text sentiment classification model in our example, it is a transformer-based encoder model pretrained in a special way, please refer to the model card to learn more.
Following the instructions on the model page, we use
AutoModelForSequenceClassification
to initialize the model and
perform inference with it. To find more information on HF pipelines and
model initialization please refer to HF
tutorials.
MODEL = "cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(MODEL, return_dict=True)
# The torchscript=True flag is used to ensure the model outputs are tuples
# instead of ModelOutput (which causes JIT errors).
model = AutoModelForSequenceClassification.from_pretrained(MODEL, torchscript=True)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Original Model inference¶
Let’s do a classification of a simple prompt below.
text = "HF models run perfectly with OpenVINO!"
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
scores = output[0][0]
scores = torch.softmax(scores, dim=0).numpy(force=True)
def print_prediction(scores):
for i, descending_index in enumerate(scores.argsort()[::-1]):
label = model.config.id2label[descending_index]
score = np.round(float(scores[descending_index]), 4)
print(f"{i+1}) {label} {score}")
print_prediction(scores)
1) positive 0.9485
2) neutral 0.0484
3) negative 0.0031
Converting the Model to OpenVINO IR format¶
We use the OpenVINO Model conversion API to convert the model (this one is implemented in PyTorch) to OpenVINO Intermediate Representation (IR).
Note how we reuse our real encoded_input
, passing it to the
ov.convert_model
function. It will be used for model tracing.
import openvino as ov
save_model_path = Path('./models/model.xml')
if not save_model_path.exists():
ov_model = ov.convert_model(model, example_input=dict(encoded_input))
ov.save_model(ov_model, save_model_path)
Converted Model Inference¶
First, we pick a device to do the model inference
import ipywidgets as widgets
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value='AUTO',
description='Device:',
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
OpenVINO model IR must be compiled for a specific device prior to the model inference.
compiled_model = core.compile_model(save_model_path, device.value)
# Compiled model call is performed using the same parameters as for the original model
scores_ov = compiled_model(encoded_input.data)[0]
scores_ov = torch.softmax(torch.tensor(scores_ov[0]), dim=0).detach().numpy()
print_prediction(scores_ov)
1) positive 0.9483
2) neutral 0.0485
3) negative 0.0031
Note the prediction of the converted model match exactly the one of the original model.
This is a rather simple example as the pipeline includes just one encoder model. Contemporary state of the art pipelines often consist of several model, feel free to explore other OpenVINO tutorials: 1. Stable Diffusion v2 2. Zero-shot Image Classification with OpenAI CLIP 3. Controllable Music Generation with MusicGen
The workflow for the diffusers
package is exactly the same. The
first example in the list above relies on the diffusers
.
Converting a Model Using the Optimum Intel Package¶
Optimum Intel is the interface between the Transformers and Diffusers libraries and the different tools and libraries provided by Intel to accelerate end-to-end pipelines on Intel architectures.
Among other use cases, Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
Install Requirements for Optimum¶
%pip install -q "git+https://github.com/huggingface/optimum-intel.git" onnx
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using tokenizers before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Note: you may need to restart the kernel to use updated packages.
Import Optimum¶
Documentation for Optimum Intel states: >You can now easily perform
inference with OpenVINO Runtime on a variety of Intel processors (see
the full list of supported devices). For that, just replace the
AutoModelForXxx
class with the corresponding OVModelForXxx
class.
You can find more information in Optimum Intel documentation.
from optimum.intel.openvino import OVModelForSequenceClassification
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using tokenizers before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using tokenizers before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
2024-02-09 23:10:50.826096: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0. 2024-02-09 23:10:50.861099: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-09 23:10:51.428729: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Initialize and Convert the Model Automatically using OVModel class¶
To load a Transformers model and convert it to the OpenVINO format on
the fly, you can set export=True
when loading your model. The model
can be saved in OpenVINO format using save_pretrained
method and
specifying a directory for storing the model as an argument. For the
next usage, you can avoid the conversion step and load the saved early
model from disk using from_pretrained
method without export
specification. We also specified device
parameter for compiling the
model on the specific device, if not provided, the default device will
be used. The device can be changed later in runtime using
model.to(device)
, please note that it may require some time for
model compilation on a newly selected device. In some cases, it can be
useful to separate model initialization and compilation, for example, if
you want to reshape the model using reshape
method, you can postpone
compilation, providing the parameter compile=False
into
from_pretrained
method, compilation can be performed manually using
compile
method or will be performed automatically during first
inference run.
model = OVModelForSequenceClassification.from_pretrained(MODEL, export=True, device=device.value)
# The save_pretrained() method saves the model weights to avoid conversion on the next load.
model.save_pretrained('./models/optimum_model')
Framework not specified. Using pt to export to ONNX.
Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using the export variant default. Available variants are:
- default: The default ONNX variant.
Using framework PyTorch: 2.1.0+cpu
Overriding 1 configuration item(s)
- use_cache -> False
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
Compiling the model to AUTO ...
Convert model using Optimum CLI interface¶
Alternatively, you can use the Optimum CLI interface for converting models (supported starting optimum-intel 1.12 version). General command format:
optimum-cli export openvino --model <model_id_or_path> --task <task> <output_dir>
where task is task to export the model for, if not specified, the task
will be auto-inferred based on the model. Available tasks depend on the
model, but are among: [‘default’, ‘fill-mask’, ‘text-generation’,
‘text2text-generation’, ‘text-classification’, ‘token-classification’,
‘multiple-choice’, ‘object-detection’, ‘question-answering’,
‘image-classification’, ‘image-segmentation’, ‘masked-im’,
‘semantic-segmentation’, ‘automatic-speech-recognition’,
‘audio-classification’, ‘audio-frame-classification’,
‘automatic-speech-recognition’, ‘audio-xvector’, ‘image-to-text’,
‘stable-diffusion’, ‘zero-shot-object-detection’]. For decoder models,
use xxx-with-past
to export the model using past key values in the
decoder.
You can find a mapping between tasks and model classes in Optimum TaskManager documentation.
Additionally, you can specify weights compression --fp16
for the
compression model to FP16 and --int8
for the compression model to
INT8. Please note, that for INT8, it is necessary to install nncf.
Full list of supported arguments available via --help
!optimum-cli export openvino --help
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using tokenizers before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2024-02-09 23:11:03.409282: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--cache_dir CACHE_DIR] [--framework {pt,tf}] [--trust-remote-code] [--pad-token-id PAD_TOKEN_ID] [--fp16] [--int8] [--weight-format {fp32,fp16,int8,int4_sym_g128,int4_asym_g128,int4_sym_g64,int4_asym_g64}] [--ratio RATIO] [--disable-stateful] [--convert-tokenizer] output optional arguments: -h, --help show this help message and exit Required arguments: -m MODEL, --model MODEL Model ID on huggingface.co or path on disk to load model from. output Path indicating the directory where to store the generated OV model. Optional arguments: --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['sentence-similarity', 'object-detection', 'question- answering', 'text-to-audio', 'audio-xvector', 'stable- diffusion-xl', 'feature-extraction', 'image-to-image', 'text-generation', 'mask-generation', 'text- classification', 'image-segmentation', 'automatic- speech-recognition', 'text2text-generation', 'stable- diffusion', 'audio-classification', 'semantic- segmentation', 'fill-mask', 'depth-estimation', 'zero- shot-image-classification', 'image-to-text', 'zero- shot-object-detection', 'multiple-choice', 'conversational', 'image-classification', 'masked-im', 'audio-frame-classification', 'token-classification']. For decoder models, use xxx-with-past to export the model using past key values in the decoder. --cache_dir CACHE_DIR Path indicating where to store cache. --framework {pt,tf} The framework to use for the export. If not provided, will attempt to use the local checkpoint's original framework or what is available in the environment. --trust-remote-code Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository. --pad-token-id PAD_TOKEN_ID This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. --fp16 Compress weights to fp16 --int8 Compress weights to int8 --weight-format {fp32,fp16,int8,int4_sym_g128,int4_asym_g128,int4_sym_g64,int4_asym_g64} The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, int4_* - for INT4 compressed weights. --ratio RATIO Compression ratio between primary and backup precision. In the case of INT4, NNCF evaluates layer sensitivity and keeps the most impactful layers in INT8precision (by default 20% in INT8). This helps to achieve better accuracy after weight compression. --disable-stateful Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. In stateful models all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. If --disable-stateful option is used, it may result in sub-optimal inference performance. Use it when you intentionally want to use a stateless model, for example, to be compatible with existing OpenVINO native inference code that expects kv-cache inputs and outputs in the model. --convert-tokenizer Add converted tokenizer and detokenizer with OpenVINO Tokenizers
The command line export for model from example above with FP16 weights compression:
!optimum-cli export openvino --model $MODEL --task text-classification --fp16 models/optimum_model/fp16
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using tokenizers before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2024-02-09 23:11:07.691775: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
--fp16 option is deprecated and will be removed in a future version. Use --weight-format instead.
Framework not specified. Using pt to export to ONNX.
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using the export variant default. Available variants are:
- default: The default ONNX variant.
Using framework PyTorch: 2.1.0+cpu
Overriding 1 configuration item(s)
- use_cache -> False
After export, model will be available in the specified directory and can be loaded using the same OVModelForXXX class.
model = OVModelForSequenceClassification.from_pretrained("models/optimum_model/fp16", device=device.value)
Compiling the model to AUTO ...
There are some models in the Hugging Face Models Hub, that are already converted and ready to run! You can filter those models out by library name, just type OpenVINO, or follow this link.
The Optimum Model Inference¶
Model inference is exactly the same as for the original model!
output = model(**encoded_input)
scores = output[0][0]
scores = torch.softmax(scores, dim=0).numpy(force=True)
print_prediction(scores)
1) positive 0.9483
2) neutral 0.0485
3) negative 0.0031
You can find more examples of using Optimum Intel here: 1. Accelerate Inference of Sparse Transformer Models 2. Grammatical Error Correction with OpenVINO 3. Stable Diffusion v2.1 using Optimum-Intel OpenVINO 4. Image generation with Stable Diffusion XL 5. Instruction following using Databricks Dolly 2.0 6. Create LLM-powered Chatbot using OpenVINO 7. Document Visual Question Answering Using Pix2Struct and OpenVINO 8. Automatic speech recognition using Distil-Whisper and OpenVINO