Convert a JAX Model to OpenVINO™ IR
This Jupyter notebook can be launched after a local installation only.
JAX is a Python library for
accelerator-oriented array computation and program transformation,
designed for high-performance numerical computing and large-scale
machine learning. JAX provides a familiar NumPy-style API for ease of
adoption by researchers and engineers.
In this tutorial we will show how to convert JAX
ViT
and
Mixer
models in OpenVINO format.
Click here for more detailed information about the models
MLP-Mixer
MLP-Mixer (Mixer for short) consists of per-patch linear embeddings,
Mixer layers, and a classifier head. Mixer layers contain one
token-mixing MLP and one channel-mixing MLP, each consisting of two
fully-connected layers and a GELU nonlinearity. Other components
include: skip-connections, dropout, and linear classifier head.
Table of contents:
Installation Instructions
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only
need a Jupyter server to start. For details, please refer to
Installation
Guide .
Prerequisites
import requests
r = requests . get (
url = "https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py" ,
)
open ( "notebook_utils.py" , "w" ) . write ( r . text )
r = requests . get (
url = "https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py" ,
)
open ( "cmd_helper.py" , "w" ) . write ( r . text )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
from cmd_helper import clone_repo
clone_repo ( "https://github.com/google-research/vision_transformer.git" )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
% pip install -q "openvino>=2024.5.0"
% pip install -q Pillow "jax>=0.4.2" "absl-py>=0.12.0" "flax>=0.6.4" "pandas>=1.1.0" "tensorflow-cpu>=2.4.0" tf_keras tqdm "einops>=0.3.0" "ml-collections>=0.1.0"
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
import PIL
import jax
import numpy as np
from vit_jax import checkpoint
from vit_jax import models_vit
from vit_jax import models_mixer
from vit_jax.configs import models as models_config
import openvino as ov
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
import ipywidgets as widgets
available_models = [ "ViT-B_32" , "Mixer-B_16" ]
model_to_use = widgets . Select (
options = available_models ,
value = available_models [ 0 ],
description = "Select model:" ,
disabled = False ,
)
model_to_use
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Select ( description = 'Select model:' , options = ( 'ViT-B_32' , 'Mixer-B_16' ), value = 'ViT-B_32' )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Load and run the original model and a sample
Download a pre-trained model.
from notebook_utils import download_file
model_name = model_to_use . value
model_config = models_config . MODEL_CONFIGS [ model_name ]
if model_name . startswith ( "Mixer" ):
# Download model trained on imagenet2012
model_name_path = download_file ( f "https://storage.googleapis.com/mixer_models/imagenet1k/ { model_name } .npz" , filename = f " { model_name } _imagenet2012.npz" )
model = models_mixer . MlpMixer ( num_classes = 1000 , ** model_config )
else :
# Download model pre-trained on imagenet21k and fine-tuned on imagenet2012.
model_name_path = download_file (
f "https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ { model_name } .npz" , filename = f " { model_name } _imagenet2012.npz"
)
model = models_vit . VisionTransformer ( num_classes = 1000 , ** model_config )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
ViT-B_32_imagenet2012.npz: 0%| | 0.00/337M [00:00<?, ?B/s]
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Load and convert pretrained checkpoint.
params = checkpoint . load ( f " { model_name } _imagenet2012.npz" )
params [ "pre_logits" ] = {} # Need to restore empty leaf for Flax.
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Get imagenet labels.
from notebook_utils import download_file
imagenet_labels_path = download_file ( "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt" )
imagenet_labels = dict ( enumerate ( open ( imagenet_labels_path )))
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
ilsvrc2012_wordnet_lemmas.txt: 0%| | 0.00/21.2k [00:00<?, ?B/s]
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Get a random picture with the correct dimensions.
resolution = 224 if model_name . startswith ( "Mixer" ) else 384
image_path = download_file ( f "https://picsum.photos/ { resolution } " , filename = "picsum.jpg" )
img = PIL . Image . open ( image_path )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
picsum.jpg: 0%| | 0.00/30.5k [00:00<?, ?B/s]
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
img
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Run the original model inference
# Predict on a batch with a single item (note very efficient TPU usage...)
data = ( np . array ( img ) / 128 - 1 )[ None , ... ]
( logits ,) = model . apply ( dict ( params = params ), data , train = False )
preds = np . array ( jax . nn . softmax ( logits ))
for idx in preds . argsort ()[: - 11 : - 1 ]:
print ( f " { preds [ idx ] : .5f } : { imagenet_labels [ idx ] } " , end = "" )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
0.95251 : alp
0.03884 : valley , vale
0.00192 : cliff , drop , drop - off
0.00173 : ski
0.00059 : lakeside , lakeshore
0.00049 : promontory , headland , head , foreland
0.00036 : volcano
0.00021 : snowmobile
0.00017 : mountain_bike , all - terrain_bike , off - roader
0.00017 : mountain_tent
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Convert the model to OpenVINO IR
OpenVINO supports JAX models via conversion to OpenVINO Intermediate
Representation (IR). OpenVINO model conversion
API
should be used for these purposes. ov.convert_model
function accepts
original JAX model instance and example input for tracing and returns
ov.Model
representing this model in OpenVINO framework. Converted
model can be used for saving on disk using ov.save_model
function or
directly loading on device using core.complie_model
.
Before conversion we need to create the
Jaxprs
(JAX’s internal intermediate representation (IR) of programs) object by
tracing a Python function using the
jax.make_jaxpr
function. [jax.make_jaxpr
] take a function as argument, that should
perform the forward pass. In our case it is calling of model.apply
method. But model.apply
requires not only input data, but also
params
and keyword argument train=False
in our case. To handle
it create a wrapper function model_apply
that calls
model.apply(params, x, train=False)
.
from pathlib import Path
model_path = Path ( f "models/ { model_name } .xml" )
def model_apply ( x ):
return model . apply ( dict ( params = params ), x , train = False )
jaxpr = jax . make_jaxpr ( model_apply )(( np . array ( img ) / 128 - 1 )[ None , ... ])
converted_model = ov . convert_model ( jaxpr )
ov . save_model ( converted_model , model_path )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Compiling the model
Select device from dropdown list for running inference using OpenVINO.
from notebook_utils import device_widget
core = ov . Core ()
device = device_widget ()
device
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Dropdown ( description = 'Device:' , index = 1 , options = ( 'CPU' , 'AUTO' ), value = 'AUTO' )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
compiled_model = core . compile_model ( model_path , device . value )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
Run OpenVINO model inference
( logits_ov ,) = list ( compiled_model ( data ) . values ())[ 0 ]
preds = np . array ( jax . nn . softmax ( logits_ov ))
for idx in preds . argsort ()[: - 11 : - 1 ]:
print ( f " { preds [ idx ] : .5f } : { imagenet_labels [ idx ] } " , end = "" )
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)
0.95255 : alp
0.03881 : valley , vale
0.00192 : cliff , drop , drop - off
0.00173 : ski
0.00059 : lakeside , lakeshore
0.00049 : promontory , headland , head , foreland
0.00036 : volcano
0.00021 : snowmobile
0.00017 : mountain_bike , all - terrain_bike , off - roader
0.00017 : mountain_tent
Convert a JAX Model to OpenVINO™ IR — OpenVINO™ documentationCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboardCopy to clipboard — Version(2024)