Image Background Removal with U^2-Net and OpenVINO™

This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS. To run without installing anything, click the launch binder button.

Binder Github

This notebook demonstrates background removal in images using U\(^2\)-Net and OpenVINO.

For more information about U\(^2\)-Net, including source code and test data, see the Github page and the research paper: U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection.

The PyTorch U\(^2\)-Net model is converted to ONNX and loaded with OpenVINO. The model source is here. For a more detailed overview of loading PyTorch models in OpenVINO, including how to load an ONNX model in OpenVINO directly, without converting to OpenVINO IR format, see the PyTorch/ONNX notebook.

Prepare

Import the PyTorch Library and U\(^2\)-Net

import os
import time
from collections import namedtuple
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import HTML, FileLink, display
from model.u2net import U2NET, U2NETP
from openvino.runtime import Core

Settings

This tutorial supports using the original U\(^2\)-Net salient object detection model, as well as the smaller U2NETP version. Two sets of weights are supported for the original model: salient object detection and human segmentation.

IMAGE_DIR = "data"
model_config = namedtuple("ModelConfig", ["name", "url", "model", "model_args"])

u2net_lite = model_config(
    name="u2net_lite",
    url="https://drive.google.com/uc?id=1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
    model=U2NETP,
    model_args=(),
)
u2net = model_config(
    name="u2net",
    url="https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
    model=U2NET,
    model_args=(3, 1),
)
u2net_human_seg = model_config(
    name="u2net_human_seg",
    url="https://drive.google.com/uc?id=1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
    model=U2NET,
    model_args=(3, 1),
)

# Set u2net_model to one of the three configurations listed above.
u2net_model = u2net_lite
# The filenames of the downloaded and converted models.
MODEL_DIR = "model"
model_path = Path(MODEL_DIR) / u2net_model.name / Path(u2net_model.name).with_suffix(".pth")
onnx_path = model_path.with_suffix(".onnx")
ir_path = model_path.with_suffix(".xml")

Load the U\(^2\)-Net Model

The U\(^2\)-Net human segmentation model weights are stored on Google Drive. They will be downloaded if they are not present yet. The next cell loads the model and the pre-trained weights.

if not model_path.exists():
    import gdown

    os.makedirs(name=model_path.parent, exist_ok=True)
    print("Start downloading model weights file... ")
    with open(model_path, "wb") as model_file:
        gdown.download(url=u2net_model.url, output=model_file)
        print(f"Model weights have been downloaded to {model_path}")
# Load the model.
net = u2net_model.model(*u2net_model.model_args)
net.eval()

# Load the weights.
print(f"Loading model weights from: '{model_path}'")
net.load_state_dict(state_dict=torch.load(model_path, map_location="cpu"))

# Save the model if it does not exist yet.
if not model_path.exists():
    print("\nSaving the model")
    torch.save(obj=net.state_dict(), f=str(model_path))
    print(f"Model saved at {model_path}")
Loading model weights from: 'model/u2net_lite/u2net_lite.pth'

Convert PyTorch U\(^2\)-Net model to ONNX and OpenVINO IR

Convert PyTorch model to ONNX

The output for this cell will show some warnings. These are most likely harmless. When the conversion succeeds, the last line of the output will read ONNX model exported to [filename].onnx.

if not onnx_path.exists():
    dummy_input = torch.randn(1, 3, 512, 512)
    torch.onnx.export(model=net, args=dummy_input, f=onnx_path, opset_version=11)
    print(f"ONNX model exported to {onnx_path}.")
else:
    print(f"ONNX model {onnx_path} already exists.")
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/nn/functional.py:3328: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/nn/functional.py:3454: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  warnings.warn(
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/nn/functional.py:1709: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
ONNX model exported to model/u2net_lite/u2net_lite.onnx.

Convert ONNX model to OpenVINO IR Format

Use Model Optimizer to convert the ONNX model to OpenVINO IR format, with FP16 precision. The models are saved to the current directory. Then, add the mean values to the model and scale the output with the standard deviation with --scale_values. With these options, it is not necessary to normalize input data before propagating it through the network. The mean and standard deviation values can be found in the dataloader file in the U^2-Net repository and multiplied by 255 to support images with pixel values from 0-255.

For more information, refer to the Model Optimizer Developer Guide.

Executing the following command may take a while. There may be some errors or warnings in the output. When the model optimization is successful, the last lines of the output will include [ SUCCESS ] Generated IR version 10 model.

# Construct the command for Model Optimizer.
# Set log_level to CRITICAL to suppress warnings that can be ignored for this demo.
mo_command = f"""mo
                 --input_model "{onnx_path}"
                 --input_shape "[1,3, 512, 512]"
                 --mean_values="[123.675, 116.28 , 103.53]"
                 --scale_values="[58.395, 57.12 , 57.375]"
                 --data_type FP16
                 --output_dir "{model_path.parent}"
                 --log_level "CRITICAL"
                 """
mo_command = " ".join(mo_command.split())
print("Model Optimizer command to convert the ONNX model to OpenVINO:")
print(mo_command)
Model Optimizer command to convert the ONNX model to OpenVINO:
mo --input_model "model/u2net_lite/u2net_lite.onnx" --input_shape "[1,3, 512, 512]" --mean_values="[123.675, 116.28 , 103.53]" --scale_values="[58.395, 57.12 , 57.375]" --data_type FP16 --output_dir "model/u2net_lite" --log_level "CRITICAL"
if not ir_path.exists():
    print("Exporting ONNX model to IR... This may take a few minutes.")
    mo_result = %sx $mo_command
    print("\n".join(mo_result))
else:
    print(f"IR model {ir_path} already exists.")
Exporting ONNX model to IR... This may take a few minutes.
Model Optimizer arguments:
Common parameters:
    - Path to the Input Model:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/notebooks/205-vision-background-removal/model/u2net_lite/u2net_lite.onnx
    - Path for generated IR:    /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/notebooks/205-vision-background-removal/model/u2net_lite
    - IR output name:   u2net_lite
    - Log level:    CRITICAL
    - Batch:    Not specified, inherited from the model
    - Input layers:     Not specified, inherited from the model
    - Output layers:    Not specified, inherited from the model
    - Input shapes:     [1,3, 512, 512]
    - Source layout:    Not specified
    - Target layout:    Not specified
    - Layout:   Not specified
    - Mean values:  [123.675, 116.28 , 103.53]
    - Scale values:     [58.395, 57.12 , 57.375]
    - Scale factor:     Not specified
    - Precision of IR:  FP16
    - Enable fusing:    True
    - User transformations:     Not specified
    - Reverse input channels:   False
    - Enable IR generation for fixed input shape:   False
    - Use the transformations config file:  None
Advanced parameters:
    - Force the usage of legacy Frontend of Model Optimizer for model conversion into IR:   False
    - Force the usage of new Frontend of Model Optimizer for model conversion into IR:  False
OpenVINO runtime found in:  /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino
OpenVINO runtime version:   2022.1.0-7019-cdb9bec7210-releases/2022/1
Model Optimizer version:    2022.1.0-7019-cdb9bec7210-releases/2022/1
[ SUCCESS ] Generated IR version 11 model.
[ SUCCESS ] XML file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/notebooks/205-vision-background-removal/model/u2net_lite/u2net_lite.xml
[ SUCCESS ] BIN file: /opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-231/.workspace/scm/ov-notebook/notebooks/205-vision-background-removal/model/u2net_lite/u2net_lite.bin
[ SUCCESS ] Total execution time: 0.77 seconds.
[ SUCCESS ] Memory consumed: 118 MB.
It's been a while, check for a new version of Intel(R) Distribution of OpenVINO(TM) toolkit here https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit/download.html?cid=other&source=prod&campid=ww_2022_bu_IOTG_OpenVINO-2022-1&content=upg_all&medium=organic or on the GitHub*
[ INFO ] The model was converted to IR v11, the latest model format that corresponds to the source DL framework input/output format. While IR v11 is backwards compatible with OpenVINO Inference Engine API v1.0, please use API v2.0 (as of 2022.1) to take advantage of the latest improvements in IR v11.
Find more information about API v2.0 and IR v11 at https://docs.openvino.ai

Load and Pre-Process Input Image

While OpenCV reads images in BGR format, the OpenVINO IR model expects images in RGB. Therefore, convert the images to RGB, resize them to 512 x 512 and transpose the dimensions to the format that is expected by the OpenVINO IR model.

IMAGE_PATH = Path(IMAGE_DIR) / "coco_hollywood.jpg"
image = cv2.cvtColor(
    src=cv2.imread(filename=str(IMAGE_PATH)),
    code=cv2.COLOR_BGR2RGB,
)

resized_image = cv2.resize(src=image, dsize=(512, 512))
# Convert the image shape to a shape and a data type expected by the network
# for OpenVINO IR model: (1, 3, 512, 512).
input_image = np.expand_dims(np.transpose(resized_image, (2, 0, 1)), 0)

Do Inference on OpenVINO IR Model

Load the OpenVINO IR model to OpenVINO Runtime and do inference.

# Load the network to OpenVINO Runtime.
ie = Core()
model_ir = ie.read_model(model=ir_path)
compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")
# Get the names of input and output layers.
input_layer_ir = compiled_model_ir.input(0)
output_layer_ir = compiled_model_ir.output(0)

# Do inference on the input image.
start_time = time.perf_counter()
result = compiled_model_ir([input_image])[output_layer_ir]
end_time = time.perf_counter()
print(
    f"Inference finished. Inference time: {end_time-start_time:.3f} seconds, "
    f"FPS: {1/(end_time-start_time):.2f}."
)
Inference finished. Inference time: 0.117 seconds, FPS: 8.52.

Visualize Results

Show the original image, the segmentation result, and the original image with the background removed.

# Resize the network result to the image shape and round the values
# to 0 (background) and 1 (foreground).
# The network result has (1,1,512,512) shape. The `np.squeeze` function converts this to (512, 512).
resized_result = np.rint(
    cv2.resize(src=np.squeeze(result), dsize=(image.shape[1], image.shape[0]))
).astype(np.uint8)

# Create a copy of the image and set all background values to 255 (white).
bg_removed_result = image.copy()
bg_removed_result[resized_result == 0] = 255

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 7))
ax[0].imshow(image)
ax[1].imshow(resized_result, cmap="gray")
ax[2].imshow(bg_removed_result)
for a in ax:
    a.axis("off")
../_images/205-vision-background-removal-with-output_19_0.png

Add a Background Image

In the segmentation result, all foreground pixels have a value of 1, all background pixels a value of 0. Replace the background image as follows:

  • Load a new background_image.

  • Resize this image to the same size as the original image.

  • In background_image, set all the pixels where the resized segmentation result has a value of 1 - the foreground pixels in the original image - to 0.

  • Add bg_removed_result from the previous step - the part of the original image that only contains foreground pixels - to background_image.

BACKGROUND_FILE = "data/wall.jpg"
OUTPUT_DIR = "output"

os.makedirs(name=OUTPUT_DIR, exist_ok=True)

background_image = cv2.cvtColor(src=cv2.imread(filename=BACKGROUND_FILE), code=cv2.COLOR_BGR2RGB)
background_image = cv2.resize(src=background_image, dsize=(image.shape[1], image.shape[0]))

# Set all the foreground pixels from the result to 0
# in the background image and add the image with the background removed.
background_image[resized_result == 1] = 0
new_image = background_image + bg_removed_result

# Save the generated image.
new_image_path = Path(f"{OUTPUT_DIR}/{IMAGE_PATH.stem}-{Path(BACKGROUND_FILE).stem}.jpg")
cv2.imwrite(filename=str(new_image_path), img=cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR))

# Display the original image and the image with the new background side by side
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18, 7))
ax[0].imshow(image)
ax[1].imshow(new_image)
for a in ax:
    a.axis("off")
plt.show()

# Create a link to download the image.
image_link = FileLink(new_image_path)
image_link.html_link_str = "<a href='%s' download>%s</a>"
display(
    HTML(
        f"The generated image <code>{new_image_path.name}</code> is saved in "
        f"the directory <code>{new_image_path.parent}</code>. You can also "
        "download the image by clicking on this link: "
        f"{image_link._repr_html_()}"
    )
)
../_images/205-vision-background-removal-with-output_21_0.png The generated image coco_hollywood-wall.jpg is saved in the directory output. You can also download the image by clicking on this link: output/coco_hollywood-wall.jpg