Convert a JAX Model to OpenVINO™ IR#
This Jupyter notebook can be launched after a local installation only.
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.
In this tutorial we will show how to convert JAX ViT and Mixer models in OpenVINO format.
Click here for more detailed information about the models
Vision Transformer#
Overview of the model: authors split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. In order to perform classification, authors use the standard approach of adding an extra learnable “classification token” to the sequence.
MLP-Mixer#
MLP-Mixer (Mixer for short) consists of per-patch linear embeddings, Mixer layers, and a classifier head. Mixer layers contain one token-mixing MLP and one channel-mixing MLP, each consisting of two fully-connected layers and a GELU nonlinearity. Other components include: skip-connections, dropout, and linear classifier head.
Table of contents:
Installation Instructions#
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
Prerequisites#
import requests
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py",
)
open("cmd_helper.py", "w").write(r.text)
from cmd_helper import clone_repo
clone_repo("https://github.com/google-research/vision_transformer.git")
%pip install -q "openvino>=2024.5.0"
%pip install -q Pillow "jax>=0.4.2" "absl-py>=0.12.0" "flax>=0.6.4" "pandas>=1.1.0" "tensorflow-cpu>=2.4.0" tf_keras tqdm "einops>=0.3.0" "ml-collections>=0.1.0"
import PIL
import jax
import numpy as np
from vit_jax import checkpoint
from vit_jax import models_vit
from vit_jax import models_mixer
from vit_jax.configs import models as models_config
import openvino as ov
import ipywidgets as widgets
available_models = ["ViT-B_32", "Mixer-B_16"]
model_to_use = widgets.Select(
options=available_models,
value=available_models[0],
description="Select model:",
disabled=False,
)
model_to_use
Select(description='Select model:', options=('ViT-B_32', 'Mixer-B_16'), value='ViT-B_32')
Load and run the original model and a sample#
Download a pre-trained model.
from notebook_utils import download_file
model_name = model_to_use.value
model_config = models_config.MODEL_CONFIGS[model_name]
if model_name.startswith("Mixer"):
# Download model trained on imagenet2012
model_name_path = download_file(f"https://storage.googleapis.com/mixer_models/imagenet1k/{model_name}.npz", filename=f"{model_name}_imagenet2012.npz")
model = models_mixer.MlpMixer(num_classes=1000, **model_config)
else:
# Download model pre-trained on imagenet21k and fine-tuned on imagenet2012.
model_name_path = download_file(
f"https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{model_name}.npz", filename=f"{model_name}_imagenet2012.npz"
)
model = models_vit.VisionTransformer(num_classes=1000, **model_config)
ViT-B_32_imagenet2012.npz: 0%| | 0.00/337M [00:00<?, ?B/s]
Load and convert pretrained checkpoint.
params = checkpoint.load(f"{model_name}_imagenet2012.npz")
params["pre_logits"] = {} # Need to restore empty leaf for Flax.
Get imagenet labels.
from notebook_utils import download_file
imagenet_labels_path = download_file("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt")
imagenet_labels = dict(enumerate(open(imagenet_labels_path)))
ilsvrc2012_wordnet_lemmas.txt: 0%| | 0.00/21.2k [00:00<?, ?B/s]
Get a random picture with the correct dimensions.
resolution = 224 if model_name.startswith("Mixer") else 384
image_path = download_file(f"https://picsum.photos/{resolution}", filename="picsum.jpg")
img = PIL.Image.open(image_path)
picsum.jpg: 0%| | 0.00/30.5k [00:00<?, ?B/s]
img
Run the original model inference#
# Predict on a batch with a single item (note very efficient TPU usage...)
data = (np.array(img) / 128 - 1)[None, ...]
(logits,) = model.apply(dict(params=params), data, train=False)
preds = np.array(jax.nn.softmax(logits))
for idx in preds.argsort()[:-11:-1]:
print(f"{preds[idx]:.5f} : {imagenet_labels[idx]}", end="")
0.95251 : alp
0.03884 : valley, vale
0.00192 : cliff, drop, drop-off
0.00173 : ski
0.00059 : lakeside, lakeshore
0.00049 : promontory, headland, head, foreland
0.00036 : volcano
0.00021 : snowmobile
0.00017 : mountain_bike, all-terrain_bike, off-roader
0.00017 : mountain_tent
Convert the model to OpenVINO IR#
OpenVINO supports JAX models via conversion to OpenVINO Intermediate
Representation (IR). OpenVINO model conversion
API
should be used for these purposes. ov.convert_model
function accepts
original JAX model instance and example input for tracing and returns
ov.Model
representing this model in OpenVINO framework. Converted
model can be used for saving on disk using ov.save_model
function or
directly loading on device using core.complie_model
.
Before conversion we need to create the
Jaxprs
(JAX’s internal intermediate representation (IR) of programs) object by
tracing a Python function using the
jax.make_jaxpr
function. [jax.make_jaxpr
] take a function as argument, that should
perform the forward pass. In our case it is calling of model.apply
method. But model.apply
requires not only input data, but also
params
and keyword argument train=False
in our case. To handle
it create a wrapper function model_apply
that calls
model.apply(params, x, train=False)
.
from pathlib import Path
model_path = Path(f"models/{model_name}.xml")
def model_apply(x):
return model.apply(dict(params=params), x, train=False)
jaxpr = jax.make_jaxpr(model_apply)((np.array(img) / 128 - 1)[None, ...])
converted_model = ov.convert_model(jaxpr)
ov.save_model(converted_model, model_path)
Compiling the model#
Select device from dropdown list for running inference using OpenVINO.
from notebook_utils import device_widget
core = ov.Core()
device = device_widget()
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
compiled_model = core.compile_model(model_path, device.value)
Run OpenVINO model inference#
(logits_ov,) = list(compiled_model(data).values())[0]
preds = np.array(jax.nn.softmax(logits_ov))
for idx in preds.argsort()[:-11:-1]:
print(f"{preds[idx]:.5f} : {imagenet_labels[idx]}", end="")
0.95255 : alp
0.03881 : valley, vale
0.00192 : cliff, drop, drop-off
0.00173 : ski
0.00059 : lakeside, lakeshore
0.00049 : promontory, headland, head, foreland
0.00036 : volcano
0.00021 : snowmobile
0.00017 : mountain_bike, all-terrain_bike, off-roader
0.00017 : mountain_tent