Quantization-aware Training (QAT) with PyTorch#

Below are the steps required to integrate QAT from NNCF into a training script written with PyTorch:

1. Apply Post Training Quantization to the Model#

Quantize the model using the Post-Training Quantization method.

model = TorchModel() # instance of torch.nn.Module
model = nncf.quantize(model, ...)

2. Fine-tune the Model#

This step assumes applying fine-tuning to the model the same way it is done for the baseline model. For QAT, it is required to train the model for a few epochs with a small learning rate, for example, 1e-5. Quantized models perform all computations in floating-point precision during fine-tuning by modeling quantization errors in both forward and backward passes.

... # fine-tuning preparations, e.g. dataset, loss, optimization setup, etc.

# tune quantized model for 5 epochs as the baseline
for epoch in range(0, 5):
    for i, data in enumerate(train_loader):
        ... # training loop body

Note

The precision of weights transitions to INT8 only after converting the model to OpenVINO Intermediate Representation. You can expect a reduction in model footprint only for that format.

These steps outline the basics of applying the QAT method from the NNCF. However, in some cases, it is required to save/load model checkpoints during training. Since NNCF wraps the original model with its own object, it provides an API for these needs.

3. (Optional) Save Checkpoint#

To save a model checkpoint, use the following API:

checkpoint = {
    'state_dict': model.state_dict(),
    'nncf_config': model.nncf.get_config(),
    ... # the rest of the user-defined objects to save
}
torch.save(checkpoint, path_to_checkpoint)

4. (Optional) Restore from Checkpoint#

To restore the model from checkpoint, use the following API:

resuming_checkpoint = torch.load(path_to_checkpoint)
nncf_config = resuming_checkpoint['nncf_config']
quantized_model = nncf.torch.load_from_config(model, nncf_config, example_input)
state_dict = resuming_checkpoint['state_dict']
model.load_state_dict(state_dict)

Deploying the Quantized Model#

The model can be converted into the OpenVINO Intermediate Representation (IR) if needed, compiled, and run with OpenVINO without any additional steps.

import openvino as ov

input_fp32 = ... # FP32 model input

# convert PyTorch model to OpenVINO model
ov_quantized_model = ov.convert_model(quantized_model, example_input=input_fp32)

# compile the model to transform quantized operations to int8
model_int8 = ov.compile_model(ov_quantized_model)

res = model_int8(input_fp32)

# save the model
ov.save_model(ov_quantized_model, "quantized_model.xml")

For more details, see the corresponding documentation.

Examples#