Run LLM Inference on OpenVINO with the GenAI Flavor#

This guide will show you how to integrate the OpenVINO GenAI flavor into your application, covering loading a model and passing the input context to receive generated text. Note that the vanilla flavor of OpenVINO will not work with these instructions, make sure to install OpenVINO GenAI.

Note

The examples use the CPU as the target device, however, the GPU is also supported. Note that for the LLM pipeline, the GPU is used only for inference, while token selection, tokenization, and detokenization remain on the CPU, for efficiency. Tokenizers are represented as a separate model and also run on the CPU.

  1. Export an LLM model via Hugging Face Optimum-Intel. A chat-tuned TinyLlama model is used in this example:

    optimum-cli export openvino --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --weight-format fp16 --trust-remote-code "TinyLlama-1.1B-Chat-v1.0"
    

    Optional. Optimize the model:

    The model is an optimized OpenVINO IR with FP16 precision. For enhanced LLM performance, it is recommended to use lower precision for model weights, such as INT4, and to compress weights using NNCF during model export directly:

    optimum-cli export openvino --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --weight-format int4 --trust-remote-code "TinyLlama-1.1B-Chat-v1.0"
    
  2. Perform generation using the new GenAI API:

    import openvino_genai as ov_genai
    pipe = ov_genai.LLMPipeline(model_path, "CPU")
    print(pipe.generate("The Sun is yellow because", max_new_tokens=100))
    
    #include "openvino/genai/llm_pipeline.hpp"
    #include <iostream>
    
    int main(int argc, char* argv[]) {
       std::string model_path = argv[1];
       ov::genai::LLMPipeline pipe(model_path, "CPU");
       std::cout << pipe.generate("The Sun is yellow because", ov::genai::max_new_tokens(100));
    }
    

The LLMPipeline is the main object used for decoding. You can construct it directly from the folder with the converted model. It will automatically load the main model, tokenizer, detokenizer, and the default generation configuration.

Once the model is exported from Hugging Face Optimum-Intel, it already contains all the information necessary for execution, including the tokenizer/detokenizer and the generation config, ensuring that its results match those generated by Hugging Face.

Streaming the Output#

For more interactive UIs during generation, streaming of model output tokens is supported. See the example below, where a lambda function outputs words to the console immediately upon generation:

import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "CPU")

streamer = lambda x: print(x, end='', flush=True)
pipe.generate("The Sun is yellow because", streamer=streamer, max_new_tokens=100)
#include "openvino/genai/llm_pipeline.hpp"
#include <iostream>

int main(int argc, char* argv[]) {
   std::string model_path = argv[1];
   ov::genai::LLMPipeline pipe(model_path, "CPU");

   auto streamer = [](std::string word) {
      std::cout << word << std::flush;
      // Return flag indicating whether generation should be stopped.
      // false means continue generation.
      return false;
   };
   pipe.generate("The Sun is yellow because", ov::genai::streamer(streamer), ov::genai::max_new_tokens(100));
}

You can also create your custom streamer for more sophisticated processing:

import openvino_genai as ov_genai

class CustomStreamer(ov_genai.StreamerBase):
   def __init__(self, tokenizer):
      ov_genai.StreamerBase.__init__(self)
      self.tokenizer = tokenizer
   def put(self, token_id) -> bool:
      # Decode tokens and process them.
      # Streamer returns a flag indicating whether generation should be stopped.
      # In Python, `return` can be omitted. In that case, the function will return None
      # which will be converted to False, meaning that generation should continue.
      # return stop_flag
   def end(self):
      # Decode tokens and process them.

pipe = ov_genai.LLMPipeline(model_path, "CPU")
pipe.generate("The Sun is yellow because", streamer=CustomStreamer(), max_new_tokens=100)
#include <streamer_base.hpp>

class CustomStreamer: publict StreamerBase {
public:
   bool put(int64_t token) {
      bool stop_flag = false;
      /*
      custom decoding/tokens processing code
      tokens_cache.push_back(token);
      std::string text = m_tokenizer.decode(tokens_cache);
      ...
      */
      return stop_flag;  // Flag indicating whether generation should be stopped. If True, generation stops.
   };

   void end() {
      /* custom finalization */
   };
};

int main(int argc, char* argv[]) {
   auto custom_streamer = std::make_shared<CustomStreamer>();

   std::string model_path = argv[1];
   ov::genai::LLMPipeline pipe(model_path, "CPU");
   pipe.generate("The Sun is yellow because", ov::genai::streamer(custom_streamer), ov::genai::max_new_tokens(100));
}

Using GenAI in Chat Scenario#

For chat scenarios where inputs and outputs represent a conversation, maintaining KVCache across inputs may prove beneficial. The chat-specific methods start_chat and finish_chat are used to mark a conversation session, as you can see in these simple examples:

import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path)

pipe.set_generation_config({'max_new_tokens': 100)

pipe.start_chat()
while True:
   print('question:')
   prompt = input()
   if prompt == 'Stop!':
      break
   print(pipe.generate(prompt))
pipe.finish_chat()
int main(int argc, char* argv[]) {
   std::string prompt;

   std::string model_path = argv[1];
   ov::genai::LLMPipeline pipe(model_path, "CPU");

   ov::genai::GenerationConfig config = pipe.get_generation_config();
   config.max_new_tokens = 100;
   pipe.set_generation_config(config)

   pipe.start_chat();
   for (size_t i = 0; i < questions.size(); i++) {
      std::cout << "question:\n";
      std::getline(std::cin, prompt);

      std::cout << pipe.generate(prompt) << std::endl;
   }
   pipe.finish_chat();
}

Comparing with Hugging Face Results#

Compare and analyze results with those generated by Hugging Face models.

from transformers import AutoTokenizer, AutoModelForCausalLM
import openvino_genai as ov_genai

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

max_new_tokens = 32
prompt = 'table is made of'

encoded_prompt = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False)
hf_encoded_output = model.generate(encoded_prompt, max_new_tokens=max_new_tokens, do_sample=False)
hf_output = tokenizer.decode(hf_encoded_output[0, encoded_prompt.shape[1]:])
print(f'hf_output: {hf_output}')

pipe = ov_genai.LLMPipeline('TinyLlama-1.1B-Chat-v1.0')
ov_output = pipe.generate(prompt, max_new_tokens=max_new_tokens)
print(f'ov_output: {ov_output}')

assert hf_output == ov_output

GenAI API#

OpenVINO GenAI Flavor includes the following API:

  • generation_config - defines a configuration class for text generation, enabling customization of the generation process such as the maximum length of the generated text, whether to ignore end-of-sentence tokens, and the specifics of the decoding strategy (greedy, beam search, or multinomial sampling).

  • llm_pipeline - provides classes and utilities for text generation, including a pipeline for processing inputs, generating text, and managing outputs with configurable options.

  • streamer_base - an abstract base class for creating streamers.

  • tokenizer - the tokenizer class for text encoding and decoding.

  • visibility - controls the visibility of the GenAI library.

Learn more in the GenAI API reference.

Additional Resources#