Quantization-aware Training (QAT)


Quantization-aware Training is a popular method that allows quantizing a model and applying fine-tuning to restore accuracy degradation caused by quantization. In fact, this is the most accurate quantization method. This document describes how to apply QAT from the Neural Network Compression Framework (NNCF) to get 8-bit quantized models. This assumes that you are knowledgeable in Python* programming and familiar with the training code for the model in the source DL framework.


Here, we provide the steps that are required to integrate QAT from NNCF into the training script written with PyTorch or TensorFlow 2:


Currently, NNCF for TensorFlow 2 supports optimization of the models created using Keras Sequesntial API or Functional API.

1. Import NNCF API

In this step, you add NNCF-related imports in the beginning of the training script:

import torch
import nncf  # Important - should be imported right after torch
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args
import tensorflow as tf

from nncf import NNCFConfig
from nncf.tensorflow import create_compressed_model, register_default_init_args

2. Create NNCF configuration

Here, you should define NNCF configuration which consists of model-related parameters ("input_info" section) and parameters of optimization methods ("compression" section). For faster convergence, it is also recommended to register a dataset object specific to the DL framework. It will be used at the model creation step to initialize quantization parameters.

nncf_config_dict = {
    "input_info": {"sample_size": [1, 3, 224, 224]}, # input shape required for model tracing
    "compression": {
        "algorithm": "quantization",  # 8-bit quantization with default settings
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_loader) # train_loader is an instance of torch.utils.data.DataLoader
nncf_config_dict = {
    "input_info": {"sample_size": [1, 3, 224, 224]}, # input shape required for model tracing
    "compression": {
        "algorithm": "quantization",  # 8-bit quantization with default settings
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_dataset, batch_size=1) # train_dataset is an instance of tf.data.Dataset

3. Apply optimization methods

In the next step, you need to wrap the original model object with the create_compressed_model() API using the configuration defined in the previous step. This method returns a so-called compression controller and a wrapped model that can be used the same way as the original model. It is worth noting that optimization methods are applied at this step so that the model undergoes a set of corresponding transformations and can contain additional operations required for the optimization. In the case of QAT, the compression controller object is used for model export and, optionally, in distributed training as it will be shown below.

model = TorchModel() # instance of torch.nn.Module
compression_ctrl, model = create_compressed_model(model, nncf_config)
model = KerasModel() # instance of the tensorflow.keras.Model
compression_ctrl, model = create_compressed_model(model, nncf_config)

4. Fine-tune the model

This step assumes that you will apply fine-tuning to the model the same way as it is done for the baseline model. In the case of QAT, it is required to train the model for a few epochs with a small learning rate, for example, 10e-5. In principle, you can skip this step which means that the post-training optimization will be applied to the model.

... # fine-tuning preparations, e.g. dataset, loss, optimizer setup, etc.

# tune quantized model for 5 epochs as the baseline
for epoch in range(0, 5):
    compression_ctrl.scheduler.epoch_step() # Epoch control API

    for i, data in enumerate(train_loader):
        compression_ctrl.scheduler.step()   # Training iteration control API
        ... # training loop body
... # fine-tuning preparations, e.g. dataset, loss, optimizer setup, etc.

# create compression callbacks to control optimization parameters and dump compression statistics
compression_callbacks = create_compression_callbacks(compression_ctrl, log_dir="./compression_log")
# tune quantized model for 5 epochs the same way as the baseline
model.fit(train_dataset, epochs=5, callbacks=compression_callbacks)

5. Multi-GPU distributed training

In the case of distributed multi-GPU training (not DataParallel), you should call compression_ctrl.distributed() before the fine-tuning that will inform optimization methods to do some adjustments to function in the distributed mode.

compression_ctrl.distributed() # call it before the training loop
compression_ctrl.distributed() # call it before the training

6. Export quantized model

When fine-tuning finishes, the quantized model can be exported to the corresponding format for further inference: ONNX in the case of PyTorch and frozen graph - for TensorFlow 2.

compression_ctrl.export_model("compressed_model.pb") #export to Frozen Graph


The precision of weigths gets INT8 only after the step of model conversion to OpenVINO Intermediate Representation. You can expect the model footprint reduction only for that format.

These were the basic steps to applying the QAT method from the NNCF. However, it is required in some cases to save/load model checkpoints during the training. Since NNCF wraps the original model with its own object it provides an API for these needs.

7. (Optional) Save checkpoint

To save model checkpoint use the following API:

checkpoint = {
    'state_dict': model.state_dict(),
    'compression_state': compression_ctrl.get_compression_state(),
    ... # the rest of the user-defined objects to save
torch.save(checkpoint, path_to_checkpoint)
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback

checkpoint = tf.train.Checkpoint(model=model,
                                 ... # the rest of the user-defined objects to save
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
model.fit(..., callbacks=callbacks)

8. (Optional) Restore from checkpoint

To restore the model from checkpoint you should use the following API:

resuming_checkpoint = torch.load(path_to_checkpoint)
compression_state = resuming_checkpoint['compression_state']
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state=compression_state)
state_dict = resuming_checkpoint['state_dict']
from nncf.tensorflow.utils.state import TFCompressionStateLoader

checkpoint = tf.train.Checkpoint(compression_state=TFCompressionStateLoader())
compression_state = checkpoint.compression_state.state

compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state)
checkpoint = tf.train.Checkpoint(model=model,

For more details on saving/loading checkpoints in the NNCF, see the following documentation.

Deploying quantized model

The quantized model can be deployed with OpenVINO in the same way as the baseline model. No extra steps or options are required in this case. For more details, see the corresponding documentation.