This tutorial explains how to convert Google* Neural Machine Translation (GNMT) model to the Intermediate Representation (IR).
On GitHub*, you can find several public versions of TensorFlow* GNMT model implementation. This tutorial explains how to convert the GNMT model from the TensorFlow* Neural Machine Translation (NMT) repository to the IR.
Before converting the model, you need to create a patch file for the repository. The patch modifies the framework code by adding a special command-line argument to the framework options that enables inference graph dumping:
GNMT_inference.patch
file.NOTE: Please, use TensorFlow version 1.13 or lower.
Step 1. Clone the GitHub repository and check out the commit:
Step 2. Get a trained model. You have two options:
wmt16_gnmt_4_layer.json
or wmt16_gnmt_8_layer.json
configuration file using the NMT framework.This tutorial assumes the use of the trained GNMT model from wmt16_gnmt_4_layer.json
config, German to English translation.
Step 3. Create an inference graph:
The OpenVINO™ assumes that a model is used for inference only. Hence, before converting the model into the IR, you need to transform the training graph into the inference graph. For the GNMT model, the training graph and the inference graph have different decoders: the training graph uses a greedy search decoding algorithm, while the inference graph uses a beam search decoding algorithm.
GNMT_inference.patch
patch to the repository. Refer to the Create a Patch File instructions if you do not have it: If you use different checkpoints, use the corresponding values for the src
,tgt
,ckpt
,hparams_path
, and vocab_prefix
parameters. Inference checkpoint inference_GNMT_graph
and frozen inference graph frozen_GNMT_inference_graph.pb
will appear in the /path/to/dump/model/
folder.
Step 4. Convert the model to the IR:
Input and output cutting with the --input
and --output
options is required since OpenVINO™ does not support IteratorGetNext
and LookupTableFindV2
operations.
Input cutting:
IteratorGetNext
operation iterates over a dataset. It is cut by output ports: port 0 contains data tensor with shape [batch_size, max_sequence_length]
, port 1 contains sequence_length
for every batch with shape [batch_size]
.LookupTableFindV2
operations (dynamic_seq2seq/hash_table_Lookup_1
and dynamic_seq2seq/hash_table_Lookup
nodes in the graph) are cut with constant values).Output cutting:
LookupTableFindV2
operation is cut from the output and the dynamic_seq2seq/decoder/decoder/GatherTree
node is treated as a new exit point.For more information about model cutting, refer to Cutting Off Parts of a Model.
NOTE: This step assumes you have converted a model to the Intermediate Representation.
Inputs of the model:
IteratorGetNext/placeholder_out_port_0
input with shape [batch_size, max_sequence_length]
contains batch_size
decoded input sentences. Every sentence is decoded the same way as indices of sentence elements in vocabulary and padded with index of eos
(end of sentence symbol). If the length of the sentence is less than max_sequence_length
, remaining elements are filled with index of eos
token.IteratorGetNext/placeholder_out_port_1
input with shape [batch_size]
contains sequence lengths for every sentence from the first input. \ For example, if max_sequence_length = 50
, batch_size = 1
and the sentence has only 30 elements, then the input tensor for IteratorGetNext/placeholder_out_port_1
should be [30]
.Outputs of the model:
dynamic_seq2seq/decoder/decoder/GatherTree
tensor with shape [max_sequence_length * 2, batch, beam_size]
, that contains beam_size
best translations for every sentence from input (also decoded as indices of words in vocabulary). \ NOTE: Shape of this tensor in TensorFlow* can be different: instead of
max_sequence_length * 2
, it can be any value less than that, because OpenVINO™ does not support dynamic shapes of outputs, while TensorFlow can stop decoding iterations wheneos
symbol is generated.*
NOTE: Before running the example, insert a path to your GNMT
.xml
and.bin
files intoMODEL_PATH
andWEIGHTS_PATH
, and fillinput_data_tensor
andseq_lengths
tensors according to your input data.
For more information about Python API, refer to Inference Engine Python API Overview.