Convert TensorFlow BERT Model

Pre-trained models for BERT (Bidirectional Encoder Representations from Transformers) are publicly available.

Supported Models

Currently, the following models from the pre-trained BERT model list are supported:

  • BERT-Base, Cased

  • BERT-Base, Uncased

  • BERT-Base, Multilingual Cased

  • BERT-Base, Multilingual Uncased

  • BERT-Base, Chinese

  • BERT-Large, Cased

  • BERT-Large, Uncased

Download the Pre-Trained BERT Model

Download and unzip an archive with the BERT-Base, Multilingual Uncased Model.

After the archive is unzipped, the directory uncased_L-12_H-768_A-12 is created and contains the following files:

  • bert_config.json


  • bert_model.ckpt.index

  • bert_model.ckpt.meta

  • vocab.txt

Pre-trained model meta-graph files are bert_model.ckpt.\*.

Convert TensorFlow BERT Model to IR

To generate the BERT Intermediate Representation (IR) of the model, run the Model Optimizer with the following parameters:

 mo \
--input_meta_graph uncased_L-12_H-768_A-12/bert_model.ckpt.meta \
--output bert/pooler/dense/Tanh                                 \
--input Placeholder{i32},Placeholder_1{i32},Placeholder_2{i32}

Pre-trained models are not suitable for batch reshaping out-of-the-box because of multiple hardcoded shapes in the model.

Convert Reshape-able TensorFlow* BERT Model to the Intermediate Representation

Follow these steps to make pre-trained TensorFlow BERT model reshape-able over batch dimension:

  1. Download pre-trained BERT model you would like to use from the Supported Models list

  2. Clone google-research/bert git repository:
  3. Go to the root directory of the cloned repository:

    cd bert
  4. (Optional) Checkout to the commit that the conversion was tested on:

    git checkout eedf5716c
  5. Download script to load GLUE data:

    • For UNIX*-like systems, run the following command:

    • For Windows* systems:

      Download the Python script to the current working directory.

  6. Download GLUE data by running:

    python3 --tasks MRPC
  7. Open the file in the text editor and delete lines 923-924. They should look like this:

    if not non_static_indexes:
        return shape
  8. Open the file and insert the following code after the line 645:

    import os, sys
    import tensorflow as tf
    from tensorflow.python.framework import graph_io
    with tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph()) as sess:
        (assignment_map, initialized_variable_names) = \
            modeling.get_assignment_map_from_checkpoint(tf.compat.v1.trainable_variables(), init_checkpoint)
        tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map)
        frozen = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["bert/pooler/dense/Tanh"])
        graph_io.write_graph(frozen, './', 'inference_graph.pb', as_text=False)
    print('BERT frozen model path {}'.format(os.path.join(os.path.dirname(__file__), 'inference_graph.pb')))

    Lines before the inserted code should look like this:

    (total_loss, per_example_loss, logits, probabilities) = create_model(
        bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
        num_labels, use_one_hot_embeddings)
  9. Set environment variables BERT_BASE_DIR, BERT_REPO_DIR and run the script to create inference_graph.pb file in the root of the cloned BERT repository.

    export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
    export BERT_REPO_DIR=/current/working/directory
    python3 \
        --task_name=MRPC \
        --do_eval=true \
        --data_dir=$BERT_REPO_DIR/glue_data/MRPC \
        --vocab_file=$BERT_BASE_DIR/vocab.txt \
        --bert_config_file=$BERT_BASE_DIR/bert_config.json \
        --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \

Run the Model Optimizer with the following command line parameters to generate reshape-able BERT Intermediate Representation (IR):

 mo \
--input_model inference_graph.pb \
--input "IteratorGetNext:0{i32}[1 128],IteratorGetNext:1{i32}[1 128],IteratorGetNext:4{i32}[1 128]"

For other applicable parameters, refer to Convert Model from TensorFlow.

For more information about reshape abilities, refer to Using Shape Inference.