Super Resolution with PaddleGAN and OpenVINO

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS. To run without installing anything, click the launch binder button.

Binder Github

This notebook demonstrates converting the RealSR (real-world super-resolution) model from PaddlePaddle/PaddleGAN to OpenVINO’s Intermediate Representation (IR) format, and shows inference results on both the PaddleGAN and IR models.

For more information about the various PaddleGAN superresolution models, see PaddleGAN’s documentation. For more information about RealSR, see the research paper from CVPR 2020.

This notebook works best with small images (up to 800x600).

Imports

import sys
import time
import warnings
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
from IPython.display import HTML, FileLink, ProgressBar, clear_output, display
from IPython.display import Image as DisplayImage
from PIL import Image
from openvino.runtime import Core, PartialShape
from paddle.static import InputSpec
from ppgan.apps import RealSRPredictor

sys.path.append("../utils")
from notebook_utils import NotebookAlert
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/vision/transforms/functional_pil.py:36: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  'nearest': Image.NEAREST,
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/vision/transforms/functional_pil.py:37: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  'bilinear': Image.BILINEAR,
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/vision/transforms/functional_pil.py:38: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  'bicubic': Image.BICUBIC,
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/vision/transforms/functional_pil.py:39: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead.
  'box': Image.BOX,
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/vision/transforms/functional_pil.py:40: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
  'lanczos': Image.LANCZOS,
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/vision/transforms/functional_pil.py:41: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead.
  'hamming': Image.HAMMING

Settings

# The filenames of the downloaded and converted models
MODEL_NAME = "paddlegan_sr"
MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
OUTPUT_DIR.mkdir(exist_ok=True)

model_path = MODEL_DIR / MODEL_NAME
ir_path = model_path.with_suffix(".xml")
onnx_path = model_path.with_suffix(".onnx")

Inference on PaddlePaddle Model

Investigate PaddleGAN Model

The PaddleGAN documentation explains to run the model with sr.run(). Let’s see what that function does, and check other relevant functions that are called from that function. Adding ?? to the methods shows the docstring and source code.

# Running this cell will download the model weights if they have not been downloaded before
# This may take a while
sr = RealSRPredictor()
[07/13 22:49:38] ppgan INFO: Found /opt/home/k8sworker/.cache/ppgan/DF2K_JPEG.pdparams
sr.run??
sr.run_image??
sr.norm??
sr.denorm??

The run checks whether the input is an image or a video. For an image, it loads the image as an RGB image, normalizes it, and converts it to a Paddle tensor. It is propagated to the network by calling self.model() and then “denormalized”. The normalization function simply divides all image values by 255. This converts an image with integer values in the range of 0 to 255 to an image with floating point values in the range of 0 to 1. The denormalization function transforms the output from network shape (C,H,W) to image shape (H,W,C). It then clips the image values between 0 and 255, and converts the image to a standard RGB image with integer values in the range of 0 to 255.

To get more information about the model, we can check what it looks like with sr.model??.

# sr.model??

Do Inference

To show inference on the PaddlePaddle model, set PADDLEGAN_INFERENCE to True in the cell below. Performing inference may take some time.

# Set PADDLEGAN_INFERENCE to True to show inference on the PaddlePaddle model.
# This may take a long time, especially for larger images.
#
PADDLEGAN_INFERENCE = False
if PADDLEGAN_INFERENCE:
    # load the input image and convert to tensor with input shape
    IMAGE_PATH = Path("data/coco_tulips.jpg")
    image = cv2.cvtColor(cv2.imread(str(IMAGE_PATH)), cv2.COLOR_BGR2RGB)
    input_image = image.transpose(2, 0, 1)[None, :, :, :] / 255
    input_tensor = paddle.to_tensor(input_image.astype(np.float32))
    if max(image.shape) > 400:
        NotebookAlert(
            f"This image has shape {image.shape}. Doing inference will be slow "
            "and the notebook may stop responding. Set PADDLEGAN_INFERENCE to False "
            "to skip doing inference on the PaddlePaddle model.",
            "warning",
        )
if PADDLEGAN_INFERENCE:
    # Do inference, and measure how long it takes
    print(f"Start superresolution inference for {IMAGE_PATH.name} with shape {image.shape}...")
    start_time = time.perf_counter()
    sr.model.eval()
    with paddle.no_grad():
        result = sr.model(input_tensor)
    end_time = time.perf_counter()
    duration = end_time - start_time
    result_image = (
        (result.numpy().squeeze() * 255).clip(0, 255).astype("uint8").transpose((1, 2, 0))
    )
    print(f"Superresolution image shape: {result_image.shape}")
    print(f"Inference duration: {duration:.2f} seconds")
    plt.imshow(result_image);

Convert PaddleGAN Model to ONNX and OpenVINO IR

To convert the PaddlePaddle model to OpenVINO IR, we first convert the model to ONNX, and then convert the ONNX model to the IR format.

Convert PaddlePaddle Model to ONNX

# Ignore PaddlePaddle warnings:
# The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1)
warnings.filterwarnings("ignore")
sr.model.eval()
# ONNX export requires an input shape in this format as parameter
x_spec = InputSpec([None, 3, 299, 299], "float32", "x")
paddle.onnx.export(sr.model, str(model_path), input_spec=[x_spec], opset_version=13)
2022-07-13 22:49:44 [INFO]  ONNX model saved in model/paddlegan_sr.onnx

Convert ONNX Model to OpenVINO IR

## Uncomment the command below to show Model Optimizer help, which shows the possible arguments for Model Optimizer
# ! mo --help
if not ir_path.exists():
    print("Exporting ONNX model to IR... This may take a few minutes.")
    ! mo --input_model $onnx_path --input_shape "[1,3,299,299]" --model_name $MODEL_NAME --output_dir "$MODEL_DIR" --data_type "FP16" --log_level "CRITICAL"
Exporting ONNX model to IR... This may take a few minutes.
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/207-vision-paddlegan-superresolution/model/paddlegan_sr.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/207-vision-paddlegan-superresolution/model
    - IR output name:   paddlegan_sr
    - Log level:    CRITICAL
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,3,299,299]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  Not specified
    - Scale values:     Not specified
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/207-vision-paddlegan-superresolution/model/paddlegan_sr.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-188/.workspace/scm/ov-notebook/notebooks/207-vision-paddlegan-superresolution/model/paddlegan_sr.bin
[ SUCCESS ] Total execution time: 1.51 seconds.
[ SUCCESS ] Memory consumed: 275 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Do Inference on IR Model

# Read network and get input and output names
ie = Core()
model = ie.read_model(model=ir_path)
input_layer = model.input(0)
# Load and show image
IMAGE_PATH = Path("data/coco_tulips.jpg")
image = cv2.cvtColor(cv2.imread(str(IMAGE_PATH)), cv2.COLOR_BGR2RGB)
if max(image.shape) > 800:
    NotebookAlert(
        f"This image has shape {image.shape}. The notebook works best with images with "
        "a maximum side of 800x600. Larger images may work well, but inference may "
        "be slow",
        "warning",
    )
plt.imshow(image)
<matplotlib.image.AxesImage at 0x7f13c7cf12e0>
../_images/207-vision-paddlegan-superresolution-with-output_23_1.png
# Reshape network to image size
model.reshape({input_layer.any_name: PartialShape([1, 3, image.shape[0], image.shape[1]])})
# Load network to the CPU device (this may take a few seconds)
compiled_model = ie.compile_model(model=model, device_name="CPU")
output_layer = compiled_model.output(0)
# Convert image to network input shape and divide pixel values by 255
# See "Investigate PaddleGAN model" section
input_image = image.transpose(2, 0, 1)[None, :, :, :] / 255
start_time = time.perf_counter()
# Do inference
ir_result = compiled_model([input_image])[output_layer]
end_time = time.perf_counter()
duration = end_time - start_time
print(f"Inference duration: {duration:.2f} seconds")
Inference duration: 2.99 seconds
# Get result array in CHW format
result_array = ir_result.squeeze()
# Convert array to image with same method as PaddleGAN:
# Multiply by 255, clip values between 0 and 255, convert to HWC INT8 image
# See "Investigate PaddleGAN model" section
image_super = (result_array * 255).clip(0, 255).astype("uint8").transpose((1, 2, 0))
# Resize image with bicubic upsampling for comparison
image_bicubic = cv2.resize(image, tuple(image_super.shape[:2][::-1]), interpolation=cv2.INTER_CUBIC)
plt.imshow(image_super)
<matplotlib.image.AxesImage at 0x7f13c8850580>
../_images/207-vision-paddlegan-superresolution-with-output_27_1.png

Show Animated GIF

To visualize the difference between the bicubic image and the superresolution image, we create an imated gif that switches between both versions.

result_pil = Image.fromarray(image_super)
bicubic_pil = Image.fromarray(image_bicubic)
gif_image_path = OUTPUT_DIR / Path(IMAGE_PATH.stem + "_comparison.gif")
final_image_path = OUTPUT_DIR / Path(IMAGE_PATH.stem + "_super.png")

result_pil.save(
    fp=str(gif_image_path),
    format="GIF",
    append_images=[bicubic_pil],
    save_all=True,
    duration=1000,
    loop=0,
)

result_pil.save(fp=str(final_image_path), format="png")
DisplayImage(open(gif_image_path, "rb").read(), width=1920 // 2)
../_images/207-vision-paddlegan-superresolution-with-output_29_0.png

Create Comparison Video

Create a video with a “slider”, showing the bicubic image to the right and the superresolution image on the left.

For the video, the superresolution and bicubic image are resized to half the original width and height, to improve processing speed. This gives an indication of the superresolution effect. The video is saved as an .avi video. You can click on the link to download the video, or open it directly from the images directory, and play it locally.

FOURCC = cv2.VideoWriter_fourcc(*"MJPG")
IMAGE_PATH = Path(IMAGE_PATH)
result_video_path = OUTPUT_DIR / Path(f"{IMAGE_PATH.stem}_comparison_paddlegan.avi")
video_target_height, video_target_width = (
    image_super.shape[0] // 2,
    image_super.shape[1] // 2,
)

out_video = cv2.VideoWriter(
    str(result_video_path),
    FOURCC,
    90,
    (video_target_width, video_target_height),
)

resized_result_image = cv2.resize(image_super, (video_target_width, video_target_height))[
    :, :, (2, 1, 0)
]
resized_bicubic_image = cv2.resize(image_bicubic, (video_target_width, video_target_height))[
    :, :, (2, 1, 0)
]

progress_bar = ProgressBar(total=video_target_width)
progress_bar.display()

for i in range(2, video_target_width):
    # Create a frame where the left part (until i pixels width) contains the
    # superresolution image, and the right part (from i pixels width) contains
    # the bicubic image
    comparison_frame = np.hstack(
        (
            resized_result_image[:, :i, :],
            resized_bicubic_image[:, i:, :],
        )
    )

    # create a small black border line between the superresolution
    # and bicubic part of the image
    comparison_frame[:, i - 1 : i + 1, :] = 0
    out_video.write(comparison_frame)
    progress_bar.progress = i
    progress_bar.update()
out_video.release()
clear_output()

video_link = FileLink(result_video_path)
video_link.html_link_str = "<a href='%s' download>%s</a>"
display(HTML(f"The video has been saved to {video_link._repr_html_()}"))
The video has been saved to output/coco_tulips_comparison_paddlegan.avi