Converting a JAX/Flax Model (Experimental)#
The openvino.convert_model
function supports the following JAX/Flax model object types:
jax._src.core.ClosedJaxpr
flax.linen.Module
The jax._src.core.ClosedJaxpr
object is created by tracing a Python function using the jax.make_jaxpr
function.
Here is an example of jax._src.core.ClosedJaxpr
object creation and conversion to an OpenVINO model:
import jax
import jax.numpy as jnp
import openvino as ov
# let us have some JAX function
def jax_func(x, y):
return jax.lax.tanh(jax.lax.max(x, y))
# 1. Create ClosedJaxpr object
x = jnp.array([1.0, 2.0])
y = jnp.array([-1.0, 10.0])
jaxpr = jax.make_jaxpr(jax_func)(x, y)
# 2. Convert to OpenVINO
ov_model = ov.convert_model(jaxpr)
Here is an example of the simplest flax.linen.Module
model conversion:
import flax.linen as nn
import jax
import jax.numpy as jnp
import openvino as ov
# let user have some Flax module
class SimpleModule(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
module = SimpleModule(features=4)
# create example_input used for training
example_input = jnp.ones((2, 3))
# prepare parameters to initialize the module
# they can be also loaded using pickle, flax.serialization
key = jax.random.PRNGKey(0)
params = module.init(key, example_input)
module = module.bind(params)
ov_model = ov.convert_model(module, example_input=example_input)
When using flax.linen.Module
as an input model, openvino.convert_model
requires the
example_input
parameter to be specified. Internally, it triggers model tracing during
the model conversion process, using the capabilities of the jax.make_jaxpr
function.
The __call__
method of flax.linen.Module
object can also have extra custom flags
, like training
, in the input signature. In this case, it is required to create a helper function
that has an input signature without any extra custom flags or parameters, not related to input data.
Here is an example of handling such a case:
import jax
import jax.numpy as jnp
import openvino as ov
from flax import linen as nn
from flax.core import freeze, unfreeze
class SimpleModuleWithExtraFlag(nn.Module):
features: int
@nn.compact
def __call__(self, x, training):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
return x
# 1. Initialize the model
module = SimpleModuleWithExtraFlag(features=10)
key = jax.random.PRNGKey(0)
input_data = jnp.ones((4, 5)) # Batch of 4 samples, each with 5 features
params = module.init(key, input_data, training=False)
# 2. Create helper function with only input data parameter
def helper_function(x):
return module.apply(params, x, training=False)
# 3. Trace the helper function
jaxpr = jax.make_jaxpr(helper_function)(input_data)
# 4. Convert to OpenVINO
ov_model = ov.convert_model(jaxpr)
Note
The resulting OpenVINO IR model can be saved to drive with no additional, JAX-specific steps.
Use the standard ov.save_model(ov_model,'model.xml')
command.
Exporting a JAX/Flax Model to TensorFlow SavedModel Format#
An alternative method of converting JAX/Flax models is exporting them to the TensorFlow SavedModel format
first, with jax.experimental.jax2tf.convert
, and then converting the resulting SavedModel directory to OpenVINO IR,
with openvino.convert_model
. It can be considered a backup solution, if a model cannot be
converted directly, as described previously.
Refer to the JAX and TensorFlow interoperation guide to learn how to export models from JAX/Flax to TensorFlow SavedModel format.
Follow Convert a TensorFlow model chapter to produce an OpenVINO IR model.
Here is an illustration of using these two steps together:
import flax.linen as nn
import jax
import jax.experimental.jax2tf as jax2tf
import jax.numpy as jnp
import openvino as ov
import openvino as ov
import tensorflow as tf
# let user have some Flax module
class SimpleModule(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
flax_module = SimpleModule(features=4)
# prepare parameters to initialize the module
# they can be also loaded using pickle, flax.serialization
example_input = jnp.ones((2, 3))
key = jax.random.PRNGKey(0)
params = flax_module.init(key, example_input)
module = flax_module.bind(params)
# 1. Export to SavedModel
# create TF function and wrap it into TF Module
tf_function = tf.function(jax2tf.convert(flax_module, native_serialization=False), autograph=False,
input_signature=[tf.TensorSpec(shape=[2, 3], dtype=tf.float32)])
tf_module = tf.Module()
tf_module.f = tf_function
tf.saved_model.save(tf_module, './saved_model')
# 2. Convert to OpenVINO
ov_model = ov.convert_model('./saved_model')
Note
As of version 0.4.15, it is required to pass the native_serialization=False
parameter
into jax2tf.convert
for graph serialization mode. Without it, the created TensorFlow
function will contain the embedded StableHLO modules that are not handled by the OpenVINO TensorFlow Frontend.