Visual-language assistant with Video-LLaVA and OpenVINO

This Jupyter notebook can be launched after a local installation only.

Github

Video-LLaVA (Learning United Visual Representation by Alignment Before Projection, paper) is a Large Vision-Language Model (LVLM) that breaks new ground by understanding both images and videos through a single, unified visual representation. While LLaVA excels at image-based tasks, Video-LLaVA expands this fluency to the dynamic world of videos, enabling seamless comprehension and reasoning across both visual domains. This means it can answer questions, generate text, and perform other tasks with equal ease, regardless of whether it’s presented with a still image or a moving scene.

In this tutorial we consider how to use Video-LLaVA model to build multimodal chatbot. For demonstration purposes we will use Video-LLaVA-7B model for conversion.

The tutorial consists from following steps:

  • Install prerequisites

  • Prepare input processor and tokenizer

  • Download original model

  • Compress model weights to 4 and 8 bits using NNCF

  • Convert model to OpenVINO Intermediate Representation (IR) format

  • Prepare OpenVINO-based inference pipeline

  • Run OpenVINO model

Table of contents:

About model

Video-LLaVA connects pre-trained CLIP ViT-L/14 visual encoders and large language model using a simple projection matrix

image0

More details about model can be found in original paper and repo.

Prerequisites

Install required dependencies

%pip install -q torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "transformers>=4.31.0,<4.35.0" einops peft opencv_python decord pytorchvideo sentencepiece protobuf "openvino>=2023.2.0" "nncf>=2.7.0" gradio
from pathlib import Path
import sys

repo_dir = Path("Video-LLaVA")

if not repo_dir.exists():
    !git clone https://github.com/PKU-YuanGroup/Video-LLaVA.git

sys.path.insert(0, str(repo_dir.resolve()))

Warning: this tutorial requires the ffmpeg package. To install it for your system, visit the official FFmpeg download page.

import gc

import transformers
from llava.model import LlavaLlamaForCausalLM
from llava.constants import (
    DEFAULT_X_PATCH_TOKEN,
    DEFAULT_X_START_TOKEN,
    DEFAULT_X_END_TOKEN,
    DEFAULT_X_TOKEN
)
transformers.logging.set_verbosity_error()

model_id = "LanguageBind/Video-LLaVA-7B"

config = transformers.AutoConfig.from_pretrained(model_id)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
model = LlavaLlamaForCausalLM.from_pretrained(model_id)
image_tower = model.get_image_tower()
video_tower = model.get_video_tower()
image_tower.load_model()
video_tower.load_model()
image_processor = image_tower.image_processor
video_processor = video_tower.video_processor
mm_use_x_start_end = getattr(config, "mm_use_x_start_end", False)
mm_use_x_patch_token = getattr(config, "mm_use_x_patch_token", True)
if mm_use_x_patch_token:
    for x in config.X:
        tokenizer.add_tokens([DEFAULT_X_PATCH_TOKEN[x.upper()]], special_tokens=True)
if mm_use_x_start_end:
    for x in config.X:
        tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
preprocess_fn = model.prepare_inputs_labels_for_multimodal

del model
gc.collect()
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/torch/cuda/__init__.py:740: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
  return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c104cuda20CUDACachingAllocator12recordStreamERKNS_7DataPtrENS0_10CUDAStreamE
  warn(f"Failed to load image Python extension: {e}")
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/torchvision/transforms/_functional_video.py:6: UserWarning: The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in 0.14. Please use the 'torchvision.transforms.functional' module instead.
  warnings.warn(
/home/itrushkin/.virtualenvs/videollava/lib/python3.10/site-packages/torchvision/transforms/_transforms_video.py:25: UserWarning: The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in 0.14. Please use the 'torchvision.transforms' module instead.
  warnings.warn(
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
29

Build model and convert it to OpenVINO IR format

Video-LLaVA is autoregressive transformer generative model, it means that each next model step depends from model output from previous step. The generation approach is based on the assumption that the probability distribution of a word sequence can be decomposed into the product of conditional next word distributions. In other words, model predicts the next token in the loop guided by previously generated tokens until the stop-condition will be not reached (generated sequence of maximum length or end of string token obtained). The way the next token will be selected over predicted probabilities is driven by the selected decoding methodology. You can find more information about the most popular decoding methods in this blog. The entry point for the generation process for models from the Hugging Face Transformers library is the generate method. You can find more information about its parameters and configuration in the documentation. To preserve flexibility in the selection decoding methodology, we will convert only model inference for one step.

The inference flow has difference on first step and for the next. On the first step, model accept preprocessed input instruction and video, after that LLM-based part of model runs on input embeddings to predict probability of next generated tokens. On the next step, model accepts only next token id selected based on sampling strategy and cached attention key and values. Since the output side is auto-regressive, an output token hidden state remains the same once computed for every further generation step. Therefore, recomputing it every time you want to generate a new token seems wasteful. With the cache, the model saves the hidden state once it has been computed. The model only computes the one for the most recently generated output token at each time step, re-using the saved ones for hidden tokens. This reduces the generation complexity from \(O(n^3)\) to \(O(n^2)\) for a transformer model. More details about how it works can be found in this article.

Prepare helpers for model conversion

The code below prepares function for converting Video-LLaVA model to OpenVINO Intermediate Representation format. It splits model on parts described above, prepare example inputs for each part and convert each part using OpenVINO Model Conversion API. ov.convert_model function accepts PyTorch model instance and returns ov.Model object that represent model in OpenVINO format. It is ready to use for loading on device using ov.compile_model or can be saved on disk using ov.save_model.

import torch
import openvino as ov
import nncf
from typing import Optional, Tuple, List


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
    ):
        outputs = self.model.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )

        hidden_states = outputs[0]
        logits = self.model.lm_head(hidden_states)

        return (logits, outputs.past_key_values)


def set_node_names(ov_model, input_names=None, output_names=None):
    if input_names is not None:
        for inp, name in zip(ov_model.inputs, input_names):
            inp.get_tensor().set_names({name})
    if output_names is not None:
        for out, name in zip(ov_model.outputs, output_names):
            out.get_tensor().set_names({name})

    ov_model.validate_nodes_and_infer_types()


def cleanup_torchscript_cache():
    """
    Helper for removing cached model representation
    """
    torch._C._jit_clear_class_registry()
    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
    torch.jit._state._clear_class_state()

def convert_videollava(
    pt_model: torch.nn.Module,
    model_path: Path,
    videollava_wc_parameters: Optional[dict] = None
):
    """
    Video-LLaVA model conversion function

    Params:
      pt_model: PyTorch model
      model_path: path for saving model
    Returns:
      None
    """
    ov_out_path = Path(model_path)
    pt_model.config.save_pretrained(ov_out_path)
    pt_model.config.use_cache = True
    pt_model.config.torchscript = True
    wrapped = ModelWrapper(pt_model)
    first_stage_model_path = ov_out_path / "videollava_input_embed.xml"
    second_stage_model_path = ov_out_path / "videollava_with_past.xml"

    if first_stage_model_path.exists() and second_stage_model_path.exists():
        print("Video-LLaVA model successfully converted")
        del pt_model
        return
    example_input_first_stage = {
        "inputs_embeds": torch.zeros((1, 307, 4096)),
        "attention_mask": torch.ones((1, 307), dtype=torch.long)
    }
    outs = wrapped(**example_input_first_stage)
    input_names = ["input_ids", "attention_mask"]
    output_names = ["logits"]
    for idx in range(len(outs[1])):
        input_names.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"])
        output_names.extend([f"present.{idx}.key", f"present.{idx}.value"])

    if not first_stage_model_path.exists():
        ov_model = ov.convert_model(
            wrapped, example_input=example_input_first_stage
        )
        set_node_names(ov_model, output_names=output_names)
        if videollava_wc_parameters is not None:
            print("Applying weight compression to first stage Video-LLaVA model")
            ov_model = nncf.compress_weights(ov_model, **videollava_wc_parameters)
        ov.save_model(ov_model, first_stage_model_path)
        cleanup_torchscript_cache()
        del ov_model
        gc.collect()


    if not second_stage_model_path.exists():
        example_input_second_stage = {
            "input_ids": torch.ones((1, 1), dtype=torch.long),
            "attention_mask": torch.ones((1, outs[1][-1][-1].shape[-2] + 1), dtype=torch.long),
            "past_key_values": outs[1],
        }
        ov_model = ov.convert_model(wrapped, example_input=example_input_second_stage)
        set_node_names(ov_model, input_names, output_names)

        if videollava_wc_parameters is not None:
            print("Applying weight compression to second stage Video-LLaVA model")
            ov_model = nncf.compress_weights(ov_model, **videollava_wc_parameters)
        ov.save_model(ov_model, second_stage_model_path)
        cleanup_torchscript_cache()
        del ov_model
        gc.collect()
    print("Video-LLaVA model successfully converted")
    del wrapped
    del pt_model
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino

Convert and Optimize Model

Our model conversion and optimization consist of following steps: 1. Download original PyTorch model. 2. Compress model weights using NNCF 3. Convert model to OpenVINO format and save it on disk.

Let’s consider each step more deeply.

Instantiate PyTorch model

For creating PyTorch model we should use from_pretrained method of LlavaLlamaForCausalLM model class. Model weights will be downloaded from HuggingFace hub during first run. It may takes some time and requires at least 13 Gb free space on disk.

Compress Model weights to 4 and 8 bits using NNCF

For reducing memory consumption, weights compression optimization can be applied using NNCF. Weight compression aims to reduce the memory footprint of a model. It can also lead to significant performance improvement for large memory-bound models, such as Large Language Models (LLMs). LLMs and other models, which require extensive memory to store the weights during inference, can benefit from weight compression in the following ways:

  • enabling the inference of exceptionally large models that cannot be accommodated in the memory of the device;

  • improving the inference performance of the models by reducing the latency of the memory access when computing the operations with weights, for example, Linear layers.

Neural Network Compression Framework (NNCF) provides 4-bit / 8-bit mixed weight quantization as a compression method primarily designed to optimize LLMs. The main difference between weights compression and full model quantization (post-training quantization) is that activations remain floating-point in the case of weights compression which leads to a better accuracy. Weight compression for LLMs provides a solid inference performance improvement which is on par with the performance of the full model quantization. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use.

nncf.compress_weights function can be used for performing weights compression. The function accepts an OpenVINO model and other compression parameters. Compared to INT8 compression, INT4 compression improves performance even more, but introduces a minor drop in prediction quality.

More details about weights compression, can be found in OpenVINO documentation.

NOTE: There is no speedup for INT4 compressed models on dGPU.

Convert model to OpenVINO IR format

Convert model to OpenVINO format using conversion helper function defined above.

Please select below whether you would like to run INT4 weight compression instead of INT8 weight compression.

import ipywidgets as widgets

compression_mode = widgets.Dropdown(
    options=["INT4", "INT8"],
    value="INT4",
    description="Compression mode:",
    disabled=False,
)

compression_mode
Dropdown(description='Compression mode:', options=('INT4', 'INT8'), value='INT4')
if compression_mode.value == "INT4":
    compressed_model_dir = Path("videollava/INT4_compressed_weights")
    videollava_wc_parameters = dict(mode=nncf.CompressWeightsMode.INT4_ASYM, group_size=128, ratio=0.8)
else:
    compressed_model_dir = Path("videollava/INT8_compressed_weights")
    videollava_wc_parameters = dict(mode=nncf.CompressWeightsMode.INT8)

if not compressed_model_dir.exists():
    compressed_model_dir.mkdir(exist_ok=True, parents=True)
    model = LlavaLlamaForCausalLM.from_pretrained(model_id)
    model.resize_token_embeddings(len(tokenizer))

    if hasattr(config, "max_sequence_length"):
        context_len = config.max_sequence_length
    else:
        context_len = 2048
    image_tower = model.get_image_tower()
    if not image_tower.is_loaded:
        image_tower.load_model()
    video_tower = model.get_video_tower()
    if not video_tower.is_loaded:
        video_tower.load_model()

    model.eval()
    with torch.no_grad():
        convert_videollava(
            model,
            compressed_model_dir,
            videollava_wc_parameters=videollava_wc_parameters
        )
    del model
    gc.collect();
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
WARNING:nncf:NNCF provides best results with torch==2.1.0, while current torch version is 2.1.2+cu121. If you encounter issues, consider switching to torch==2.1.0
Applying weight compression to first stage Video-LLaVA model
Output()
INFO:nncf:Statistics of the bitwidth distribution:
+--------------+-----------------+--------------------+
| Num bits (N) |  % all weight   | % internal weights |
+==============+=================+====================+
| 8            | 22% (58 / 225)  | 20% (56 / 223)     |
+--------------+-----------------+--------------------+
| 4            | 78% (167 / 225) | 80% (167 / 223)    |
+--------------+-----------------+--------------------+
Output()
Applying weight compression to second stage Video-LLaVA model
Output()
INFO:nncf:Statistics of the bitwidth distribution:
+--------------+-----------------+--------------------+
| Num bits (N) |  % all weight   | % internal weights |
+==============+=================+====================+
| 8            | 23% (58 / 226)  | 20% (56 / 224)     |
+--------------+-----------------+--------------------+
| 4            | 77% (168 / 226) | 80% (168 / 224)    |
+--------------+-----------------+--------------------+
Output()
Video-LLaVA model successfully converted

Prepare OpenVINO based inference pipeline

OVLlavaLlamaForCausalLM class provides ease-to-use interface for using model in generation scenario. It is based on transformers.generation.GenerationMixin that gives us opportunity to reuse all reach capabilities for generation implemented in HuggingFace Transformers library. More details about this interface can be found in HuggingFace documentation.

from transformers.generation import GenerationConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
import numpy as np
import torch


class OVLlavaLlamaForCausalLM(GenerationMixin):
    def __init__(self, core, model_dir, device):
        self.model = core.read_model(model_dir / "videollava_with_past.xml")
        self.model_input_embed = core.compile_model(
            model_dir / "videollava_input_embed.xml", device
        )
        self.input_names = {
            key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)
        }
        self.output_names = {
            key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)
        }
        self.key_value_input_names = [
            key for key in self.input_names if "key_values" in key
        ]
        self.key_value_output_names = [
            key for key in self.output_names if "present" in key
        ]
        compiled_model = core.compile_model(self.model, device)
        self.request = compiled_model.create_infer_request()
        self.config = transformers.AutoConfig.from_pretrained(model_dir)
        self.generation_config = GenerationConfig.from_model_config(config)
        self.main_input_name = "input_ids"
        self.device = torch.device("cpu")
        self.num_pkv = 2

    def can_generate(self):
        """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
        return True

    def __call__(
        self,
        input_ids: torch.LongTensor,
        images: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        prefix_mask: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        return self.forward(
            input_ids, images, attention_mask, prefix_mask, past_key_values
        )

    def forward(
        self,
        input_ids: torch.LongTensor,
        images: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        prefix_mask: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """General inference method"""
        inputs = {}
        if past_key_values is not None:
            # Flatten the past_key_values
            attention_mask = torch.ones(
                (input_ids.shape[0], past_key_values[-1][-1].shape[-2] + 1),
                dtype=input_ids.dtype,
            )
            past_key_values = (
                past_key_value
                for pkv_per_layer in past_key_values
                for past_key_value in pkv_per_layer
            )
            # Add the past_key_values to the decoder inputs
            inputs = dict(zip(self.key_value_input_names, past_key_values))

        else:
            return self.forward_with_image(input_ids, images, attention_mask)
        inputs["input_ids"] = np.array(input_ids)

        if "attention_mask" in self.input_names:
            inputs["attention_mask"] = np.array(attention_mask)

        # Run inference
        self.request.start_async(inputs, share_inputs=True)
        self.request.wait()

        logits = torch.from_numpy(self.request.get_tensor("logits").data)

        # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
        past_key_values = tuple(
            self.request.get_tensor(key).data for key in self.key_value_output_names
        )
        # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)

        past_key_values = tuple(
            past_key_values[i : i + self.num_pkv]
            for i in range(0, len(past_key_values), self.num_pkv)
        )
        return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

    def forward_with_image(self, input_ids, images, attention_mask):
        """First step inference method, that resolves multimodal data"""
        _, attention_mask, _, input_embeds, _ = preprocess_fn(
            input_ids, attention_mask, past_key_values=None, labels=None, X_modalities=images
        )
        outs = self.model_input_embed({"inputs_embeds": input_embeds, "attention_mask": attention_mask})
        logits = outs[0]
        pkv = list(outs.values())[1:]
        pkv = tuple(pkv[i : i + self.num_pkv] for i in range(0, len(pkv), self.num_pkv))
        return CausalLMOutputWithPast(
            logits=torch.from_numpy(logits), past_key_values=pkv
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
        """
        This function is used during running GenerationMixin.generate for preparing model specific inputs for
        each generation step
        """
        past_len = 0
        if past_key_values is not None:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            past_len = past_key_values[-1][-1].shape[-2]
        attention_mask = kwargs.get(
            "attention_mask",
            torch.ones(input_ids.shape[0], input_ids.shape[1] + past_len),
        )
        if not kwargs.get("use_cache", True):
            raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
        else:
            prefix_mask = None
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "prefix_mask": prefix_mask,
            "past_key_values": past_key_values,
            "images": kwargs.get("images", None),
        }

    def _reorder_cache(
        self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called.
        This is required to match `past_key_values` with the correct beam_idx at every generation step.
        """

        # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
        return tuple(
            tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
            for layer_past in past_key_values
        )

Run model inference

Now, when we have model and defined generation pipeline, we can run model inference.

Select inference device

Select device from dropdown list for running inference using OpenVINO.

NOTE: There is no speedup for INT4 compressed models on dGPU.

import ipywidgets as widgets

core = ov.Core()

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

Load OpenVINO model

ov_model = OVLlavaLlamaForCausalLM(core, compressed_model_dir, device.value)

Prepare input data

For preparing input data, we will use tokenizer and image processor defined in the begging of our tutorial. For alignment with original PyTorch implementation we will use PyTorch tensors as input.

from IPython.display import display, Video, Image


examples_dir = Path("Video-LLaVA/llava/serve/examples")
video_file = examples_dir / "sample_demo_22.mp4"
image_file = examples_dir / "sample_img_22.png"


video_tensor = video_processor.preprocess(str(video_file), return_tensors="pt")["pixel_values"][0]
image_tensor = image_processor.preprocess(str(image_file), return_tensors="pt")["pixel_values"][0]
X_modalities = [[video_tensor, image_tensor], ["video", "image"]]

text_message = "Are the instruments in the pictures used in the video?"
print(f"Question: {text_message}")
display(Video(video_file, embed=True))
Image(image_file, embed=True)
Question: Are the instruments in the pictures used in the video?
../_images/257-videollava-multimodal-chatbot-with-output_19_2.png

Test model inference

Generation process for long response maybe time consuming, for accessing partial result as soon as it is generated without waiting when whole process finished, Streaming API can be used. Token streaming is the mode in which the generative system returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience. You can find more details about how streaming work in HuggingFace documentation.

Also for simplification of preparing input in conversational mode, we will use Conversation Template helper provided by model authors for accumulating history of provided messages and images.

from llava.mm_utils import tokenizer_X_token, KeywordsStoppingCriteria
from llava.constants import X_TOKEN_INDEX
from transformers import TextStreamer
from llava.conversation import conv_templates, SeparatorStyle

# Prepare
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
conv_mode = "llava_v1"

conv = conv_templates[conv_mode].copy()
roles = ("user", "assistant")

if mm_use_x_start_end:
    inp = DEFAULT_X_START_TOKEN["VIDEO"] + DEFAULT_X_TOKEN["VIDEO"] + DEFAULT_X_END_TOKEN["VIDEO"] + "\n"
    inp += DEFAULT_X_START_TOKEN["IMAGE"] + DEFAULT_X_TOKEN["IMAGE"] + DEFAULT_X_END_TOKEN["IMAGE"] + "\n"
    inp += text_message
else:
    inp = DEFAULT_X_TOKEN["VIDEO"] + DEFAULT_X_TOKEN["IMAGE"] + "\n" + text_message
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)

prompt = conv.get_prompt()
input_ids1 = tokenizer_X_token(prompt.split(f'\n{X_TOKEN_INDEX["IMAGE"]}')[0], tokenizer, X_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0)
input_ids2 = tokenizer_X_token(prompt.split(f'\n{X_TOKEN_INDEX["IMAGE"]}')[-1], tokenizer, X_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0)
input_ids3 = tokenizer_X_token(f'\n{X_TOKEN_INDEX["IMAGE"]}', tokenizer, X_TOKEN_INDEX['IMAGE'], return_tensors='pt').unsqueeze(0)

input_ids = torch.cat([input_ids1, input_ids3[:, 1:], input_ids2[:, 1:]], dim=-1)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
print("Answer:")

output_ids = ov_model.generate(
    input_ids,
    images=X_modalities,
    do_sample=True,
    temperature=0.2,
    max_new_tokens=1024,
    streamer=streamer,
    use_cache=True,
    stopping_criteria=[stopping_criteria],
)
Answer:
['video', 'image']
Yes, the instruments in the pictures are used in the video. The man is playing a drum set, which includes a bass drum, snare drum, and cymbals. The cymbals are used to produce different sounds, such as crashes and hi-hats. The man is also seen playing a guitar, which is another instrument used in the video.

Interactive demo

import torch
import gradio as gr

from llava.constants import DEFAULT_X_TOKEN, X_TOKEN_INDEX
from llava.conversation import conv_templates, SeparatorStyle


def generate(image, video, textbox_in):
    if video is not None:
        textbox_in = DEFAULT_X_TOKEN["VIDEO"] + "\n" + textbox_in
        if image is not None:
            textbox_in += "\n" + DEFAULT_X_TOKEN["IMAGE"]
    elif image is not None:
        textbox_in = DEFAULT_X_TOKEN['IMAGE'] + '\n' + textbox_in

    conv_mode = "llava_v1"
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], textbox_in)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    image_tensor, video_tensor = None, None
    if image is not None:
        image_tensor = image_processor(image, return_tensors="pt")["pixel_values"][0]
        if video is not None:
            video_tensor = video_processor(video, return_tensors="pt")["pixel_values"][0]
            input_ids1 = tokenizer_X_token(prompt.split(f'\n{X_TOKEN_INDEX["IMAGE"]}')[0], tokenizer, X_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0)
            input_ids2 = tokenizer_X_token(prompt.split(f'\n{X_TOKEN_INDEX["IMAGE"]}')[-1], tokenizer, X_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0)
            input_ids3 = tokenizer_X_token(f'\n{X_TOKEN_INDEX["IMAGE"]}', tokenizer, X_TOKEN_INDEX['IMAGE'], return_tensors='pt').unsqueeze(0)

            input_ids = torch.cat([input_ids1, input_ids3[:, 1:], input_ids2[:, 1:]], dim=-1)
        else:
            input_ids = tokenizer_X_token(prompt, tokenizer, X_TOKEN_INDEX['IMAGE'], return_tensors='pt').unsqueeze(0)
    elif video is not None:
        video_tensor = video_processor(video, return_tensors="pt")["pixel_values"][0]
        input_ids = tokenizer_X_token(prompt, tokenizer, X_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0)
    X_modalities = [[], []]
    if video is not None:
        X_modalities[0] += [video_tensor]
        X_modalities[1] += ["video"]
    if image is not None:
        X_modalities[0] += [image_tensor]
        X_modalities[1] += ["image"]
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    generate_kwargs = dict(
        input_ids=input_ids,
        images=X_modalities,
        max_new_tokens=1024,
        temperature=0.2,
        do_sample=True,
        use_cache=True,
        stopping_criteria=[stopping_criteria],
    )

    output_ids = ov_model.generate(**generate_kwargs)

    input_token_len = input_ids.shape[1]
    outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[:-len(stop_str)]
    outputs = outputs.strip()

    return outputs

demo = gr.Interface(
    generate,
    [
        gr.Image(label="Input Image", type="filepath"),
        gr.Video(label="Input Video"),
        gr.Textbox(label="Question")
    ],
    gr.Textbox(lines=10),
    examples=[
        [
            f"{examples_dir}/extreme_ironing.jpg",
            None,
            "What is unusual about this image?",
        ],
        [
            f"{examples_dir}/waterview.jpg",
            None,
            "What are the things I should be cautious about when I visit here?",
        ],
        [
            f"{examples_dir}/desert.jpg",
            None,
            "If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?",
        ],
        [
            None,
            f"{examples_dir}/sample_demo_1.mp4",
            "Why is this video funny?",
        ],
        [
            None,
            f"{examples_dir}/sample_demo_3.mp4",
            "Can you identify any safety hazards in this video?"
        ],
        [
            None,
            f"{examples_dir}/sample_demo_9.mp4",
            "Describe the video.",
        ],
        [
            None,
            f"{examples_dir}/sample_demo_22.mp4",
            "Describe the activity in the video.",
        ],
        [
            f"{examples_dir}/sample_img_22.png",
            f"{examples_dir}/sample_demo_22.mp4",
            "Are the instruments in the pictures used in the video?",
        ],
        [
            f"{examples_dir}/sample_img_13.png",
            f"{examples_dir}/sample_demo_13.mp4",
            "Does the flag in the image appear in the video?",
        ],
        [
            f"{examples_dir}/sample_img_8.png",
            f"{examples_dir}/sample_demo_8.mp4",
            "Are the image and the video depicting the same place?",
        ],
    ],
    title="Video-LLaVA🚀",
    allow_flagging="never"
)
try:
    demo.queue().launch(debug=False)
except Exception:
    demo.queue().launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/