Visual Content Search using MobileCLIP and OpenVINO#

This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:

Google ColabGithub

Semantic visual content search is a machine learning task that uses either a text query or an input image to search a database of images (photo gallery, video) to find images that are semantically similar to the search query. Historically, building a robust search engine for images was difficult. One could search by features such as file name and image metadata, and use any context around an image (i.e. alt text or surrounding text if an image appears in a passage of text) to provide the richer searching feature. This was before the advent of neural networks that can identify semantically related images to a given user query.

Contrastive Language-Image Pre-Training (CLIP) models provide the means through which you can implement a semantic search engine with a few dozen lines of code. The CLIP model has been trained on millions of pairs of text and images, encoding semantics from images and text combined. Using CLIP, you can provide a text query and CLIP will return the images most related to the query.

In this tutorial, we consider how to use MobileCLIP to implement a visual content search engine for finding relevant frames in video.

Table of contents:

Installation Instructions#

This is a self-contained example that relies solely on its own code.

We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.

Prerequisites#

from pathlib import Path

repo_dir = Path("./ml-mobileclip")

if not repo_dir.exists():
    !git clone https://github.com/apple/ml-mobileclip.git
Cloning into 'ml-mobileclip'...
remote: Enumerating objects: 95, done.
remote: Counting objects: 100% (95/95), done.
remote: Compressing objects: 100% (66/66), done.
remote: Total 95 (delta 38), reused 85 (delta 28), pack-reused 0 (from 0)
Unpacking objects: 100% (95/95), 469.11 KiB | 3.13 MiB/s, done.
%pip install -q "./ml-mobileclip" --no-deps

%pip install -q "clip-benchmark>=1.4.0" "datasets>=2.8.0" "open-clip-torch>=2.20.0" "timm>=0.9.5" "torch>=1.13.1" "torchvision>=0.14.1" --extra-index-url https://download.pytorch.org/whl/cpu

%pip install -q "openvino>=2024.0.0" "gradio>=4.19" "matplotlib" "Pillow"  "altair" "pandas" "opencv-python" "tqdm" "matplotlib>=3.4"
Note: you may need to restart the kernel to use updated packages.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mobileclip 0.1.0 requires torchvision==0.14.1, but you have torchvision 0.17.2+cpu which is incompatible.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Select model#

For starting work, we should select model that will be used in our demonstration. By default, we will use the MobileCLIP model, but for comparison purposes, you can select different models among:

  • CLIP - CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on various (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task. CLIP uses a ViT like transformer to get visual features and a causal language model to get the text features. The text and visual features are then projected into a latent space with identical dimensions. The dot product between the projected image and text features is then used as a similarity score. You can find more information about this model in the research paper, OpenAI blog, model card and GitHub repository.

  • SigLIP - The SigLIP model was proposed in Sigmoid Loss for Language Image Pre-Training. SigLIP proposes to replace the loss function used in CLIP (Contrastive Language–Image Pre-training) by a simple pairwise sigmoid loss. This results in better performance in terms of zero-shot classification accuracy on ImageNet. You can find more information about this model in the research paper and GitHub repository,

  • MobileCLIP - MobileCLIP – a new family of efficient image-text models optimized for runtime performance along with a novel and efficient training approach, namely multi-modal reinforced training. The smallest variant MobileCLIP-S0 obtains similar zero-shot performance as OpenAI’s CLIP ViT-b16 model while being several times faster and 2.8x smaller. More details about model can be found in research paper and GitHub repository.

import ipywidgets as widgets

model_dir = Path("checkpoints")

supported_models = {
    "MobileCLIP": {
        "mobileclip_s0": {
            "model_name": "mobileclip_s0",
            "pretrained": model_dir / "mobileclip_s0.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt",
            "image_size": 256,
        },
        "mobileclip_s1": {
            "model_name": "mobileclip_s1",
            "pretrained": model_dir / "mobileclip_s1.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt",
            "image_size": 256,
        },
        "mobileclip_s2": {
            "model_name": "mobileclip_s0",
            "pretrained": model_dir / "mobileclip_s2.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt",
            "image_size": 256,
        },
        "mobileclip_b": {
            "model_name": "mobileclip_b",
            "pretrained": model_dir / "mobileclip_b.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt",
            "image_size": 224,
        },
        "mobileclip_blt": {
            "model_name": "mobileclip_b",
            "pretrained": model_dir / "mobileclip_blt.pt",
            "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt",
            "image_size": 224,
        },
    },
    "CLIP": {
        "clip-vit-b-32": {
            "model_name": "ViT-B-32",
            "pretrained": "laion2b_s34b_b79k",
            "image_size": 224,
        },
        "clip-vit-b-16": {
            "model_name": "ViT-B-16",
            "pretrained": "openai",
            "image_size": 224,
        },
        "clip-vit-l-14": {
            "model_name": "ViT-L-14",
            "pretrained": "datacomp_xl_s13b_b90k",
            "image_size": 224,
        },
        "clip-vit-h-14": {
            "model_name": "ViT-H-14",
            "pretrained": "laion2b_s32b_b79k",
            "image_size": 224,
        },
    },
    "SigLIP": {
        "siglip-vit-b-16": {
            "model_name": "ViT-B-16-SigLIP",
            "pretrained": "webli",
            "image_size": 224,
        },
        "siglip-vit-l-16": {
            "model_name": "ViT-L-16-SigLIP-256",
            "pretrained": "webli",
            "image_size": 256,
        },
    },
}


model_type = widgets.Dropdown(options=supported_models.keys(), default="MobileCLIP", description="Model type:")
model_type
Dropdown(description='Model type:', options=('MobileCLIP', 'CLIP', 'SigLIP'), value='MobileCLIP')
available_models = supported_models[model_type.value]

model_checkpoint = widgets.Dropdown(
    options=available_models.keys(),
    default=list(available_models),
    description="Model:",
)

model_checkpoint
Dropdown(description='Model:', options=('mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b', 'mo…
import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)

open("notebook_utils.py", "w").write(r.text)

from notebook_utils import download_file, device_widget

model_config = available_models[model_checkpoint.value]

Run model inference#

Now, let’s see model in action. We will try to find image, where some specific object is represented using embeddings. Embeddings are a numeric representation of data such as text and images. The model learned to encode semantics about the contents of images in embedding format. This ability turns the model into a powerful for solving various tasks including image-text retrieval. To reach our goal we should:

  1. Calculate embeddings for all of the images in our dataset;

  2. Calculate a text embedding for a user query (i.e. “black dog” or “car”);

  3. Compare the text embedding to the image embeddings to find related embeddings.

The closer two embeddings are, the more similar the contents they represent are.

Prepare model#

The code bellow download model weights, create model class instance and preprocessing utilities

import torch
import time
from PIL import Image
import mobileclip
import open_clip

# instantiate model
model_name = model_config["model_name"]
pretrained = model_config["pretrained"]
if model_type.value == "MobileCLIP":
    model_dir.mkdir(exist_ok=True)
    model_url = model_config["url"]
    download_file(model_url, directory=model_dir)
    model, _, preprocess = mobileclip.create_model_and_transforms(model_name, pretrained=pretrained)
    tokenizer = mobileclip.get_tokenizer(model_name)
else:
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    tokenizer = open_clip.get_tokenizer(model_name)
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/810/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
checkpoints/mobileclip_s0.pt:   0%|          | 0.00/206M [00:00<?, ?B/s]

Select device for image encoder#

core = ov.Core()

device = device_widget()

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
ov_compiled_image_encoder = core.compile_model(image_encoder_path, device.value)
ov_compiled_image_encoder(image_tensor);

Select device for text encoder#

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
ov_compiled_text_encoder = core.compile_model(text_encoder_path, device.value)
ov_compiled_text_encoder(text);

Perform search#

image_encoding_start = time.perf_counter()
image_features = torch.from_numpy(ov_compiled_image_encoder(image_tensor)[0])
image_encoding_end = time.perf_counter()
print(f"Image encoding took {image_encoding_end - image_encoding_start:.3} ms")
text_encoding_start = time.perf_counter()
text_features = torch.from_numpy(ov_compiled_text_encoder(text)[0])
text_encoding_end = time.perf_counter()
print(f"Text encoding took {text_encoding_end - text_encoding_start:.3} ms")
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

image_probs = (100.0 * text_features @ image_features.T).softmax(dim=-1)
selected_image = [torch.argmax(image_probs).item()]

visualize_result(images, input_labels[0], selected_image);
Image encoding took 0.0294 ms
Text encoding took 0.00498 ms
../_images/mobileclip-video-search-with-output_25_1.png

Interactive Demo#

In this part, you can try different supported by tutorial models in searching frames in the video by text query or image. Upload video and provide text query or reference image for search and model will find the most relevant frames according to provided query. Please note, different models can require different optimal threshold for search.

import altair as alt
import cv2
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.transforms import (
    CenterCrop,
    Compose,
    InterpolationMode,
    Resize,
    ToTensor,
)
from open_clip.transform import image_transform
from typing import Optional


current_device = device.value
current_model = image_encoder_path.name.split("_im_encoder")[0]

available_converted_models = [model_file.name.split("_im_encoder")[0] for model_file in ov_models_dir.glob("*_im_encoder.xml")]
available_devices = list(core.available_devices) + ["AUTO"]

download_file(
    "https://storage.openvinotoolkit.org/data/test_data/videos/car-detection.mp4",
    directory=sample_path,
)
download_file(
    "https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/video/Coco%20Walking%20in%20Berkeley.mp4",
    directory=sample_path,
    filename="coco.mp4",
)


def get_preprocess_and_tokenizer(model_name):
    if "mobileclip" in model_name:
        resolution = supported_models["MobileCLIP"][model_name]["image_size"]
        resize_size = resolution
        centercrop_size = resolution
        aug_list = [
            Resize(
                resize_size,
                interpolation=InterpolationMode.BILINEAR,
            ),
            CenterCrop(centercrop_size),
            ToTensor(),
        ]
        preprocess = Compose(aug_list)
        tokenizer = mobileclip.get_tokenizer(supported_models["MobileCLIP"][model_name]["model_name"])
    else:
        model_configs = supported_models["SigLIP"] if "siglip" in model_name else supported_models["CLIP"]
        resize_size = model_configs[model_name]["image_size"]
        preprocess = image_transform((resize_size, resize_size), is_train=False, resize_mode="longest")
        tokenizer = open_clip.get_tokenizer(model_configs[model_name]["model_name"])

    return preprocess, tokenizer


def run(
    path: str,
    text_search: str,
    image_search: Optional[Image.Image],
    model_name: str,
    device: str,
    thresh: float,
    stride: int,
    batch_size: int,
):
    assert path, "An input video should be provided"
    assert text_search is not None or image_search is not None, "A text or image query should be provided"
    global current_model
    global current_device
    global preprocess
    global tokenizer
    global ov_compiled_image_encoder
    global ov_compiled_text_encoder

    if current_model != model_name or device != current_device:
        ov_compiled_image_encoder = core.compile_model(ov_models_dir / f"{model_name}_im_encoder.xml", device)
        ov_compiled_text_encoder = core.compile_model(ov_models_dir / f"{model_name}_text_encoder.xml", device)
        preprocess, tokenizer = get_preprocess_and_tokenizer(model_name)
        current_model = model_name
        current_device = device
    # Load video
    dataset = LoadVideo(path, transforms=preprocess, vid_stride=stride)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Get image query features
    if image_search:
        image = preprocess(image_search).unsqueeze(0)
        query_features = torch.from_numpy(ov_compiled_image_encoder(image)[0])
        query_features /= query_features.norm(dim=-1, keepdim=True)
    # Get text query features
    else:
        # Tokenize search phrase
        text = tokenizer([text_search])
        # Encode text query
        query_features = torch.from_numpy(ov_compiled_text_encoder(text)[0])
        query_features /= query_features.norm(dim=-1, keepdim=True)
    # Encode each frame and compare with query features
    matches = []
    matches_probs = []
    res = pd.DataFrame(columns=["Frame", "Timestamp", "Similarity"])
    for image, orig, frame, timestamp in dataloader:
        with torch.no_grad():
            image_features = torch.from_numpy(ov_compiled_image_encoder(image)[0])

        image_features /= image_features.norm(dim=-1, keepdim=True)
        probs = query_features.cpu().numpy() @ image_features.cpu().numpy().T
        probs = probs[0]

        # Save frame similarity values
        df = pd.DataFrame(
            {
                "Frame": frame.tolist(),
                "Timestamp": torch.round(timestamp / 1000, decimals=2).tolist(),
                "Similarity": probs.tolist(),
            }
        )
        res = pd.concat([res, df])

        # Check if frame is over threshold
        for i, p in enumerate(probs):
            if p > thresh:
                matches.append(to_pil_image(orig[i]))
                matches_probs.append(p)

        print(f"Frames: {frame.tolist()} - Probs: {probs}")

    # Create plot of similarity values
    lines = (
        alt.Chart(res)
        .mark_line(color="firebrick")
        .encode(
            alt.X("Timestamp", title="Timestamp (seconds)"),
            alt.Y("Similarity", scale=alt.Scale(zero=False)),
        )
    ).properties(width=600)
    rule = alt.Chart().mark_rule(strokeDash=[6, 3], size=2).encode(y=alt.datum(thresh))

    selected_frames = np.argsort(-1 * np.array(matches_probs))[:20]
    matched_sorted_frames = [matches[idx] for idx in selected_frames]

    return (
        lines + rule,
        matched_sorted_frames,
    )  # Only return up to 20 images to not crash the UI


class LoadVideo(Dataset):
    def __init__(self, path, transforms, vid_stride=1):
        self.transforms = transforms
        self.vid_stride = vid_stride
        self.cur_frame = 0
        self.cap = cv2.VideoCapture(path)
        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)

    def __getitem__(self, _):
        # Read video
        # Skip over frames
        for _ in range(self.vid_stride):
            self.cap.grab()
            self.cur_frame += 1

        # Read frame
        _, img = self.cap.retrieve()
        timestamp = self.cap.get(cv2.CAP_PROP_POS_MSEC)

        # Convert to PIL
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(np.uint8(img))

        # Apply transforms
        img_t = self.transforms(img)

        return img_t, to_tensor(img), self.cur_frame, timestamp

    def __len__(self):
        return self.total_frames
data/car-detection.mp4:   0%|          | 0.00/2.68M [00:00<?, ?B/s]
data/coco.mp4:   0%|          | 0.00/877k [00:00<?, ?B/s]
if not Path("gradio_helper.py").exists():
    r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/mobileclip-video-search/gradio_helper.py")
    open("gradio_helper.py", "w").write(r.text)

from gradio_helper import make_demo, Option

demo = make_demo(
    run=run,
    model_option=Option(choices=available_converted_models, value=model_checkpoint.value),
    device_option=Option(choices=available_devices, value=device.value),
)

try:
    demo.launch(debug=False)
except Exception:
    demo.launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL:  http://127.0.0.1:7860

To create a public link, set share=True in launch().