Photos to Anime with PaddleGAN and OpenVINO

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS. To run without installing anything, click the launch binder button.

Binder Github

This tutorial demonstrates converting a PaddlePaddle/PaddleGAN AnimeGAN model to OpenVINO IR format, and shows inference results on the PaddleGAN and IR models.

For more information about the AnimeGAN model, see PaddleGAN’s AnimeGAN documentation

Imports

import sys
import time
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
from openvino.inference_engine import IECore

# PaddlePaddle requires a C++ compiler. If importing the paddle packages fails, please
# install C++
try:
    import paddle
    from paddle.static import InputSpec
    from ppgan.apps import AnimeGANPredictor
except NameError:
    if sys.platform == "win32":
        install_message = (
            "To use this notebook, please install the free Microsoft "
            "Visual C++ redistributable from <a href='https://aka.ms/vs/16/release/vc_redist.x64.exe'>"
            "https://aka.ms/vs/16/release/vc_redist.x64.exe</a>"
        )
    else:
        install_message = (
            "To use this notebook, please install a C++ compiler. On macOS, "
            "`xcode-select --install` installs many developer tools, including C++. On Linux, "
            "install gcc with your distribution's package manager."
        )
    display(
        HTML(
            f"""<div class="alert alert-danger" ><i>
    <b>Error: </b>PaddlePaddle requires installation of C++. {install_message}"""
        )
    )
    raise

Settings

MODEL_DIR = "model"
MODEL_NAME = "paddlegan_anime"

os.makedirs(MODEL_DIR, exist_ok=True)

# Create filenames of the models that will be converted in this notebook.
model_path = Path(f"{MODEL_DIR}/{MODEL_NAME}")
ir_path = model_path.with_suffix(".xml")
onnx_path = model_path.with_suffix(".onnx")

Functions

def resize_to_max_width(image, max_width):
    """
    Resize `image` to `max_width`, preserving the aspect ratio of the image.
    """
    if image.shape[1] > max_width:
        hw_ratio = image.shape[0] / image.shape[1]
        new_height = int(max_width * hw_ratio)
        image = cv2.resize(image, (max_width, new_height))
    return image

Inference on PaddleGAN Model

The PaddleGAN documentation explains to run the model with .run(). Let’s see what that function does with Jupyter’s ?? shortcut to show the docstring and source of the function.

# This cell will initialize the AnimeGANPredictor() and download the weights from PaddlePaddle.
# This may take a while. The weights are stored in a cache and are only downloaded once.
predictor = AnimeGANPredictor()
[02/18 22:45:21] ppgan INFO: Found /opt/home/k8sworker/.cache/ppgan/animeganv2_hayao.pdparams
# In a Jupyter Notebook, ?? shows the source and docstring
predictor.run??

The AnimeGANPredictor.run() method:

  1. loads an image with OpenCV and converts it to RGB

  2. transforms the image

  3. propagates the transformed image through the generator model and postprocesses the results to return an array with a [0,255] range

  4. transposes the result from (C,H,W) to (H,W,C) shape

  5. resizes the result image to the original image size

  6. optionally adjusts the brightness of the result image

  7. saves the image

We can execute these steps manually and confirm that the result looks correct. To speed up inference time, resize large images before propagating them through the network. The inference step in the next cell will still take some time to execute. If you want to skip this step, set PADDLEGAN_INFERENCE = False in the first line of the next cell.

PADDLEGAN_INFERENCE = True
OUTPUT_DIR = "output"

os.makedirs(OUTPUT_DIR, exist_ok=True)
# Step 1. Load the image and convert to RGB
image_path = Path("data/coco_bricks.png")

image = cv2.cvtColor(cv2.imread(str(image_path), flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)

## Inference takes a long time on large images. Resize to a max width of 600
image = resize_to_max_width(image, 600)

# Step 2. Transform the image
transformed_image = predictor.transform(image)
input_tensor = paddle.to_tensor(transformed_image[None, ::])

if PADDLEGAN_INFERENCE:
    # Step 3. Do inference.
    predictor.generator.eval()
    with paddle.no_grad():
        result = predictor.generator(input_tensor)

    # Step 4. Convert the inference result to an image following the same steps as
    # PaddleGAN's predictor.run() function
    result_image_pg = (result * 0.5 + 0.5)[0].numpy() * 255
    result_image_pg = result_image_pg.transpose((1, 2, 0))

    # Step 5. Resize the result image
    result_image_pg = cv2.resize(result_image_pg, image.shape[:2][::-1])

    # Step 6. Adjust the brightness
    result_image_pg = predictor.adjust_brightness(result_image_pg, image)

    # Step 7. Save the result image
    anime_image_path_pg = Path(f"{OUTPUT_DIR}/{image_path.stem}_anime_pg").with_suffix(".jpg")
    if cv2.imwrite(str(anime_image_path_pg), result_image_pg[:, :, (2, 1, 0)]):
        print(f"The anime image was saved to {anime_image_path_pg}")
The anime image was saved to output/coco_bricks_anime_pg.jpg

Show Inference Results on PaddleGAN model

if PADDLEGAN_INFERENCE:
    fig, ax = plt.subplots(1, 2, figsize=(25, 15))
    ax[0].imshow(image)
    ax[1].imshow(result_image_pg)
else:
    print("PADDLEGAN_INFERENCE is not enabled. Set PADDLEGAN_INFERENCE = True in the previous cell and run that cell to show inference results.")
../_images/206-vision-paddlegan-anime-with-output_13_0.png

Model Conversion to ONNX and IR

We convert the PaddleGAN model to OpenVINO IR by first converting PaddleGAN to ONNX with paddle2onnx and then converting the ONNX model to IR with OpenVINO’s Model Optimizer.

Convert to ONNX

Exporting to ONNX requires specifying an input shape with PaddlePaddle’s InputSpec and calling paddle.onnx.export. We check the input shape of the transformed image and use that as input shape for the ONNX model. Exporting to ONNX should not take long. If exporting succeeded, the output of the next cell will include ONNX model saved in paddlegan_anime.onnx.

target_height, target_width = transformed_image.shape[1:]
target_height, target_width
(448, 576)
predictor.generator.eval()
x_spec = InputSpec([None, 3, target_height, target_width], "float32", "x")
paddle.onnx.export(predictor.generator, str(model_path), input_spec=[x_spec], opset_version=11)
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working
  return (isinstance(seq, collections.Sequence) and
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/onnx/helper.py:343: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working
  is_iterable = isinstance(value, collections.Iterable)
2022-02-18 22:45:28 [INFO]  ONNX model saved in model/paddlegan_anime.onnx

Convert to IR

The OpenVINO IR format allows storing the preprocessing normalization in the model file. It is then no longer necessary to normalize input images manually. Let’s check the transforms that the .run() method used:

predictor.__init__??
t = predictor.transform.transforms[0]
t.params
{'taget_size': (448, 576)}
## Uncomment the line below to see the documentation and code of the ResizeToScale transformation
# t??

There are three transformations: resize, transpose, and normalize, where normalize uses a mean and scale of [127.5, 127.5, 127.5].

The ResizeToScale class is called with (256,256) as argument for size. Further analysis shows that this is the minimum size to resize to. The ResizeToScale transform resizes images to the size specified in the ResizeToScale params, with width and height as multiples of 32.

Now that we know the mean and standard deviation values, and the shape of the model inputs, we can call Model Optimizer and convert the model to IR with these values. We use FP16 precision and set log level to CRITICAL to ignore warnings that are irrelevant for this demo. See the Model Optimizer Documentation for information about Model Optimizer parameters.

Convert Model to IR with Model Optimizer

onnx_path = model_path.with_suffix(".onnx")
print("Exporting ONNX model to IR... This may take a few minutes.")
!mo --input_model $onnx_path --output_dir $MODEL_DIR --input_shape [1,3,$target_height,$target_width] --model_name $MODEL_NAME --data_type "FP16" --mean_values="[127.5,127.5,127.5]" --scale_values="[127.5,127.5,127.5]" --log_level "CRITICAL"
Exporting ONNX model to IR... This may take a few minutes.
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/206-vision-paddlegan-anime/model/paddlegan_anime.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/206-vision-paddlegan-anime/model
    - IR output name:   paddlegan_anime
    - Log level:    CRITICAL
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,3,448,576]
    - Mean values:  [127.5,127.5,127.5]
    - Scale values:     [127.5,127.5,127.5]
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - Enable grouped convolutions fusing:   True
    - Move mean values to preprocess section:   None
    - Reverse input channels:   False
ONNX specific parameters:
    - Inference Engine found in:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
Inference Engine version:   2021.4.2-3976-0943ed67223-refs/pull/539/head
Model Optimizer version:    2021.4.2-3976-0943ed67223-refs/pull/539/head
[ SUCCESS ] Generated IR version 10 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/206-vision-paddlegan-anime/model/paddlegan_anime.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-80/.workspace/scm/ov-notebook/notebooks/206-vision-paddlegan-anime/model/paddlegan_anime.bin
[ SUCCESS ] Total execution time: 10.97 seconds.
[ SUCCESS ] Memory consumed: 112 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*

Show Inference Results on IR and PaddleGAN Models

If the Model Optimizer output in the cell above showed SUCCESS, model conversion succeeded and the IR model is generated.

We can use the model for inference now with the adjust_brightness() method from the PaddleGAN model, but in order to use the IR model without installing PaddleGAN, it is useful to check what these functions do and extract them.

Create Postprocessing Functions

predictor.adjust_brightness??
predictor.calc_avg_brightness??

The average brightness is computed by a standard formula, see https://www.w3.org/TR/AERT/#color-contrast. To adjust the brightness, the difference in brightness between the source and destination (anime) image is computed and the brightness of the destination image is adjusted based on that. The image is then converted to an 8-bit image.

We copy these functions to the next cell, and will use them for inference on the IR model

# Copyright (c) 2020 PaddlePaddle Authors. Licensed under the Apache License, Version 2.0


def calc_avg_brightness(img):
    R = img[..., 0].mean()
    G = img[..., 1].mean()
    B = img[..., 2].mean()

    brightness = 0.299 * R + 0.587 * G + 0.114 * B
    return brightness, B, G, R


def adjust_brightness(dst, src):
    brightness1, B1, G1, R1 = AnimeGANPredictor.calc_avg_brightness(src)
    brightness2, B2, G2, R2 = AnimeGANPredictor.calc_avg_brightness(dst)
    brightness_difference = brightness1 / brightness2
    dstf = dst * brightness_difference
    dstf = np.clip(dstf, 0, 255)
    dstf = np.uint8(dstf)
    return dstf

Do Inference on IR Model

Load the IR model, and do inference, following the same steps as for the PaddleGAN model. See the OpenVINO Inference Engine API notebook for more information about inference on IR models.

The IR model is generated with an input shape that is computed based on the input image. If you do inference on images with different input shapes, results may differ from the PaddleGAN results.

# Load and prepare the IR model.
ie = IECore()
net = ie.read_network(ir_path)
exec_net = ie.load_network(net, "CPU")
input_key = next(iter(net.input_info.keys()))
output_key = next(iter(net.outputs.keys()))
# Step 1. Load an image and convert to RGB
image_path = Path("data/coco_bricks.png")
image = cv2.cvtColor(cv2.imread(str(image_path), flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)

# Step 2. Transform the image (only resize and transpose are still required)
resized_image = cv2.resize(image, (target_width, target_height))
input_image = resized_image.transpose(2, 0, 1)[None, :, :, :]

# Step 3. Do inference.
result_ir = exec_net.infer({input_key: input_image})

# Step 4. Convert the inference result to an image following the same steps as
# PaddleGAN's predictor.run() function
result_image_ir = (result_ir[output_key] * 0.5 + 0.5)[0] * 255
result_image_ir = result_image_ir.transpose((1, 2, 0))

# Step 5. Resize the result image
result_image_ir = cv2.resize(result_image_ir, image.shape[:2][::-1])

# Step 6. Adjust the brightness
result_image_ir = adjust_brightness(result_image_ir, image)

# Step 7. Save the result image
anime_fn_ir = Path(f"{OUTPUT_DIR}/{image_path.stem}_anime_ir").with_suffix(".jpg")
if cv2.imwrite(str(anime_fn_ir), result_image_ir[:, :, (2, 1, 0)]):
    print(f"The anime image was saved to {anime_fn_ir}")
The anime image was saved to output/coco_bricks_anime_ir.jpg

Show Inference Results

fig, ax = plt.subplots(1, 2, figsize=(25, 15))
ax[0].imshow(image)
ax[1].imshow(result_image_ir)
ax[0].set_title("Image")
ax[1].set_title("OpenVINO IR result");
../_images/206-vision-paddlegan-anime-with-output_36_0.png

Performance Comparison

Measure the time it takes to do inference on an image. This gives an indication of performance. It is not a perfect measure. Since the PaddleGAN model requires quite a bit of memory for inference, we only measure inference on one image. For more accurate benchmarking, use the OpenVINO benchmark tool

NUM_IMAGES = 1
start = time.perf_counter()
for _ in range(NUM_IMAGES):
    exec_net.infer(inputs={input_key: input_image})
end = time.perf_counter()
time_ir = end - start
print(
    f"IR model in Inference Engine/CPU: {time_ir/NUM_IMAGES:.3f} "
    f"seconds per image, FPS: {NUM_IMAGES/time_ir:.2f}"
)

## Uncomment the lines below to measure inference time on an Intel iGPU.
## Note that it will take some time to load the model to the GPU

# if "GPU" in ie.available_devices:
#     # Loading the IR model on the GPU takes some time
#     exec_net_multi = ie.load_network(net, "GPU")
#     start = time.perf_counter()
#     for _ in range(NUM_IMAGES):
#         exec_net_multi.infer(inputs={input_key: input_image})
#     end = time.perf_counter()
#     time_ir = end - start
#     print(
#         f"IR model in Inference Engine/GPU: {time_ir/NUM_IMAGES:.3f} "
#         f"seconds per image, FPS: {NUM_IMAGES/time_ir:.2f}"
#     )
# else:
#     print("A supported iGPU device is not available on this system.")

## PADDLEGAN_INFERENCE is defined in the section "Inference on PaddleGAN model"
## Uncomment the next line to enable a performance comparison with the PaddleGAN model
## if you disabled it earlier.

# PADDLEGAN_INFERENCE = True

if PADDLEGAN_INFERENCE:
    with paddle.no_grad():
        start = time.perf_counter()
        for _ in range(NUM_IMAGES):
            predictor.generator(input_tensor)
        end = time.perf_counter()
        time_paddle = end - start
    print(
        f"PaddleGAN model on CPU: {time_paddle/NUM_IMAGES:.3f} seconds per image, "
        f"FPS: {NUM_IMAGES/time_paddle:.2f}"
    )
IR model in Inference Engine/CPU: 0.489 seconds per image, FPS: 2.05
PaddleGAN model on CPU: 5.571 seconds per image, FPS: 0.18

References

The PaddleGAN code that is shown in this notebook is written by PaddlePaddle Authors and licensed under the Apache 2.0 license. The license for this code is displayed below.

#  Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.