Visual Content Search using MobileCLIP and OpenVINO

This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:

Google ColabGithub

Semantic visual content search is a machine learning task that uses either a text query or an input image to search a database of images (photo gallery, video) to find images that are semantically similar to the search query. Historically, building a robust search engine for images was difficult. One could search by features such as file name and image metadata, and use any context around an image (i.e. alt text or surrounding text if an image appears in a passage of text) to provide the richer searching feature. This was before the advent of neural networks that can identify semantically related images to a given user query.

Contrastive Language-Image Pre-Training (CLIP) models provide the means through which you can implement a semantic search engine with a few dozen lines of code. The CLIP model has been trained on millions of pairs of text and images, encoding semantics from images and text combined. Using CLIP, you can provide a text query and CLIP will return the images most related to the query.

In this tutorial, we consider how to use MobileCLIP to implement a visual content search engine for finding relevant frames in video. #### Table of contents:

Prerequisites

## Prerequisites

from pathlib import Path

repo_dir = Path("./ml-mobileclip")

if not repo_dir.exists():
    !git clone https://github.com/apple/ml-mobileclip.git
Cloning into 'ml-mobileclip'...
remote: Enumerating objects: 45, done.
remote: Counting objects: 100% (45/45), done.
remote: Compressing objects: 100% (36/36), done.
remote: Total 45 (delta 9), reused 44 (delta 8), pack-reused 0
Unpacking objects: 100% (45/45), 428.50 KiB | 3.15 MiB/s, done.
%pip install -q "./ml-mobileclip" --no-deps

%pip install -q "clip-benchmark>=1.4.0" "datasets>=2.8.0" "open-clip-torch>=2.20.0" "timm>=0.9.5" "torch>=1.13.1" "torchvision>=0.14.1" --extra-index-url https://download.pytorch.org/whl/cpu

%pip install -q "openvino>=2024.0.0" "gradio>=4.19" "matplotlib" "Pillow"  "altair" "pandas" "opencv-python" "tqdm"
Note: you may need to restart the kernel to use updated packages.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mobileclip 0.1.0 requires torch==1.13.1, but you have torch 2.3.0+cpu which is incompatible.
mobileclip 0.1.0 requires torchvision==0.14.1, but you have torchvision 0.18.0+cpu which is incompatible.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Select model

For starting work, we should select model that will be used in our demonstration. By default, we will use the MobileCLIP model, but for comparison purposes, you can select different models among:

  • CLIP - CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on various (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task. CLIP uses a ViT like transformer to get visual features and a causal language model to get the text features. The text and visual features are then projected into a latent space with identical dimensions. The dot product between the projected image and text features is then used as a similarity score. You can find more information about this model in the research paper, OpenAI blog, model card and GitHub repository.

  • SigLIP - The SigLIP model was proposed in Sigmoid Loss for Language Image Pre-Training. SigLIP proposes to replace the loss function used in CLIP (Contrastive Language–Image Pre-training) by a simple pairwise sigmoid loss. This results in better performance in terms of zero-shot classification accuracy on ImageNet. You can find more information about this model in the research paper and GitHub repository,

  • MobileCLIP - MobileCLIP – a new family of efficient image-text models optimized for runtime performance along with a novel and efficient training approach, namely multi-modal reinforced training. The smallest variant MobileCLIP-S0 obtains similar zero-shot performance as OpenAI’s CLIP ViT-b16 model while being several times faster and 2.8x smaller. More details about model can be found in research paper and GitHub repository.

import ipywidgets as widgets

model_dir = Path("checkpoints")

supported_models = {
    "MobileCLIP": {
        "mobileclip_s0": {
            "model_name": "mobileclip_s0",
            "pretrained": model_dir / "mobileclip_s0.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt",
            "image_size": 256,
        },
        "mobileclip_s1": {
            "model_name": "mobileclip_s1",
            "pretrained": model_dir / "mobileclip_s1.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt",
            "image_size": 256,
        },
        "mobileclip_s2": {
            "model_name": "mobileclip_s0",
            "pretrained": model_dir / "mobileclip_s2.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt",
            "image_size": 256,
        },
        "mobileclip_b": {
            "model_name": "mobileclip_b",
            "pretrained": model_dir / "mobileclip_b.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt",
            "image_size": 224,
        },
        "mobileclip_blt": {
            "model_name": "mobileclip_b",
            "pretrained": model_dir / "mobileclip_blt.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt",
            "image_size": 224,
        },
    },
    "CLIP": {
        "clip-vit-b-32": {
            "model_name": "ViT-B-32",
            "pretrained": "laion2b_s34b_b79k",
            "image_size": 224,
        },
        "clip-vit-b-16": {
            "image_name": "ViT-B-16",
            "pretrained": "openai",
            "image_size": 224,
        },
        "clip-vit-l-14": {
            "image_name": "ViT-L-14",
            "pretrained": "datacomp_xl_s13b_b90k",
            "image_size": 224,
        },
        "clip-vit-h-14": {
            "image_name": "ViT-H-14",
            "pretrained": "laion2b_s32b_b79k",
            "image_size": 224,
        },
    },
    "SigLIP": {
        "siglip-vit-b-16": {
            "model_name": "ViT-B-16-SigLIP",
            "pretrained": "webli",
            "image_size": 224,
        },
        "siglip-vit-l-16": {
            "model_name": "ViT-L-16-SigLIP-256",
            "pretrained": "webli",
            "image_size": 256,
        },
    },
}


model_type = widgets.Dropdown(options=supported_models.keys(), default="MobileCLIP", description="Model type:")
model_type
Dropdown(description='Model type:', options=('MobileCLIP', 'CLIP', 'SigLIP'), value='MobileCLIP')
available_models = supported_models[model_type.value]

model_checkpoint = widgets.Dropdown(
    options=available_models.keys(),
    default=list(available_models),
    description="Model:",
)

model_checkpoint
Dropdown(description='Model:', options=('mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b', 'mo…
import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)

open("notebook_utils.py", "w").write(r.text)

from notebook_utils import download_file

model_config = available_models[model_checkpoint.value]

Run model inference

Now, let’s see model in action. We will try to find image, where some specific object is represented using embeddings. Embeddings are a numeric representation of data such as text and images. The model learned to encode semantics about the contents of images in embedding format. This ability turns the model into a powerful for solving various tasks including image-text retrieval. To reach our goal we should:

  1. Calculate embeddings for all of the images in our dataset;

  2. Calculate a text embedding for a user query (i.e. “black dog” or “car”);

  3. Compare the text embedding to the image embeddings to find related embeddings.

The closer two embeddings are, the more similar the contents they represent are.

Prepare model

The code bellow download model weights, create model class instance and preprocessing utilities

import torch
import time
from PIL import Image
import mobileclip
import open_clip

# instantiate model
model_name = model_config["model_name"]
pretrained = model_config["pretrained"]
if model_type.value == "MobileCLIP":
    model_dir.mkdir(exist_ok=True)
    model_url = model_config["url"]
    download_file(model_url, directory=model_dir)
    model, _, preprocess = mobileclip.create_model_and_transforms(model_name, pretrained=pretrained)
    tokenizer = mobileclip.get_tokenizer(model_name)
else:
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    tokenizer = open_clip.get_tokenizer(model_name)
checkpoints/mobileclip_s0.pt:   0%|          | 0.00/206M [00:00<?, ?B/s]

Convert Model to OpenVINO Intermediate Representation format

For best results with OpenVINO, it is recommended to convert the model to OpenVINO IR format. OpenVINO supports PyTorch via Model conversion API. To convert the PyTorch model to OpenVINO IR format we will use ov.convert_model of model conversion API. The ov.convert_model Python function returns an OpenVINO Model object ready to load on the device and start making predictions.

Our model consist from 2 parts - image encoder and text encoder that can be used separately. Let’s convert each part to OpenVINO.

import types
import torch.nn.functional as F


def se_block_forward(self, inputs):
    """Apply forward pass."""
    b, c, h, w = inputs.size()
    x = F.avg_pool2d(inputs, kernel_size=[8, 8])
    x = self.reduce(x)
    x = F.relu(x)
    x = self.expand(x)
    x = torch.sigmoid(x)
    x = x.view(-1, c, 1, 1)
    return inputs * x
import openvino as ov
import gc

ov_models_dir = Path("ov_models")
ov_models_dir.mkdir(exist_ok=True)

image_encoder_path = ov_models_dir / f"{model_checkpoint.value}_im_encoder.xml"

if not image_encoder_path.exists():
    if "mobileclip_s" in model_name:
        model.image_encoder.model.conv_exp.se.forward = types.MethodType(se_block_forward, model.image_encoder.model.conv_exp.se)
    model.forward = model.encode_image
    ov_image_encoder = ov.convert_model(
        model,
        example_input=image_tensor,
        input=[-1, 3, image_tensor.shape[2], image_tensor.shape[3]],
    )
    ov.save_model(ov_image_encoder, image_encoder_path)
    del ov_image_encoder
    gc.collect()

text_encoder_path = ov_models_dir / f"{model_checkpoint.value}_text_encoder.xml"

if not text_encoder_path.exists():
    model.forward = model.encode_text
    ov_text_encoder = ov.convert_model(model, example_input=text, input=[-1, text.shape[1]])
    ov.save_model(ov_text_encoder, text_encoder_path)
    del ov_text_encoder
    gc.collect()

del model
gc.collect();
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-674/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/mobileclip/modules/common/transformer.py:125: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if seq_len != self.num_embeddings:

Run OpenVINO model inference

Select device for image encoder

core = ov.Core()

import ipywidgets as widgets

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
ov_compiled_image_encoder = core.compile_model(image_encoder_path, device.value)
ov_compiled_image_encoder(image_tensor);

Select device for text encoder

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
ov_compiled_text_encoder = core.compile_model(text_encoder_path, device.value)
ov_compiled_text_encoder(text);

Perform search

image_encoding_start = time.perf_counter()
image_features = torch.from_numpy(ov_compiled_image_encoder(image_tensor)[0])
image_encoding_end = time.perf_counter()
print(f"Image encoding took {image_encoding_end - image_encoding_start:.3} ms")
text_encoding_start = time.perf_counter()
text_features = torch.from_numpy(ov_compiled_text_encoder(text)[0])
text_encoding_end = time.perf_counter()
print(f"Text encoding took {text_encoding_end - text_encoding_start:.3} ms")
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

image_probs = (100.0 * text_features @ image_features.T).softmax(dim=-1)
selected_image = [torch.argmax(image_probs).item()]

visualize_result(images, input_labels[0], selected_image);
Image encoding took 0.0293 ms
Text encoding took 0.0058 ms
../_images/mobileclip-video-search-with-output_25_1.png

Interactive Demo

In this part, you can try different supported by tutorial models in searching frames in the video by text query or image. Upload video and provide text query or reference image for search and model will find the most relevant frames according to provided query. Please note, different models can require different optimal threshold for search.

import altair as alt
import cv2
import gradio as gr
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.transforms import (
    CenterCrop,
    Compose,
    InterpolationMode,
    Resize,
    ToTensor,
)
from open_clip.transform import image_transform


current_device = device.value
current_model = image_encoder_path.name.split("_im_encoder")[0]

available_converted_models = [model_file.name.split("_im_encoder")[0] for model_file in ov_models_dir.glob("*_im_encoder.xml")]
available_devices = list(core.available_devices) + ["AUTO"]

download_file(
    "https://github.com/intel-iot-devkit/sample-videos/raw/master/car-detection.mp4",
    directory=sample_path,
)
download_file(
    "https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/video/Coco%20Walking%20in%20Berkeley.mp4",
    directory=sample_path,
    filename="coco.mp4",
)


def get_preprocess_and_tokenizer(model_name):
    if "mobileclip" in model_name:
        resolution = supported_models["MobileCLIP"][model_name]["image_size"]
        resize_size = resolution
        centercrop_size = resolution
        aug_list = [
            Resize(
                resize_size,
                interpolation=InterpolationMode.BILINEAR,
            ),
            CenterCrop(centercrop_size),
            ToTensor(),
        ]
        preprocess = Compose(aug_list)
        tokenizer = mobileclip.get_tokenizer(supported_models["MobileCLIP"][model_name]["model_name"])
    else:
        model_configs = supported_models["SigLIP"] if "siglip" in model_name else supported_models["CLIP"]
        resize_size = model_configs[model_name]["image_size"]
        preprocess = image_transform((resize_size, resize_size), is_train=False, resize_mode="longest")
        tokenizer = open_clip.get_tokenizer(model_configs[model_name]["model_name"])

    return preprocess, tokenizer


def run(
    path: str,
    text_search: str,
    image_search: Image.Image,
    model_name: str,
    device: str,
    thresh: float,
    stride: int,
    batch_size: int,
):
    assert path, "An input video should be provided"
    assert text_search is not None or image_search is not None, "A text or image query should be provided"
    global current_model
    global current_device
    global preprocess
    global tokenizer
    global ov_compiled_image_encoder
    global ov_compiled_text_encoder

    if current_model != model_name or device != current_device:
        ov_compiled_image_encoder = core.compile_model(ov_models_dir / f"{model_name}_im_encoder.xml", device)
        ov_compiled_text_encoder = core.compile_model(ov_models_dir / f"{model_name}_text_encoder.xml", device)
        preprocess, tokenizer = get_preprocess_and_tokenizer(model_name)
        current_model = model_name
        current_device = device
    # Load video
    dataset = LoadVideo(path, transforms=preprocess, vid_stride=stride)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Get image query features
    if image_search:
        image = preprocess(image_search).unsqueeze(0)
        query_features = torch.from_numpy(ov_compiled_image_encoder(image)[0])
        query_features /= query_features.norm(dim=-1, keepdim=True)
    # Get text query features
    else:
        # Tokenize search phrase
        text = tokenizer([text_search])
        # Encode text query
        query_features = torch.from_numpy(ov_compiled_text_encoder(text)[0])
        query_features /= query_features.norm(dim=-1, keepdim=True)
    # Encode each frame and compare with query features
    matches = []
    matches_probs = []
    res = pd.DataFrame(columns=["Frame", "Timestamp", "Similarity"])
    for image, orig, frame, timestamp in dataloader:
        with torch.no_grad():
            image_features = torch.from_numpy(ov_compiled_image_encoder(image)[0])

        image_features /= image_features.norm(dim=-1, keepdim=True)
        probs = query_features.cpu().numpy() @ image_features.cpu().numpy().T
        probs = probs[0]

        # Save frame similarity values
        df = pd.DataFrame(
            {
                "Frame": frame.tolist(),
                "Timestamp": torch.round(timestamp / 1000, decimals=2).tolist(),
                "Similarity": probs.tolist(),
            }
        )
        res = pd.concat([res, df])

        # Check if frame is over threshold
        for i, p in enumerate(probs):
            if p > thresh:
                matches.append(to_pil_image(orig[i]))
                matches_probs.append(p)

        print(f"Frames: {frame.tolist()} - Probs: {probs}")

    # Create plot of similarity values
    lines = (
        alt.Chart(res)
        .mark_line(color="firebrick")
        .encode(
            alt.X("Timestamp", title="Timestamp (seconds)"),
            alt.Y("Similarity", scale=alt.Scale(zero=False)),
        )
    ).properties(width=600)
    rule = alt.Chart().mark_rule(strokeDash=[6, 3], size=2).encode(y=alt.datum(thresh))

    selected_frames = np.argsort(-1 * np.array(matches_probs))[:20]
    matched_sorted_frames = [matches[idx] for idx in selected_frames]

    return (
        lines + rule,
        matched_sorted_frames,
    )  # Only return up to 20 images to not crash the UI


class LoadVideo(Dataset):
    def __init__(self, path, transforms, vid_stride=1):
        self.transforms = transforms
        self.vid_stride = vid_stride
        self.cur_frame = 0
        self.cap = cv2.VideoCapture(path)
        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)

    def __getitem__(self, _):
        # Read video
        # Skip over frames
        for _ in range(self.vid_stride):
            self.cap.grab()
            self.cur_frame += 1

        # Read frame
        _, img = self.cap.retrieve()
        timestamp = self.cap.get(cv2.CAP_PROP_POS_MSEC)

        # Convert to PIL
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(np.uint8(img))

        # Apply transforms
        img_t = self.transforms(img)

        return img_t, to_tensor(img), self.cur_frame, timestamp

    def __len__(self):
        return self.total_frames


desc_text = """
Search the content's of a video with a text description.
__Note__: Long videos (over a few minutes) may cause UI performance issues.
    """
text_app = gr.Interface(
    description=desc_text,
    fn=run,
    inputs=[
        gr.Video(label="Video"),
        gr.Textbox(label="Text Search Query"),
        gr.Image(label="Image Search Query", visible=False),
        gr.Dropdown(
            label="Model",
            choices=available_converted_models,
            value=model_checkpoint.value,
        ),
        gr.Dropdown(label="Device", choices=available_devices, value=device.value),
        gr.Slider(label="Threshold", maximum=1.0, value=0.2),
        gr.Slider(label="Frame-rate Stride", value=4, step=1),
        gr.Slider(label="Batch Size", value=4, step=1),
    ],
    outputs=[
        gr.Plot(label="Similarity Plot"),
        gr.Gallery(label="Matched Frames", columns=2, object_fit="contain", height="auto"),
    ],
    examples=[[sample_path / "car-detection.mp4", "white car"]],
    allow_flagging="never",
)

desc_image = """
Search the content's of a video with an image query.
__Note__: Long videos (over a few minutes) may cause UI performance issues.
    """
image_app = gr.Interface(
    description=desc_image,
    fn=run,
    inputs=[
        gr.Video(label="Video"),
        gr.Textbox(label="Text Search Query", visible=False),
        gr.Image(label="Image Search Query", type="pil"),
        gr.Dropdown(
            label="Model",
            choices=available_converted_models,
            value=model_checkpoint.value,
        ),
        gr.Dropdown(label="Device", choices=available_devices, value=device.value),
        gr.Slider(label="Threshold", maximum=1.0, value=0.2),
        gr.Slider(label="Frame-rate Stride", value=4, step=1),
        gr.Slider(label="Batch Size", value=4, step=1),
    ],
    outputs=[
        gr.Plot(label="Similarity Plot"),
        gr.Gallery(label="Matched Frames", columns=2, object_fit="contain", height="auto"),
    ],
    allow_flagging="never",
    examples=[[sample_path / "coco.mp4", None, sample_path / "dog.png"]],
)
demo = gr.TabbedInterface(
    interface_list=[text_app, image_app],
    tab_names=["Text Query Search", "Image Query Search"],
    title="CLIP Video Content Search",
)


try:
    demo.launch(debug=False)
except Exception:
    demo.launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
data/car-detection.mp4:   0%|          | 0.00/2.68M [00:00<?, ?B/s]
data/coco.mp4:   0%|          | 0.00/877k [00:00<?, ?B/s]
Running on local URL:  http://127.0.0.1:7860

To create a public link, set share=True in launch().