Post-Training Quantization of PyTorch models with NNCF¶
This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.
The goal of this tutorial is to demonstrate how to use the NNCF (Neural Network Compression Framework) 8-bit quantization in post-training mode (without the fine-tuning pipeline) to optimize a PyTorch model for the high-speed inference via OpenVINO™ Toolkit. The optimization process contains the following steps:
Evaluate the original model.
Transform the original model to a quantized one.
Export optimized and original models to ONNX and then to OpenVINO IR.
Compare performance of the obtained
FP32
andINT8
models.
This tutorial uses a ResNet-50 model, pre-trained on Tiny ImageNet, which contains 100000 images of 200 classes (500 for each class) downsized to 64×64 colored images. The tutorial will demonstrate that only a tiny part of the dataset is needed for the post-training quantization, not demanding the fine-tuning of the model.
NOTE: This notebook requires that a C++ compiler is accessible on the default binary search path of the OS you are running the notebook.
Preparations¶
# On Windows, this script adds the directory that contains cl.exe to the PATH to enable PyTorch to find the
# required C++ tools. This code assumes that Visual Studio 2019 is installed in the default
# directory. If you have a different C++ compiler, add the correct path to os.environ["PATH"]
# directly.
# Adding the path to os.environ["LIB"] is not always required - it depends on the system configuration.
import sys
if sys.platform == "win32":
import distutils.command.build_ext
import os
from pathlib import Path
VS_INSTALL_DIR = r"C:/Program Files (x86)/Microsoft Visual Studio"
cl_paths = sorted(list(Path(VS_INSTALL_DIR).glob("**/Hostx86/x64/cl.exe")))
if len(cl_paths) == 0:
raise ValueError(
"Cannot find Visual Studio. This notebook requires C++. If you installed "
"a C++ compiler, please add the directory that contains cl.exe to "
"`os.environ['PATH']`"
)
else:
# If multiple versions of MSVC are installed, get the most recent one.
cl_path = cl_paths[-1]
vs_dir = str(cl_path.parent)
os.environ["PATH"] += f"{os.pathsep}{vs_dir}"
# The code for finding the library dirs is from
# https://stackoverflow.com/questions/47423246/get-pythons-lib-path
d = distutils.core.Distribution()
b = distutils.command.build_ext.build_ext(d)
b.finalize_options()
os.environ["LIB"] = os.pathsep.join(b.library_dirs)
print(f"Added {vs_dir} to PATH")
Preparing model files¶
Note: All NNCF logging messages below ERROR level (INFO and WARNING) are disabled to simplify the tutorial. For production use, it is recommended to enable logging by removing
set_log_level(logging.ERROR)
.
Imports¶
import logging
import os
import sys
import time
import warnings
import zipfile
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from nncf import NNCFConfig # Important - should be imported directly after torch
from nncf.common.logging.logger import set_log_level
set_log_level(logging.ERROR) # Disables all NNCF info and warning messages
from nncf.torch import create_compressed_model, register_default_init_args
from openvino.runtime import Core
from torch.jit import TracerWarning
sys.path.append("../utils")
from notebook_utils import download_file
2023-03-23 22:54:32.964647: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0. /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-369/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino/offline_transformations/__init__.py:10: FutureWarning: The module is private and following namespace offline_transformations will be removed in the future, use openvino.runtime.passes instead! warnings.warn(
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
Settings¶
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
BASE_MODEL_NAME = "resnet50"
IMAGE_SIZE = [64, 64]
OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)
# Paths where PyTorch, ONNX and OpenVINO IR models will be stored.
fp32_checkpoint_filename = Path(BASE_MODEL_NAME + "_fp32").with_suffix(".pth")
fp32_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_fp32")).with_suffix(".onnx")
fp32_ir_path = fp32_onnx_path.with_suffix(".xml")
int8_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_int8")).with_suffix(".onnx")
int8_ir_path = int8_onnx_path.with_suffix(".xml")
fp32_pth_url = "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/304_resnet50_fp32.pth"
download_file(fp32_pth_url, directory=MODEL_DIR, filename=fp32_checkpoint_filename)
Using cpu device
model/resnet50_fp32.pth: 0%| | 0.00/91.5M [00:00<?, ?B/s]
PosixPath('/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-369/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/model/resnet50_fp32.pth')
Download and Prepare Tiny ImageNet dataset¶
100k images of shape 3x64x64,
200 different classes: snake, spider, cat, truck, grasshopper, gull, etc.
def download_tiny_imagenet_200(
output_dir: Path,
url: str = "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
tarname: str = "tiny-imagenet-200.zip",
):
archive_path = output_dir / tarname
download_file(url, directory=output_dir, filename=tarname)
zip_ref = zipfile.ZipFile(archive_path, "r")
zip_ref.extractall(path=output_dir)
zip_ref.close()
print(f"Successfully downloaded and extracted dataset to: {output_dir}")
def create_validation_dir(dataset_dir: Path):
VALID_DIR = dataset_dir / "val"
val_img_dir = VALID_DIR / "images"
fp = open(VALID_DIR / "val_annotations.txt", "r")
data = fp.readlines()
val_img_dict = {}
for line in data:
words = line.split("\t")
val_img_dict[words[0]] = words[1]
fp.close()
for img, folder in val_img_dict.items():
newpath = val_img_dir / folder
if not newpath.exists():
os.makedirs(newpath)
if (val_img_dir / img).exists():
os.rename(val_img_dir / img, newpath / img)
DATASET_DIR = OUTPUT_DIR / "tiny-imagenet-200"
if not DATASET_DIR.exists():
download_tiny_imagenet_200(OUTPUT_DIR)
create_validation_dir(DATASET_DIR)
output/tiny-imagenet-200.zip: 0%| | 0.00/237M [00:00<?, ?B/s]
Successfully downloaded and extracted dataset to: output
Helpers classes and functions¶
The code below will help to count accuracy and visualize validation process.
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, fmt: str = ":f"):
self.name = name
self.fmt = fmt
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val: float, n: int = 1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
"""Displays the progress of validation process"""
def __init__(self, num_batches: int, meters: List[AverageMeter], prefix: str = ""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch: int):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def _get_batch_fmtstr(self, num_batches: int):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def accuracy(output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
Validation function¶
def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module):
"""Compute the metrics using data from val_loader for the model"""
batch_time = AverageMeter("Time", ":3.3f")
top1 = AverageMeter("Acc@1", ":2.2f")
top5 = AverageMeter("Acc@5", ":2.2f")
progress = ProgressMeter(len(val_loader), [batch_time, top1, top5], prefix="Test: ")
# Switch to evaluate mode.
model.eval()
model.to(device)
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# Compute the output.
output = model(images)
# Measure accuracy and record loss.
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# Measure elapsed time.
batch_time.update(time.time() - end)
end = time.time()
print_frequency = 10
if i % print_frequency == 0:
progress.display(i)
print(
" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
)
return top1.avg
Create and load original uncompressed model¶
ResNet-50 from the torchivision repository is pre-trained on ImageNet with more prediction classes than Tiny ImageNet, so the model is adjusted by swapping the last FC layer to one with fewer output values.
def create_model(model_path: Path):
"""Creates the ResNet-50 model and loads the pretrained weights"""
model = models.resnet50()
# Update the last FC layer for Tiny ImageNet number of classes.
NUM_CLASSES = 200
model.fc = nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
model.to(device)
if model_path.exists():
checkpoint = torch.load(str(model_path), map_location="cpu")
model.load_state_dict(checkpoint["state_dict"], strict=True)
else:
raise RuntimeError("There is no checkpoint to load")
return model
model = create_model(MODEL_DIR / fp32_checkpoint_filename)
Create train and validation dataloaders¶
def create_dataloaders(batch_size: int = 128):
"""Creates train dataloader that is used for quantization initialization and validation dataloader for computing the model accruacy"""
train_dir = DATASET_DIR / "train"
val_dir = DATASET_DIR / "val" / "images"
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose(
[
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
normalize,
]
),
)
val_dataset = datasets.ImageFolder(
val_dir,
transforms.Compose(
[transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), normalize]
),
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
sampler=None,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
return train_loader, val_loader
train_loader, val_loader = create_dataloaders()
Now that the preparation of the training and validation pipelines, and the model files within this notebook is done, it is time to perform actual post-training quantization with NNCF.
I. Evaluate the loaded model¶
acc1 = validate(val_loader, model)
print(f"Test accuracy of FP32 model: {acc1:.3f}")
Test: [ 0/79] Time 0.275 (0.275) Acc@1 81.25 (81.25) Acc@5 92.19 (92.19)
Test: [10/79] Time 0.239 (0.244) Acc@1 56.25 (66.97) Acc@5 86.72 (87.50)
Test: [20/79] Time 0.240 (0.243) Acc@1 67.97 (64.29) Acc@5 85.16 (87.35)
Test: [30/79] Time 0.240 (0.242) Acc@1 53.12 (62.37) Acc@5 77.34 (85.33)
Test: [40/79] Time 0.241 (0.242) Acc@1 67.19 (60.86) Acc@5 90.62 (84.51)
Test: [50/79] Time 0.238 (0.242) Acc@1 60.16 (60.80) Acc@5 88.28 (84.42)
Test: [60/79] Time 0.264 (0.242) Acc@1 66.41 (60.46) Acc@5 86.72 (83.79)
Test: [70/79] Time 0.240 (0.241) Acc@1 52.34 (60.21) Acc@5 80.47 (83.33)
* Acc@1 60.740 Acc@5 83.960
Test accuracy of FP32 model: 60.740
Export the FP32
model to the ONNX, which is supported by OpenVINO
Toolkit, to benchmark it in comparison with the INT8
model.
dummy_input = torch.randn(1, 3, *IMAGE_SIZE).to(device)
torch.onnx.export(model, dummy_input, fp32_onnx_path, opset_version=10)
print(f"FP32 ONNX model was exported to {fp32_onnx_path}.")
FP32 ONNX model was exported to output/resnet50_fp32.onnx.
II. Create and initialize quantization¶
NNCF enables post-training quantization by adding the quantization layers into the model graph and then using a subset of the training dataset to initialize the parameters of these additional quantization layers. The framework is designed so that modifications to your original training code are minor. Quantization is the simplest scenario and requires only 3 modifications.
Configure the NNCF parameters to specify compression:
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, *IMAGE_SIZE]},
"log_dir": str(OUTPUT_DIR),
"compression": {
"algorithm": "quantization",
"initializer": {
"range": {"num_init_samples": 15000},
"batchnorm_adaptation": {"num_bn_adaptation_samples": 4000},
},
},
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
Provide a data loader to initialize the values of quantization ranges and determine which activation should be signed or unsigned from the collected statistics, using a given number of samples.
nncf_config = register_default_init_args(nncf_config, train_loader)
Create a quantized model from a pre-trained
FP32
model and a configuration object.
compression_ctrl, model = create_compressed_model(model, nncf_config)
Evaluate the new model on the validation set after initialization of quantization. The accuracy should be close to the accuracy of the floating-point
FP32
model for a simple case like the one being demonstrated now.
acc1 = validate(val_loader, model)
print(f"Accuracy of initialized INT8 model: {acc1:.3f}")
Test: [ 0/79] Time 0.381 (0.381) Acc@1 82.03 (82.03) Acc@5 91.41 (91.41)
Test: [10/79] Time 0.375 (0.375) Acc@1 57.81 (68.11) Acc@5 86.72 (87.22)
Test: [20/79] Time 0.377 (0.375) Acc@1 67.97 (64.32) Acc@5 86.72 (87.05)
Test: [30/79] Time 0.375 (0.375) Acc@1 53.12 (62.10) Acc@5 75.00 (85.16)
Test: [40/79] Time 0.374 (0.375) Acc@1 67.19 (60.88) Acc@5 88.28 (84.28)
Test: [50/79] Time 0.375 (0.375) Acc@1 58.59 (60.71) Acc@5 85.16 (84.10)
Test: [60/79] Time 0.374 (0.375) Acc@1 66.41 (60.39) Acc@5 83.59 (83.40)
Test: [70/79] Time 0.375 (0.375) Acc@1 54.69 (60.26) Acc@5 78.91 (83.08)
* Acc@1 60.770 Acc@5 83.620
Accuracy of initialized INT8 model: 60.770
Export an
INT8
model to ONNX
warnings.filterwarnings("ignore", category=TracerWarning) # Ignore export warnings
warnings.filterwarnings("ignore", category=UserWarning)
compression_ctrl.export_model(int8_onnx_path)
print(f"INT8 ONNX model exported to {int8_onnx_path}.")
INT8 ONNX model exported to output/resnet50_int8.onnx.
III. Convert ONNX models to OpenVINO Intermediate Representation (OpenVINO IR)¶
Use Model Optimizer to convert the ONNX model to OpenVINO IR, with
FP16
precision. The models are saved to the current directory. Then,
add the mean values to the model and scale the output with the standard
deviation by --mean_values
and --scale_values
arguments. It is
not necessary to normalize input data before propagating it through the
network with these options.
For more information about Model Optimizer, refer to the Model Optimizer Developer Guide.
Executing the following command may take a while. There may be some
errors or warnings in the output. When Model Optimizer converts the
model to OpenVINO IR successfully, the last lines of the output will
include: [ SUCCESS ] Generated IR version 11 model
.
input_shape = [1, 3, *IMAGE_SIZE]
if not fp32_ir_path.exists():
!mo --input_model "$fp32_onnx_path" --input_shape "$input_shape" --mean_values "[123.675, 116.28 , 103.53]" --scale_values "[58.395, 57.12 , 57.375]" --compress_to_fp16 --output_dir "$OUTPUT_DIR"
assert fp32_ir_path.exists(), "The OpenVINO IR of FP32 model wasn't created"
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai/latest/openvino_2_0_transition_guide.html
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-369/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_fp32.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-369/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_fp32.bin
if not int8_ir_path.exists():
!mo --input_model "$int8_onnx_path" --input_shape "$input_shape" --mean_values "[123.675, 116.28 , 103.53]" --scale_values "[58.395, 57.12 , 57.375]" --compress_to_fp16 --output_dir "$OUTPUT_DIR"
assert int8_ir_path.exists(), "The OpenVINO IR of INT8 model wasn't created"
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai/latest/openvino_2_0_transition_guide.html
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-369/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_int8.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-369/.workspace/scm/ov-notebook/notebooks/112-pytorch-post-training-quantization-nncf/output/resnet50_int8.bin
IV. Compare perfomance of INT8 model and FP32 model in OpenVINO¶
Finally, measure the inference performance of the FP32
and INT8
models, using Benchmark
Tool
- an inference performance measurement tool in OpenVINO. By default,
Benchmark Tool runs inference for 60 seconds in asynchronous mode on
CPU. It returns inference speed as latency (milliseconds per image) and
throughput (frames per second) values.
Note: This notebook runs benchmark_app for 15 seconds to give a quick indication of performance. For more accurate performance, it is recommended to run benchmark_app in a terminal/command prompt after closing other applications. Run
benchmark_app -m model.xml -d CPU
to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Runbenchmark_app --help
to see an overview of all command-line options.
def parse_benchmark_output(benchmark_output: str):
"""Prints the output from benchmark_app in human-readable format"""
parsed_output = [line for line in benchmark_output if 'FPS' in line]
print(*parsed_output, sep='\n')
print('Benchmark FP32 model (OpenVINO IR)')
benchmark_output = ! benchmark_app -m "$fp32_ir_path" -d CPU -api async -t 15
parse_benchmark_output(benchmark_output)
print('Benchmark INT8 model (OpenVINO IR)')
benchmark_output = ! benchmark_app -m "$int8_ir_path" -d CPU -api async -t 15
parse_benchmark_output(benchmark_output)
Benchmark FP32 model (OpenVINO IR)
[ INFO ] Throughput: 1255.77 FPS
Benchmark INT8 model (OpenVINO IR)
[ INFO ] Throughput: 1195.87 FPS
Show CPU Information for reference:
ie = Core()
ie.get_property("CPU", "FULL_DEVICE_NAME")
'Intel(R) Core(TM) i9-10920X CPU @ 3.50GHz'