GPT-2 Text Prediction with OpenVINO

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This notebook shows a text prediction with OpenVINO. We use the GPT-2 model, which is a part of the Generative Pre-trained Transformer (GPT) family. GPT-2 is pre-trained on a large corpus of English text using unsupervised training. The model is available from Open Model Zoo, which we will use to download and convert the model to OpenVINO IR.

Imports

import sys
import numpy as np
from openvino.runtime import Core
from IPython.display import Markdown, display
import json
from pathlib import Path

from transformers import GPT2Tokenizer
sys.path.append("../utils")

The model

# directory where the model will be downloaded.
base_model_dir = "model"

# name of the model
model_name = 'gpt-2'

# desired precision
precision = "FP16"

model_path = f"model/public/{model_name}/{precision}/{model_name}.xml"
model_weights_path = f"model/public/{model_name}/{precision}/{model_name}.bin"

Download GPT-2 from Open Model Zoo

We use omz_downloader, which is a command-line tool from the openvino-dev package. omz_downloader automatically creates a directory structure and downloads the selected model. Skip this step if the model is already downloaded. For this demo, we have to download and use gpt-2 model.

download_command = f"omz_downloader " \
                   f"--name {model_name} " \
                   f"--output_dir {base_model_dir} " \
                   f"--cache_dir {base_model_dir}"

display(Markdown(f"Download command: `{download_command}`"))
display(Markdown(f"Downloading {model_name}... (This may take a few minutes depending on your connection.)"))

! $download_command

Download command: omz_downloader --name gpt-2 --output_dir model --cache_dir model

Downloading gpt-2… (This may take a few minutes depending on your connection.)

################|| Downloading gpt-2 ||################

========== Downloading model/public/gpt-2/transformers-4.9.1-py3-none-any.whl


========== Downloading model/public/gpt-2/gpt2/pytorch_model.bin


========== Downloading model/public/gpt-2/gpt2/config.json


========== Downloading model/public/gpt-2/gpt2/vocab.json


========== Downloading model/public/gpt-2/gpt2/merges.txt


========== Downloading model/public/gpt-2/packaging-21.0-py3-none-any.whl


========== Unpacking model/public/gpt-2/transformers-4.9.1-py3-none-any.whl
========== Unpacking model/public/gpt-2/packaging-21.0-py3-none-any.whl
========== Replacing text in model/public/gpt-2/transformers/__init__.py
========== Replacing text in model/public/gpt-2/transformers/file_utils.py
========== Replacing text in model/public/gpt-2/transformers/file_utils.py
========== Replacing text in model/public/gpt-2/transformers/data/datasets/glue.py
========== Replacing text in model/public/gpt-2/transformers/data/datasets/squad.py
========== Replacing text in model/public/gpt-2/transformers/data/datasets/language_modeling.py
========== Replacing text in model/public/gpt-2/transformers/file_utils.py
========== Replacing text in model/public/gpt-2/transformers/file_utils.py
========== Replacing text in model/public/gpt-2/transformers/modelcard.py
========== Replacing text in model/public/gpt-2/transformers/deepspeed.py
========== Replacing text in model/public/gpt-2/transformers/trainer.py

Convert GPT-2 to OpenVINO IR

Since the downloaded GPT-2 model is not yet in OpenVINO IR format, we to perform an additional step to convert it. Use following command:

if not Path(model_path).exists():
    convert_command = (
        f"omz_converter --name {model_name} --precisions {precision}"
        f" --download_dir {base_model_dir} --output_dir {base_model_dir}"
    )
    display(Markdown(f"Convert command: `{convert_command}`"))
    display(Markdown(f"Converting {model_name}"))

    ! $convert_command

Convert command: omz_converter --name gpt-2 --precisions FP16 --download_dir model --output_dir model

Converting gpt-2

========== Converting gpt-2 to ONNX
Conversion to ONNX command: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/bin/python -- /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino/model_zoo/internal_scripts/pytorch_to_onnx.py --model-path=model/public/gpt-2 --model-path=/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino/model_zoo/models/public/gpt-2 --model-name=create_model --import-module=model '--model-param=model_dir=r"model/public/gpt-2/gpt2"' --input-names=input --output-names=output '--input-shapes=[1,1024]' --output-file=model/public/gpt-2/gpt-2.onnx --inputs-dtype=long '--conversion-param=dynamic_axes={"input": {0: "batch_size", 1: "sequence_len"}, "output": {0: "batch_size", 1: "sequence_len"}}'

/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py:185: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  attn_weights = attn_weights / torch.tensor(
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py:185: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  attn_weights = attn_weights / torch.tensor(
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py:200: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
ONNX check passed successfully.

========== Converting gpt-2 to IR (FP16)
Conversion command: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/bin/python -- /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/bin/mo --framework=onnx --data_type=FP16 --output_dir=model/public/gpt-2/FP16 --model_name=gpt-2 --input=input --input_model=model/public/gpt-2/gpt-2.onnx --output=output '--layout=input(NS)'

Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/223-gpt2-text-prediction/model/public/gpt-2/gpt-2.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/223-gpt2-text-prediction/model/public/gpt-2/FP16
    - IR output name:   gpt-2
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     input
    - Output layers:    output
    - Input shapes:     Not specified, inherited from the model
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   input(NS)
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.2.0-7713-af16ea1d79a-releases/2022/2
Model Optimizer version:    2022.2.0-7713-af16ea1d79a-releases/2022/2
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
Warning: One or more of the values of the Constant can't fit in the float16 data type. Those values were casted to the nearest limit value, the model can produce incorrect results.
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/223-gpt2-text-prediction/model/public/gpt-2/FP16/gpt-2.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-275/.workspace/scm/ov-notebook/notebooks/223-gpt2-text-prediction/model/public/gpt-2/FP16/gpt-2.bin
[ SUCCESS ] Total execution time: 4.36 seconds.
[ SUCCESS ] Memory consumed: 1348 MB.
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Load the model

Converted models are located in a fixed directory structure, which indicates source, model name and precision. We start by building an Inference Engine object. Then we read the network architecture and model weights from the .xml and .bin files, respectively. Finally, we compile the model for the desired device. Because we use the dynamic shapes feature, which is only available on CPU, we must use CPU for the device. Dynamic shapes support on GPU is coming soon.

Since the text recognition model has a dynamic input shape, you cannot directly switch device to GPU for inference on integrated or discrete Intel GPUs. In order to run inference on iGPU or dGPU with this model, you will need to resize the inputs to this model to use a fixed size and then try running the inference on GPU device.

# initialize inference engine
ie_core = Core()

# read the model and corresponding weights from file
model = ie_core.read_model(model=model_path, weights=model_weights_path)

# assign dynamic shapes to every input layer
for input_layer in model.inputs:
    input_shape = input_layer.partial_shape
    input_shape[0] = -1
    input_shape[1] = -1
    model.reshape({input_layer: input_shape})

# compile the model for CPU devices
compiled_model = ie_core.compile_model(model=model, device_name="CPU")

# get input and output names of nodes
input_keys = next(iter(compiled_model.inputs))
output_keys = next(iter(compiled_model.outputs))

Input keys are the names of the input nodes and output keys contain names of the output nodes of the network. In the case of GPT-2, we have batch size and sequence length as inputs and batch size, sequence length and vocab size as outputs.

Pre-Processing

NLP models often take a list of tokens as a standard input. A token is a single word mapped to an integer. To provide the proper input, we use a vocabulary file to handle the mapping. So first let’s load the vocabulary file.

def load_vocab_file(vocab_file_path):
    with open(vocab_file_path, "r", encoding="utf-8") as content:
        return json.load(content)
vocal_file_path = f"model/public/{model_name}/gpt2/vocab.json"
vocab = load_vocab_file(vocal_file_path)

Define tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# this function converts text to tokens
def tokenize(text):
    input_ids = tokenizer(text)['input_ids']
    input_ids = np.array(input_ids).reshape(1, -1)
    return input_ids

The last token in the vocabulary list is an endoftext token. We store the index of this token in order to use this index as padding at later stage.

eos_token_id = len(vocab) - 1
tokenizer._convert_id_to_token(len(vocab) - 1)
'<|endoftext|>'

Define Softmax layer

A softmax function is used to convert top-k logits into a probability distribution.

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    summation = e_x.sum(axis=-1, keepdims=True)
    return e_x / summation

Set the minimum sequence length

If the minimum sequence length is not reached, the following code will reduce the probability of the eos token occurring. This continues the process of generating the next words.

def process_logits(input_ids, scores, eos_token_id, min_length=0):
    cur_length = input_ids.shape[-1]
    if cur_length < min_length:
        scores[:, eos_token_id] = -float("inf")
    return scores

Top-K sampling

In Top-K sampling, we filter the K most likely next words and redistribute the probability mass among only those K next words.

def get_top_k_logits(scores, top_k):
    filter_value = -float("inf")
    top_k = min(max(top_k, 1), scores.shape[-1])
    top_k_scores = -np.sort(-scores)[:, :top_k]
    indices_to_remove = scores < np.min(top_k_scores)
    filtred_scores = np.ma.array(scores, mask=indices_to_remove,
                                 fill_value=filter_value).filled()
    return filtred_scores

Main Processing Function

Generating the predicted sequence.

def generate_sequence(input_ids, max_sequence_length=128,
                      eos_token_id=eos_token_id):
    while True:
        cur_input_len = len(input_ids[0])
        pad_len = max_sequence_length - cur_input_len
        model_input = np.concatenate((input_ids,
                                      [[eos_token_id] * pad_len]), axis=-1)
        # passing the padded sequnce into the model
        outputs = compiled_model(inputs=[model_input])[output_keys]
        next_token_logits = outputs[:, cur_input_len - 1, :]
        # pre-process distribution
        next_token_scores = process_logits(input_ids,
                                           next_token_logits, eos_token_id)
        top_k = 20
        next_token_scores = get_top_k_logits(next_token_scores, top_k)
        # get next token id
        probs = softmax(next_token_scores)
        next_tokens = np.random.choice(probs.shape[-1], 1,
                                       p=probs[0], replace=True)
        # break the loop if max length or end of text token is reached
        if cur_input_len == max_sequence_length or next_tokens == eos_token_id:
            break
        else:
            input_ids = np.concatenate((input_ids, [next_tokens]), axis=-1)
    return input_ids

Run

The text variable below is the input used to generate a predicted sequence.

text = "Deep learning is a type of machine learning that uses neural networks"
input_ids = tokenize(text)
output_ids = generate_sequence(input_ids)
S = " "
# Convert IDs to words and make the sentence from it
for i in output_ids[0]:
    S += tokenizer.convert_tokens_to_string(tokenizer._convert_id_to_token(i))
print("Input Text: ", text)
print()
print(f"Predicted Sequence:{S}")
Input Text:  Deep learning is a type of machine learning that uses neural networks

Predicted Sequence: Deep learning is a type of machine learning that uses neural networks to learn a large set of facts about a situation and then compares and contrasts that information with what's in the background.

The team of researchers from the University of Washington in Seattle, in collaboration with the University of Michigan and the University of Pennsylvania, analyzed the data on a large, well-known social network (Facebook Twitter) called Twitter Learning.

The researchers found that Twitter Learning was more efficient than the previous network of Facebook Twitter Learning in predicting outcomes on a scale of 1 to 10 as compared to Facebook Facebook learning. This was not surprising given that the social networks