Humans perceive the world through vision and language. A longtime goal
of AI is to build intelligent agents that can understand the world
through vision and language inputs to communicate with humans through
natural language. In order to achieve this goal, vision-language
pre-training has emerged as an effective approach, where deep neural
network models are pre-trained on large scale image-text datasets to
improve performance on downstream vision-language tasks, such as
image-text retrieval, image captioning, and visual question answering.
BLIP is a language-image
pre-training framework for unified vision-language understanding and
generation. BLIP achieves state-of-the-art results on a wide range of
vision-language tasks. This tutorial demonstrates how to use BLIP for
visual question answering and image captioning. An additional part of
tutorial demonstrates how to speed up the model by applying 8-bit
post-training quantization and data free int8 weight compression from
NNCF (Neural Network
Compression Framework) to OpenVINO IR models and infer optimized BLIP
model via OpenVINO™ Toolkit.
The tutorial consists of the following parts:
Instantiate a BLIP model.
Convert the BLIP model to OpenVINO IR.
Run visual question answering and image captioning with OpenVINO.
Visual language processing is a branch of artificial intelligence that
focuses on creating algorithms designed to enable computers to more
accurately understand images and their content.
Popular tasks include:
Text to Image Retrieval - a semantic task that aims to find the
most relevant image for a given text description.
Image Captioning - a semantic task that aims to provide a text
description for image content.
Visual Question Answering - a semantic task that aims to answer
questions based on image content.
As shown in the diagram below, these three tasks differ in the input
provided to the AI system. For text-to-image retrieval, you have a
predefined gallery of images for search and a user-requested text
description (query). Image captioning can be represented as a particular
case of visual question answering, where you have a predefined question
“What is in the picture?” and various images provided by a user. For
visual question answering, both the text-based question and image
context are variables requested by a user.
This notebook does not focus on Text to Image retrieval. Instead, it
considers Image Captioning and Visual Question Answering.
Image Captioning is the task of describing the content of an image in
words. This task lies at the intersection of computer vision and natural
language processing. Most image captioning systems use an
encoder-decoder framework, where an input image is encoded into an
intermediate representation of the information in the image, and then
decoded into a descriptive text sequence.
Visual Question Answering (VQA) is the task of answering text-based
questions about image content.
For a better understanding of how VQA works, let us consider a
traditional NLP task like Question Answering, which aims to retrieve the
answer to a question from a given text input. Typically, a question
answering pipeline consists of three steps:
Question analysis - analysis of provided question in natural language
form to understand the object in the question and additional context.
For example, if you have a question like “How many bridges in
Paris?”, question words “how many” gives a hint that the answer is
more likely to be a number, “bridges” is the target object of the
question and “ in Paris” serves as additional context for the
search.
Build query for search - use analyzed results to formalize query for
finding the most relevant information.
Perform a search in the knowledge base - send the query to a
knowledge base, typically provided text documents or databases serve
as a source of knowledge.
The difference between text-based question answering and visual question
answering is that an image is used as context and the knowledge base.
Answering arbitrary questions about images is a complex problem because
it requires involving a lot of computer vision sub-tasks. In the table
below, you can find an example of questions and the required computer
vision skills to find answers.
Computer vision task
Question examples
Object recognition
What is shown in the picture? What is
it?
Object detection
Is there any object (dog, man, book)
in the image? Where is … located?
Object and image attribute
recognition
What color is an umbrella? Does this
man wear glasses? Is there color in
the image?
Scene recognition
Is it rainy? What celebration is
pictured?
Object counting
How many players are there on the
football field? How many steps are
there on the stairs?
Activity recognition
Is the baby crying? What is the woman
cooking? What are they doing?
Spatial relationships among
objects
What is located between the sofa and
the armchair? What is in the bottom
left corner?
Commonsense reasoning
Does she have 100% vision? Does this
person have children?
Knowledge-based reasoning
Is it a vegetarian pizza?
Text recognition
What is the title of the book? What is
shown on the screen?
There are a lot of applications for visual question answering:
Aid Visually Impaired Persons: VQA models can be used to reduce
barriers for visually impaired people by helping them get information
about images from the web and the real world.
Education: VQA models can be used to improve visitor experiences at
museums by enabling observers to directly ask questions they are
interested in or to bring more interactivity to schoolbooks for
children interested in acquiring specific knowledge.
E-commerce: VQA models can retrieve information about products using
photos from online stores.
Independent expert assessment: VQA models can be provide objective
assessments in sports competitions, medical diagnosis, and forensic
examination.
To pre-train a unified vision-language model with both understanding and
generation capabilities, BLIP introduces a multimodal mixture of an
encoder-decoder and a multi-task model which can operate in one of the
three modes:
Unimodal encoders, which separately encode images and text. The
image encoder is a vision transformer. The text encoder is the same
as BERT.
Image-grounded text encoder, which injects visual information by
inserting a cross-attention layer between the self-attention layer
and the feed-forward network for each transformer block of the text
encoder.
Image-grounded text decoder, which replaces the bi-directional
self-attention layers in the text encoder with causal self-attention
layers.
In this tutorial, you will use the
blip-vqa-base
model available for download from Hugging
Face. The same actions are also applicable
to other similar models from the BLIP family. Although this model class
is designed to perform question answering, its components can also be
reused for image captioning.
To start working with the model, you need to instantiate the
BlipForQuestionAnswering class, using from_pretrained method.
BlipProcessor is a helper class for preparing input data for both
text and vision modalities and postprocessing of generation results.
importtimefromPILimportImagefromtransformersimportBlipProcessor,BlipForQuestionAnswering# Fetch `notebook_utils` moduleimportrequestsr=requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",)open("notebook_utils.py","w").write(r.text)fromnotebook_utilsimportdownload_file,device_widget,quantization_widget# get model and processorprocessor=BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")model=BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")# setup test input: download and read image, prepare questionimg_url="https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"download_file(img_url,"demo.jpg")raw_image=Image.open("demo.jpg").convert("RGB")question="how many dogs are in the picture?"# preprocess input datainputs=processor(raw_image,question,return_tensors="pt")start=time.perf_counter()# perform generationout=model.generate(**inputs)end=time.perf_counter()-start# postprocess resultanswer=processor.decode(out[0],skip_special_tokens=True)
Starting from OpenVINO 2023.0 release, OpenVINO supports direct PyTorch
models conversion to OpenVINO Intermediate Representation (IR) format to
take the advantage of advanced OpenVINO optimization tools and features.
You need to provide a model object, input data for model tracing to
OpenVINO Model Conversion API. ov.convert_model function convert
PyTorch model instance to ov.Model object that can be used for
compilation on device or saved on disk using ov.save_model in
compressed to FP16 format.
The model consists of three parts:
vision_model - an encoder for image representation.
text_encoder - an encoder for input query, used for question
answering and text-to-image retrieval only.
text_decoder - a decoder for output answer.
To be able to perform multiple tasks, using the same model components,
you should convert each part independently.
The vision model accepts float input tensors with the [1,3,384,384]
shape, containing RGB image pixel values normalized in the [0,1] range.
importtorchfrompathlibimportPathimportopenvinoasovVISION_MODEL_OV=Path("blip_vision_model.xml")vision_model=model.vision_modelvision_model.eval()# check that model works and save it outputs for reusage as text encoder inputwithtorch.no_grad():vision_outputs=vision_model(inputs["pixel_values"])# if openvino model does not exist, convert it to IRifnotVISION_MODEL_OV.exists():# export pytorch model to ov.Modelwithtorch.no_grad():ov_vision_model=ov.convert_model(vision_model,example_input=inputs["pixel_values"])# save model on disk for next usagesov.save_model(ov_vision_model,VISION_MODEL_OV)print(f"Vision model successfuly converted and saved to {VISION_MODEL_OV}")else:print(f"Vision model will be loaded from {VISION_MODEL_OV}")
/home/ltalamanova/tmp_venv/lib/python3.11/site-packages/transformers/modeling_utils.py:4225: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead
warnings.warn(
The text encoder is used by visual question answering tasks to build a
question embedding representation. It takes input_ids with a
tokenized question and output image embeddings obtained from the vision
model and attention masks for them.
TEXT_ENCODER_OV=Path("blip_text_encoder.xml")text_encoder=model.text_encodertext_encoder.eval()# if openvino model does not exist, convert it to IRifnotTEXT_ENCODER_OV.exists():# prepare example inputsimage_embeds=vision_outputs[0]image_attention_mask=torch.ones(image_embeds.size()[:-1],dtype=torch.long)input_dict={"input_ids":inputs["input_ids"],"attention_mask":inputs["attention_mask"],"encoder_hidden_states":image_embeds,"encoder_attention_mask":image_attention_mask,}# export PyTorch modelwithtorch.no_grad():ov_text_encoder=ov.convert_model(text_encoder,example_input=input_dict)# save model on disk for next usagesov.save_model(ov_text_encoder,TEXT_ENCODER_OV)print(f"Text encoder successfuly converted and saved to {TEXT_ENCODER_OV}")else:print(f"Text encoder will be loaded from {TEXT_ENCODER_OV}")
The text decoder is responsible for generating the sequence of tokens to
represent model output (answer to question or caption), using an image
(and question, if required) representation. The generation approach is
based on the assumption that the probability distribution of a word
sequence can be decomposed into the product of conditional next word
distributions. In other words, model predicts the next token in the loop
guided by previously generated tokens until the stop-condition will be
not reached (generated sequence of maximum length or end of string token
obtained). The way the next token will be selected over predicted
probabilities is driven by the selected decoding methodology. You can
find more information about the most popular decoding methods in this
blog. The entry point
for the generation process for models from the Hugging Face Transformers
library is the generate method. You can find more information about
its parameters and configuration in the
documentation.
To preserve flexibility in the selection decoding methodology, you will
convert only model inference for one step.
To optimize the generation process and use memory more efficiently, the
use_cache=True option is enabled. Since the output side is
auto-regressive, an output token hidden state remains the same once
computed for every further generation step. Therefore, recomputing it
every time you want to generate a new token seems wasteful. With the
cache, the model saves the hidden state once it has been computed. The
model only computes the one for the most recently generated output token
at each time step, re-using the saved ones for hidden tokens. This
reduces the generation complexity from O(n^3) to O(n^2) for a
transformer model. More details about how it works can be found in this
article.
With this option, the model gets the previous step’s hidden states as
input and additionally provides hidden states for the current step as
output. Initially, you have no previous step hidden states, so the first
step does not require you to provide them, but we should initialize them
by default values. In PyTorch, past hidden state outputs are represented
as a list of pairs (hidden state for key, hidden state for value] for
each transformer layer in the model. OpenVINO model does not support
nested outputs, they will be flattened.
Similar to text_encoder, text_decoder can work with input
sequences of different lengths and requires preserving dynamic input
shapes.
text_decoder=model.text_decodertext_decoder.eval()TEXT_DECODER_OV=Path("blip_text_decoder_with_past.xml")# prepare example inputsinput_ids=torch.tensor([[30522]])# begin of sequence token idattention_mask=torch.tensor([[1]])# attention mask for input_idsencoder_hidden_states=torch.rand((1,10,768))# encoder last hidden state from text_encoderencoder_attention_mask=torch.ones((1,10),dtype=torch.long)# attention mask for encoder hidden statesinput_dict={"input_ids":input_ids,"attention_mask":attention_mask,"encoder_hidden_states":encoder_hidden_states,"encoder_attention_mask":encoder_attention_mask,}text_decoder_outs=text_decoder(**input_dict)# extend input dictionary with hidden states from previous stepinput_dict["past_key_values"]=text_decoder_outs["past_key_values"]text_decoder.config.torchscript=TrueifnotTEXT_DECODER_OV.exists():# export PyTorch modelwithtorch.no_grad():ov_text_decoder=ov.convert_model(text_decoder,example_input=input_dict)# save model on disk for next usagesov.save_model(ov_text_decoder,TEXT_DECODER_OV)print(f"Text decoder successfuly converted and saved to {TEXT_DECODER_OV}")else:print(f"Text decoder will be loaded from {TEXT_DECODER_OV}")
/home/ltalamanova/tmp_venv/lib/python3.11/site-packages/transformers/models/blip/modeling_blip_text.py:635:TracerWarning:ConvertingatensortoaPythonbooleanmightcausethetracetobeincorrect.Wecan't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!ifcausal_mask.shape[1]<attention_mask.shape[1]:/home/ltalamanova/tmp_venv/lib/python3.11/site-packages/torch/jit/_trace.py:165:UserWarning:The.gradattributeofaTensorthatisnotaleafTensorisbeingaccessed.Its.gradattributewon't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)ifa.gradisnotNone:
As discussed before, the model consists of several blocks which can be
reused for building pipelines for different tasks. In the diagram below,
you can see how image captioning works:
The visual model accepts the image preprocessed by BlipProcessor as
input and produces image embeddings, which are directly passed to the
text decoder for generation caption tokens. When generation is finished,
output sequence of tokens is provided to BlipProcessor for decoding
to text using a tokenizer.
The pipeline for question answering looks similar, but with additional
question processing. In this case, image embeddings and question
tokenized by BlipProcessor are provided to the text encoder and then
multimodal question embedding is passed to the text decoder for
performing generation of answers.
The next step is implementing both pipelines using OpenVINO models.
select device from dropdown list for running inference using OpenVINO
device=device_widget()device
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
# load models on deviceov_vision_model=core.compile_model(VISION_MODEL_OV,device.value)ov_text_encoder=core.compile_model(TEXT_ENCODER_OV,device.value)ov_text_decoder_with_past=core.compile_model(TEXT_DECODER_OV,device.value)
The model helper class has two methods for generation:
generate_answer - used for visual question answering,
generate_caption - used for caption generation. For initialization,
model class accepts compiled OpenVINO models for the text encoder,
vision model and text decoder, and also configuration for generation and
initial token for decoder work.
NNCF enables
post-training quantization by adding the quantization layers into the
model graph and then using a subset of the training dataset to
initialize the parameters of these additional quantization layers. The
framework is designed so that modifications to your original training
code are minor.
The optimization process contains the following steps:
Create a dataset for quantization.
Run nncf.quantize to get a quantized model from the pre-trained
FP16 model.
Serialize the INT8 model using openvino.save_model function.
NOTE: Quantization is time and memory consuming operation.
Running quantization code below may take some time. You can disable
it using widget below:
The VQAv2 is a dataset containing
open-ended questions about images. These questions require an
understanding of vision, language and commonsense knowledge to answer.
%%skip not $to_quantize.value
import numpy as np
from datasets import load_dataset
from tqdm.notebook import tqdm
def preprocess_batch(batch, vision_model, inputs_info):
"""
Preprocesses a dataset batch by loading and transforming image and text data.
VQAv2 dataset contains multiple questions to image.
To reduce dataset preparation time we will store preprocessed images in `inputs_info`.
"""
image_id = batch["image_id"]
if image_id in inputs_info:
inputs = processor(text=batch['question'], return_tensors="np")
pixel_values = inputs_info[image_id]["pixel_values"]
encoder_hidden_states = inputs_info[image_id]["encoder_hidden_states"]
else:
inputs = processor(images=batch["image"], text=batch["question"], return_tensors="np")
pixel_values = inputs["pixel_values"]
encoder_hidden_states = vision_model(pixel_values)[vision_model.output(0)]
inputs_info[image_id] = {
"pixel_values": pixel_values,
"encoder_hidden_states": encoder_hidden_states,
"text_encoder_inputs": []
}
text_encoder_inputs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
}
inputs_info[image_id]["text_encoder_inputs"].append(text_encoder_inputs)
def prepare_input_data(dataloader, vision_model, opt_init_steps):
"""
Store calibration subset in List to reduce quantization time.
"""
inputs_info = {}
for idx, batch in enumerate(tqdm(dataloader, total=opt_init_steps, desc="Prepare calibration data")):
preprocess_batch(batch, vision_model, inputs_info)
calibration_subset = []
for image_id in inputs_info:
pixel_values = inputs_info[image_id]["pixel_values"]
encoder_hidden_states = inputs_info[image_id]["encoder_hidden_states"]
encoder_attention_mask = np.ones(encoder_hidden_states.shape[:-1], dtype=int)
for text_encoder_inputs in inputs_info[image_id]["text_encoder_inputs"]:
text_encoder_inputs["encoder_hidden_states"] = encoder_hidden_states
text_encoder_inputs["encoder_attention_mask"] = encoder_attention_mask
blip_inputs = {
"vision_model_inputs": {"pixel_values": pixel_values},
"text_encoder_inputs": text_encoder_inputs,
}
calibration_subset.append(blip_inputs)
return calibration_subset
def prepare_dataset(vision_model, opt_init_steps=300, streaming=False):
"""
Prepares a vision-text dataset for quantization.
"""
split = f"train[:{opt_init_steps}]" if not streaming else "train"
dataset = load_dataset("HuggingFaceM4/VQAv2", split=split, streaming=streaming, trust_remote_code=True)
dataset = dataset.shuffle(seed=42)
if streaming:
dataset = dataset.take(opt_init_steps)
calibration_subset = prepare_input_data(dataset, vision_model, opt_init_steps)
return calibration_subset
Loading and processing the dataset in streaming mode may take a long
time and depends on your internet connection.
The quantization of the text decoder leads to significant accuracy loss.
Instead of post-training quantization, we can use data free weights
compression to reduce the model footprint.
The optimization process contains the following steps:
Run nncf.compress_weights to get a model with compressed weights.
Serialize the OpenVINO model using openvino.save_model
function.
The steps for making predictions with the optimized OpenVINO BLIP model
are similar to the PyTorch model. Let us check the model result using
the same input data like for model before quantization
%%skip not $to_quantize.value
from functools import partial
from transformers import BlipForQuestionAnswering
from blip_model import OVBlipModel, text_decoder_forward
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
text_decoder = model.text_decoder
text_decoder.eval()
text_decoder.forward = partial(text_decoder_forward, ov_text_decoder_with_past=q_ov_text_decoder_with_past)
int8_model = OVBlipModel(model.config, model.decoder_start_token_id, q_ov_vision_model, q_ov_text_encoder, text_decoder)
%%skip not $to_quantize.value
raw_image = Image.open("demo.jpg").convert('RGB')
question = "how many dogs are in the picture?"
# preprocess input data
inputs = processor(raw_image, question, return_tensors="pt")
Compare inference time of the FP16 and optimized models#
To measure the inference performance of the FP16 and INT8
models, we use median inference time on 100 samples of the calibration
dataset. So we can approximately estimate the speed up of the dynamic
quantized models.
NOTE: For the most accurate performance estimation, it is
recommended to run benchmark_app in a terminal/command prompt
after closing other applications with static shapes.
%%skip not $to_quantize.value
import time
import torch
def calculate_inference_time(blip_model, calibration_data, generate_caption):
inference_time = []
for inputs in calibration_data:
pixel_values = torch.from_numpy(inputs["vision_model_inputs"]["pixel_values"])
input_ids = torch.from_numpy(inputs["text_encoder_inputs"]["input_ids"])
attention_mask = torch.from_numpy(inputs["text_encoder_inputs"]["attention_mask"])
start = time.perf_counter()
if generate_caption:
_ = blip_model.generate_caption(pixel_values, max_length=20)
else:
_ = blip_model.generate_answer(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, max_length=20)
end = time.perf_counter()
delta = end - start
inference_time.append(delta)
return np.median(inference_time)
importgradioasgrov_model=int8_modelifuse_quantized_model.valueelseov_modeldefgenerate_answer(img,question):ifimgisNone:raisegr.Error("Please upload an image or choose one from the examples list")start=time.perf_counter()inputs=processor(img,question,return_tensors="pt")output=ov_model.generate_answer(**inputs,max_length=20)iflen(question)elseov_model.generate_caption(inputs["pixel_values"],max_length=20)answer=processor.decode(output[0],skip_special_tokens=True)elapsed=time.perf_counter()-starthtml=f"<p>Processing time: {elapsed:.4f}</p>"returnanswer,html
ifnotPath("gradio_helper.py").exists():r=requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/blip-visual-language-processing/gradio_helper.py")open("gradio_helper.py","w").write(r.text)fromgradio_helperimportmake_demodemo=make_demo(fn=generate_answer)try:demo.launch(debug=False)exceptException:demo.launch(share=True,debug=False)# if you are launching remotely, specify server_name and server_port# demo.launch(server_name='your server name', server_port='server port in int')# Read more in the docs: https://gradio.app/docs/