Quantize NLP models with OpenVINO Post-Training Optimization Tool ​

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.

Github

This tutorial demonstrates how to apply INT8 quantization to the Natural Language Processing model known as BERT, using the Post-Training Optimization Tool API (part of the OpenVINO Toolkit). We will use a fine-tuned HuggingFace BERT PyTorch model trained on the Microsoft Research Paraphrase Corpus (MRPC). The tutorial is designed to be extendable to custom models and datasets. It consists of the following steps:

  • Download and prepare the BERT model and MRPC dataset

  • Define data loading and accuracy validation functionality

  • Prepare the model for quantization

  • Run optimization pipeline

  • Compare performance of the original and quantized models

Imports

import os
import sys
import time
import warnings
from pathlib import Path
from zipfile import ZipFile

import numpy as np
import torch
from addict import Dict
from compression.api import DataLoader as POTDataLoader
from compression.api import Metric
from compression.engines.ie_engine import IEEngine
from compression.graph import load_model, save_model
from compression.graph.model_utils import compress_model_weights
from compression.pipeline.initializer import create_pipeline
from torch.utils.data import TensorDataset
from transformers import BertForSequenceClassification, BertTokenizer
from transformers import (
    glue_convert_examples_to_features as convert_examples_to_features,
)
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors

sys.path.append("../utils")
from notebook_utils import download_file
22:21:53 accuracy_checker WARNING: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/defusedxml/__init__.py:30: DeprecationWarning: defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.
  from . import cElementTree

22:21:53 accuracy_checker WARNING: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/compression/algorithms/quantization/optimization/algorithm.py:29: UserWarning: Nevergrad package could not be imported. If you are planning to useany hyperparameter optimization algo, consider installing itusing pip. This implies advanced usage of the tool.Note that nevergrad is compatible only with Python 3.6+
  warnings.warn(

Settings

# Set the data and model directories, model source URL and model filename
DATA_DIR = "data"
MODEL_DIR = "model"
MODEL_LINK = "https://download.pytorch.org/tutorial/MRPC.zip"
FILE_NAME = MODEL_LINK.split("/")[-1]

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

Prepare the Model

Next steps include: - Download and unpack pre-trained BERT model for MRPC by PyTorch - Convert model to ONNX - Run OpenVINO Model Optimizer tool to convert the model from the ONNX representation to the OpenVINO Intermediate Representation (IR)

download_file(MODEL_LINK, directory=MODEL_DIR, show_progress=True)
with ZipFile(f"{MODEL_DIR}/{FILE_NAME}", "r") as zip_ref:
    zip_ref.extractall(MODEL_DIR)
model/MRPC.zip:   0%|          | 0.00/387M [00:00<?, ?B/s]

Import all dependencies to load the original PyTorch model and convert it to the ONNX representation.

BATCH_SIZE = 1
MAX_SEQ_LENGTH = 128


def export_model_to_onnx(model, path):
    with torch.no_grad():
        default_input = torch.ones(1, MAX_SEQ_LENGTH, dtype=torch.int64)
        inputs = {
            "input_ids": default_input,
            "attention_mask": default_input,
            "token_type_ids": default_input,
        }
        symbolic_names = {0: "batch_size", 1: "max_seq_len"}
        torch.onnx.export(
            model,
            (inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]),
            path,
            opset_version=11,
            do_constant_folding=True,
            input_names=["input_ids", "input_mask", "segment_ids"],
            output_names=["output"],
            dynamic_axes={
                "input_ids": symbolic_names,
                "input_mask": symbolic_names,
                "segment_ids": symbolic_names,
            },
        )
        print("ONNX model saved to {}".format(path))


torch_model = BertForSequenceClassification.from_pretrained(os.path.join(MODEL_DIR, "MRPC"))
onnx_model_path = Path(MODEL_DIR) / "bert_mrpc.onnx"
if not onnx_model_path.exists():
    export_model_to_onnx(torch_model, onnx_model_path)
22:23:26 accuracy_checker WARNING: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py:200: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
ONNX model saved to model/bert_mrpc.onnx

Convert the ONNX Model to OpenVINO IR

ir_model_xml = onnx_model_path.with_suffix(".xml")
ir_model_bin = onnx_model_path.with_suffix(".bin")

if not ir_model_xml.exists():
    !mo --input_model $onnx_model_path --output_dir $MODEL_DIR --model_name bert_mrpc --input input_ids,input_mask,segment_ids --input_shape [1,128],[1,128],[1,128] --output output --data_type FP16
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model/bert_mrpc.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model
    - IR output name:   bert_mrpc
    - Log level:    ERROR
    - Batch:    Not specified, inherited from the model
    - Input layers:     input_ids,input_mask,segment_ids
    - Output layers:    output
    - Input shapes:     [1,128],[1,128],[1,128]
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - Enable grouped convolutions fusing:   True
    - Move mean values to preprocess section:   None
    - Reverse input channels:   False
ONNX specific parameters:
    - Inference Engine found in:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
Inference Engine version:   2021.4.2-3976-0943ed67223-refs/pull/539/head
Model Optimizer version:    2021.4.2-3976-0943ed67223-refs/pull/539/head
[ WARNING ]  Convert data type of Parameter "input_ids" to int32
[ WARNING ]  Convert data type of Parameter "input_mask" to int32
[ WARNING ]  Convert data type of Parameter "segment_ids" to int32
[ SUCCESS ] Generated IR version 10 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model/bert_mrpc.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/model/bert_mrpc.bin
[ SUCCESS ] Total execution time: 34.91 seconds.
[ SUCCESS ] Memory consumed: 1349 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*

Prepare MRPC Task Dataset

To run this tutorial, you will need to download the General Language Understanding Evaluation (GLUE) data for the MRPC task from HuggingFace. The code below will download a script that fetches the MRPC dataset.

download_file(
    "https://raw.githubusercontent.com/huggingface/transformers/f98ef14d161d7bcdc9808b5ec399981481411cc1/utils/download_glue_data.py",
    show_progress=False,
)
PosixPath('/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/105-language-quantize-bert/download_glue_data.py')
from download_glue_data import format_mrpc

format_mrpc(DATA_DIR, "")
Processing MRPC...
Local MRPC data not specified, downloading data from https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt
    Completed!

Define DataLoader for POT

In this step, we need to define DataLoader based on POT API. It will be used to collect statistics for quantization and run model evaluation. We use helper functions from the HuggingFace Transformers to do the data preprocessing. It takes raw text data and encodes sentences and words producing three model inputs. For more details about the data preprocessing and tokenization please refer to this description.

class MRPCDataLoader(POTDataLoader):
    # Required methods
    def __init__(self, config):
        """Constructor
        :param config: data loader specific config
        """
        if not isinstance(config, Dict):
            config = Dict(config)
        super().__init__(config)
        self._task = config["task"].lower()
        self._model_dir = config["model_dir"]
        self._data_dir = config["data_source"]
        self._batch_size = config["batch_size"]
        self._max_length = config["max_length"]
        self._prepare_dataset()

    def __len__(self):
        """Returns size of the dataset"""
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Returns annotation, data and metadata at the specified index.
        Possible formats:
        (index, annotation), data
        (index, annotation), data, metadata
        """
        if index >= len(self):
            raise IndexError

        batch = self.dataset[index]
        batch = tuple(t.detach().cpu().numpy() for t in batch)
        inputs = {"input_ids": batch[0], "input_mask": batch[1], "segment_ids": batch[2]}
        labels = batch[3]
        return (index, labels), inputs

    # Methods specific to the current implementation
    def _prepare_dataset(self):
        """Prepare dataset"""
        tokenizer = BertTokenizer.from_pretrained(self._model_dir, do_lower_case=True)
        processor = processors[self._task]()
        output_mode = output_modes[self._task]
        label_list = processor.get_labels()
        examples = processor.get_dev_examples(self._data_dir)
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=self._max_length,
            output_mode=output_mode,
        )
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
        self.dataset = TensorDataset(
            all_input_ids, all_attention_mask, all_token_type_ids, all_labels
        )

Define Accuracy Metric Calculation

At this step the Metric interface for MRPC task metrics is implemented. It is used for validating the accuracy of the models.

class Accuracy(Metric):

    # Required methods
    def __init__(self):
        super().__init__()
        self._name = "Accuracy"
        self._matches = []

    @property
    def value(self):
        """Returns accuracy metric value for the last model output."""
        return {self._name: self._matches[-1]}

    @property
    def avg_value(self):
        """Returns accuracy metric value for all model outputs."""
        return {self._name: np.ravel(self._matches).mean()}

    def update(self, output, target):
        """
        Updates prediction matches.

        :param output: model output
        :param target: annotations
        """
        if len(output) > 1:
            raise Exception(
                "The accuracy metric cannot be calculated " "for a model with multiple outputs"
            )
        output = np.argmax(output)
        match = output == target[0]
        self._matches.append(match)

    def reset(self):
        """
        Resets collected matches
        """
        self._matches = []

    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {"direction": "higher-better", "type": "accuracy"}}

Run Quantization Pipeline

Here we define a configuration for our quantization pipeline and run it. Please note that we use built-in IEEngine implementation of Engine interface from the POT API for model inference.

warnings.filterwarnings("ignore")  # Suppress accuracychecker warnings

model_config = Dict({"model_name": "bert_mrpc", "model": ir_model_xml, "weights": ir_model_bin})
engine_config = Dict({"device": "CPU"})
dataset_config = {
    "task": "mrpc",
    "data_source": os.path.join(DATA_DIR, "MRPC"),
    "model_dir": os.path.join(MODEL_DIR, "MRPC"),
    "batch_size": BATCH_SIZE,
    "max_length": MAX_SEQ_LENGTH,
}
algorithms = [
    {
        "name": "DefaultQuantization",
        "params": {
            "target_device": "ANY",
            "model_type": "transformer",
            "preset": "performance",
            "stat_subset_size": 250,
        },
    }
]


# Step 1: Load the model
model = load_model(model_config=model_config)

# Step 2: Initialize the data loader
data_loader = MRPCDataLoader(config=dataset_config)

# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric
metric = Accuracy()

# Step 4: Initialize the engine for metric calculation and statistics collection
engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)

# Step 5: Create a pipeline of compression algorithms
pipeline = create_pipeline(algo_config=algorithms, engine=engine)

# Step 6 (Optional): Evaluate the original model. Print the results
fp_results = pipeline.evaluate(model=model)
if fp_results:
    print("FP16 model results:")
    for name, value in fp_results.items():
        print(f"{name}: {value:.5f}")
FP16 model results:
Accuracy: 0.86029
# Step 7: Execute the pipeline
warnings.filterwarnings("ignore")  # Suppress accuracychecker warnings
print(
    f"Quantizing model with {algorithms[0]['params']['preset']} preset and {algorithms[0]['name']}"
)
start_time = time.perf_counter()
compressed_model = pipeline.run(model=model)
end_time = time.perf_counter()
print(f"Quantization finished in {end_time - start_time:.2f} seconds")

# Step 8 (Optional): Compress model weights to quantized precision
#                    in order to reduce the size of final .bin file
compress_model_weights(model=compressed_model)

# Step 9: Save the compressed model to the desired path
compressed_model_paths = save_model(model=compressed_model, save_path=MODEL_DIR, model_name="quantized_bert_mrpc"
)
compressed_model_xml = compressed_model_paths[0]["model"]
Quantizing model with performance preset and DefaultQuantization
Quantization finished in 103.78 seconds
# Step 10 (Optional): Evaluate the compressed model and print the results
int_results = pipeline.evaluate(model=compressed_model)

if int_results:
    print("INT8 model results:")
    for name, value in int_results.items():
        print(f"{name}: {value:.5f}")
INT8 model results:
Accuracy: 0.85294

Compare Performance of the Original and Quantized Models

Finally, we will measure the inference performance of the FP32 and INT8 models. To do this, we use Benchmark Tool - OpenVINO’s inference performance measurement tool.

NOTE: For more accurate performance, we recommended running benchmark_app in a terminal/command prompt after closing other applications. Run benchmark_app -m model.xml -d CPU to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run benchmark_app --help to see an overview of all command line options.

## compressed_model_xml is defined after quantizing the model.
## Uncomment the lines below to set default values for the model file locations.
# ir_model_xml = "model/bert_mrpc.xml"
# compressed_model_xml = "model/quantized_bert_mrpc.xml"
# Inference FP16 model (IR)
! benchmark_app -m $ir_model_xml -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading Inference Engine
[ INFO ] InferenceEngine:
         API version............. 2021.4.2-3976-0943ed67223-refs/pull/539/head
[ INFO ] Device info
         CPU
         MKLDNNPlugin............ version 2.1
         Build................... 2021.4.2-3976-0943ed67223-refs/pull/539/head

[Step 3/11] Setting device configuration
[ WARNING ] -nstreams default value is determined automatically for CPU device. Although the automatic selection usually provides a reasonable performance,but it still may be non-optimal for some cases, for more information look at README.
[Step 4/11] Reading network files
[ INFO ] Read network took 120.48 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Network input 'input_ids' precision I32, dimensions (NC): 1 128
[ INFO ] Network input 'input_mask' precision I32, dimensions (NC): 1 128
[ INFO ] Network input 'segment_ids' precision I32, dimensions (NC): 1 128
[ INFO ] Network output 'output' precision FP32, dimensions (NC): 1 2
[Step 7/11] Loading the model to the device
[ INFO ] Load network took 788.31 ms
[Step 8/11] Setting optimal runtime parameters
[Step 9/11] Creating infer requests and filling input blobs with images
[ WARNING ] No input files were given: all inputs will be filled with random values!
[ INFO ] Infer Request 0 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 1 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 2 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 3 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 4 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 5 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests using 6 streams for CPU, limits: 60000 ms duration)
[ INFO ] First inference took 122.41 ms
[Step 11/11] Dumping statistics report
Count:      3084 iterations
Duration:   60144.82 ms
Latency:    116.60 ms
Throughput: 51.28 FPS
# Inference INT8 model (IR)
! benchmark_app -m $compressed_model_xml -d CPU -api async
[Step 1/11] Parsing and validating input arguments
[ WARNING ]  -nstreams default value is determined automatically for a device. Although the automatic selection usually provides a reasonable performance, but it still may be non-optimal for some cases, for more information look at README.
[Step 2/11] Loading Inference Engine
[ INFO ] InferenceEngine:
         API version............. 2021.4.2-3976-0943ed67223-refs/pull/539/head
[ INFO ] Device info
         CPU
         MKLDNNPlugin............ version 2.1
         Build................... 2021.4.2-3976-0943ed67223-refs/pull/539/head

[Step 3/11] Setting device configuration
[ WARNING ] -nstreams default value is determined automatically for CPU device. Although the automatic selection usually provides a reasonable performance,but it still may be non-optimal for some cases, for more information look at README.
[Step 4/11] Reading network files
[ INFO ] Read network took 90.80 ms
[Step 5/11] Resizing network to match image sizes and given batch
[ INFO ] Network batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Network input 'input_ids' precision I32, dimensions (NC): 1 128
[ INFO ] Network input 'input_mask' precision I32, dimensions (NC): 1 128
[ INFO ] Network input 'segment_ids' precision I32, dimensions (NC): 1 128
[ INFO ] Network output 'output' precision FP32, dimensions (NC): 1 2
[Step 7/11] Loading the model to the device
[ INFO ] Load network took 586.32 ms
[Step 8/11] Setting optimal runtime parameters
[Step 9/11] Creating infer requests and filling input blobs with images
[ WARNING ] No input files were given: all inputs will be filled with random values!
[ INFO ] Infer Request 0 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 1 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 2 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 3 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 4 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[ INFO ] Infer Request 5 filling
[ INFO ] Fill input 'input_ids' with random values (some binary data is expected)
[ INFO ] Fill input 'input_mask' with random values (some binary data is expected)
[ INFO ] Fill input 'segment_ids' with random values (some binary data is expected)
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests using 6 streams for CPU, limits: 60000 ms duration)
[ INFO ] First inference took 83.35 ms
[Step 11/11] Dumping statistics report
Count:      8700 iterations
Duration:   60044.16 ms
Latency:    40.48 ms
Throughput: 144.89 FPS