Quantization-aware Training (QAT)#
Introduction#
Quantization-aware Training is a popular method that allows quantizing a model and applying fine-tuning to restore accuracy degradation caused by quantization. In fact, this is the most accurate quantization method. This document describes how to apply QAT from the Neural Network Compression Framework (NNCF) to get 8-bit quantized models. This assumes that you are knowledgeable in Python programming and familiar with the training code for the model in the source DL framework.
Steps required to apply QAT to the model:
Note
Currently, NNCF for TensorFlow supports the optimization of models created using the Keras Sequential API or Functional API.
1. Apply Post Training Quantization to the Model#
Quantize the model using the Post-Training Quantization method.
model = TorchModel() # instance of torch.nn.Module
quantized_model = nncf.quantize(model, ...)
model = KerasModel() # instance of the tensorflow.keras.Model
quantized_model = nncf.quantize(model, ...)
2. Fine-tune the Model#
This step assumes applying fine-tuning to the model the same way it is done for the baseline model. For QAT, it is required to train the model for a few epochs with a small learning rate, for example, 1e-5. Quantized models perform all computations in the floating-point precision during fine-tuning by modeling quantization errors in both forward and backward passes.
... # fine-tuning preparations, e.g. dataset, loss, optimization setup, etc.
# tune quantized model for 5 epochs as the baseline
for epoch in range(0, 5):
for i, data in enumerate(train_loader):
... # training loop body
... # fine-tuning preparations, e.g. dataset, loss, optimization setup, etc.
# tune quantized model for 5 epochs the same way as the baseline
quantized_model.fit(train_dataset, epochs=5)
Note
The precision of weight transitions to INT8 only after converting the model to OpenVINO Intermediate Representation. You can expect a reduction in the model footprint only for that format.
These steps outline the basics of applying the QAT method from the NNCF. However, in some cases, it is required to save/load model checkpoints during training. Since NNCF wraps the original model with its own object, it provides an API for these needs.
3. (Optional) Save Checkpoint#
To save a model checkpoint, use the following API:
checkpoint = {
'state_dict': model.state_dict(),
'nncf_config': model.nncf.get_config(),
... # the rest of the user-defined objects to save
}
torch.save(checkpoint, path_to_checkpoint)
from nncf.tensorflow import ConfigState
from nncf.tensorflow import get_config
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback
nncf_config = get_config(quantized_model)
checkpoint = tf.train.Checkpoint(model=quantized_model,
nncf_config_state=ConfigState(nncf_config),
... # the rest of the user-defined objects to save
)
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
...
quantized_model.fit(..., callbacks=callbacks)
4. (Optional) Restore from Checkpoint#
To restore the model from checkpoint, use the following API:
resuming_checkpoint = torch.load(path_to_checkpoint)
nncf_config = resuming_checkpoint['nncf_config']
quantized_model = nncf.torch.load_from_config(model, nncf_config, example_input)
state_dict = resuming_checkpoint['state_dict']
model.load_state_dict(state_dict)
from nncf.tensorflow import ConfigState
from nncf.tensorflow import load_from_config
checkpoint = tf.train.Checkpoint(nncf_config_state=ConfigState())
checkpoint.restore(path_to_checkpoint)
quantized_model = load_from_config(model, checkpoint.nncf_config_state.config)
checkpoint = tf.train.Checkpoint(model=quantized_model
... # the rest of the user-defined objects to load
)
checkpoint.restore(path_to_checkpoint)
Deploying the Quantized Model#
You can convert the model to OpenVINO IR, if needed, compile it and run with OpenVINO without any additional steps.
import openvino as ov
input_fp32 = ... # FP32 model input
# convert PyTorch model to OpenVINO model
ov_quantized_model = ov.convert_model(quantized_model, example_input=input_fp32)
# compile the model to transform quantized operations to int8
model_int8 = ov.compile_model(ov_quantized_model)
res = model_int8(input_fp32)
# save the model
ov.save_model(ov_quantized_model, "quantized_model.xml")
import openvino as ov
# convert TensorFlow model to OpenVINO model
ov_quantized_model = ov.convert_model(quantized_model)
# compile the model to transform quantized operations to int8
model_int8 = ov.compile_model(ov_quantized_model)
input_fp32 = ... # FP32 model input
res = model_int8(input_fp32)
# save the model
ov.save_model(ov_quantized_model, "quantized_model.xml")
For more details, see the corresponding documentation.