This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window.
You can also make a local installation. Choose one of the following options:
This notebook shows text prediction with OpenVINO. This notebook can
work in two different modes, Text Generation and Conversation, which the
user can select via selecting the model in the Model Selection Section.
We use three models
GPT-2,
GPT-Neo, and
PersonaGPT, which are a part of
the Generative Pre-trained Transformer (GPT) family. GPT-2 and GPT-Neo
can be used for text generation, whereas PersonaGPT is trained for the
downstream task of conversation.
GPT-2 and GPT-Neo are pre-trained on a large corpus of English text
using unsupervised training. They both display a broad set of
capabilities, including the ability to generate conditional synthetic
text samples of unprecedented quality, where we prime the model with an
input and have it generate a lengthy continuation.
More details about the models are provided on their HuggingFace cards:
PersonaGPT is an open-domain conversational agent that can decode
personalized and controlled responses based on user input. It is
built on the pretrained
DialoGPT-medium model,
following the GPT-2 architecture.
PersonaGPT is fine-tuned on the
Persona-Chat dataset. The model
is available from
HuggingFace. PersonaGPT
displays a broad set of capabilities, including the ability to take on
personas, where we prime the model with few facts and have it generate
based upon that, it can also be used for creating a chatbot on a
knowledge base.
The following image illustrates the complete demo pipeline used for text
generation:
This is a demonstration in which the user can type the beginning of the
text and the network will generate a further. This procedure can be
repeated as many times as the user desires.
For Text Generation, The model input is tokenized text, which serves as
the initial condition for text generation. Then, logits from the models’
inference results are obtained, and the token with the highest
probability is selected using the top-k sampling strategy and joined to
the input sequence. This procedure repeats until the end of the sequence
token is received or the specified maximum length is reached. After
that, tokenized IDs are decoded to text.
The following image illustrates the demo pipeline for conversation:
For Conversation, User Input is tokenized with eos_token
concatenated in the end. Then, the text gets generated as detailed
above. The Generated response is added to the history with the
eos_token at the end. Additional user input is added to the history,
and the sequence is passed back into the model.
For starting work with GPT-Neo model using OpenVINO, a model should be
converted to OpenVINO Intermediate Representation (IR) format.
HuggingFace provides a GPT-Neo model in PyTorch format, which is
supported in OpenVINO via Model Conversion API. The ov.convert_model
Python function of model conversion
API
can be used for converting the model. The function returns instance of
OpenVINO Model class, which is ready to use in Python interface. The
Model can also be save on device in OpenVINO IR format for future
execution using ov.save_model. In our case dynamic input shapes with
a possible shape range (from 1 token to a maximum length defined in our
processing function) are specified for optimization of memory
consumption.
frompathlibimportPathimporttorchimportopenvinoasov# define path for saving openvino modelmodel_path=Path("model/text_generator.xml")example_input={"input_ids":torch.ones((1,10),dtype=torch.long),"attention_mask":torch.ones((1,10),dtype=torch.long)}pt_model.config.torchscript=True# convert model to openvinoifmodel_name.value=="PersonaGPT (Converastional)":ov_model=ov.convert_model(pt_model,example_input=example_input,input=[('input_ids',[1,-1],ov.Type.i64),('attention_mask',[1,-1],ov.Type.i64)])else:ov_model=ov.convert_model(pt_model,example_input=example_input,input=[('input_ids',[1,ov.Dimension(1,128)],ov.Type.i64),('attention_mask',[1,ov.Dimension(1,128)],ov.Type.i64)])# serialize openvino modelov.save_model(ov_model,str(model_path))
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py:801:TracerWarning:ConvertingatensortoaPythonbooleanmightcausethetracetobeincorrect.Wecan't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!ifbatch_size<=0:
We start by building an OpenVINO Core object. Then we read the network
architecture and model weights from the .xml and .bin files,
respectively. Finally, we compile the model for the desired device.
# read the model and corresponding weights from filemodel=core.read_model(model_path)
# compile the model for CPU devicescompiled_model=core.compile_model(model=model,device_name=device.value)# get output tensorsoutput_key=compiled_model.output(0)
Input keys are the names of the input nodes and output keys contain
names of the output nodes of the network. In the case of GPT-Neo, we
have batchsize and sequencelength as inputs and
batchsize, sequencelength and vocabsize as outputs.
NLP models often take a list of tokens as a standard input. A token is a
word or a part of a word mapped to an integer. To provide the proper
input, we use a vocabulary file to handle the mapping. So first let’s
load the vocabulary file.
fromtypingimportList,Tuple# this function converts text to tokensdeftokenize(text:str)->Tuple[List[int],List[int]]:""" tokenize input text using GPT2 tokenizer Parameters: text, str - input text Returns: input_ids - np.array with input token ids attention_mask - np.array with 0 in place, where should be padding and 1 for places where original tokens are located, represents attention mask for model """inputs=tokenizer(text,return_tensors="np")returninputs["input_ids"],inputs["attention_mask"]
eos_token is special token, which means that generation is finished.
We store the index of this token in order to use this index as padding
at later stage.
2024-02-09 23:53:22.771432: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-02-09 23:53:22.804649: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
If the minimum sequence length is not reached, the following code will
reduce the probability of the eos token occurring. This continues
the process of generating the next words.
defprocess_logits(cur_length:int,scores:np.array,eos_token_id:int,min_length:int=0)->np.array:""" Reduce probability for padded indices. Parameters: cur_length: Current length of input sequence. scores: Model output logits. eos_token_id: Index of end of string token in model vocab. min_length: Minimum length for applying postprocessing. Returns: Processed logits with reduced probability for padded indices. """ifcur_length<min_length:scores[:,eos_token_id]=-float("inf")returnscores
In Top-K sampling, we filter the K most likely next words and
redistribute the probability mass among only those K next words.
defget_top_k_logits(scores:np.array,top_k:int)->np.array:""" Perform top-k sampling on the logits scores. Parameters: scores: np.array, model output logits. top_k: int, number of elements with the highest probability to select. Returns: np.array, shape (batch_size, sequence_length, vocab_size), filtered logits scores where only the top-k elements with the highest probability are kept and the rest are replaced with -inf """filter_value=-float("inf")top_k=min(max(top_k,1),scores.shape[-1])top_k_scores=-np.sort(-scores)[:,:top_k]indices_to_remove=scores<np.min(top_k_scores)filtred_scores=np.ma.array(scores,mask=indices_to_remove,fill_value=filter_value).filled()returnfiltred_scores
defgenerate_sequence(input_ids:List[int],attention_mask:List[int],max_sequence_length:int=128,eos_token_id:int=eos_token_id,dynamic_shapes:bool=True)->List[int]:""" Generates a sequence of tokens using a pre-trained language model. Parameters: input_ids: np.array, tokenized input ids for model attention_mask: np.array, attention mask for model max_sequence_length: int, maximum sequence length for stopping iteration eos_token_id: int, index of the end-of-sequence token in the model's vocabulary dynamic_shapes: bool, whether to use dynamic shapes for inference or pad model input to max_sequence_length Returns: np.array, the predicted sequence of token ids """whileTrue:cur_input_len=len(input_ids[0])ifnotdynamic_shapes:pad_len=max_sequence_length-cur_input_lenmodel_input_ids=np.concatenate((input_ids,[[eos_token_id]*pad_len]),axis=-1)model_input_attention_mask=np.concatenate((attention_mask,[[0]*pad_len]),axis=-1)else:model_input_ids=input_idsmodel_input_attention_mask=attention_maskoutputs=compiled_model({"input_ids":model_input_ids,"attention_mask":model_input_attention_mask})[output_key]next_token_logits=outputs[:,cur_input_len-1,:]# pre-process distributionnext_token_scores=process_logits(cur_input_len,next_token_logits,eos_token_id)top_k=20next_token_scores=get_top_k_logits(next_token_scores,top_k)# get next token idprobs=softmax(next_token_scores)next_tokens=np.random.choice(probs.shape[-1],1,p=probs[0],replace=True)# break the loop if max length or end of text token is reachedifcur_input_len==max_sequence_lengthornext_tokens[0]==eos_token_id:breakelse:input_ids=np.concatenate((input_ids,[next_tokens]),axis=-1)attention_mask=np.concatenate((attention_mask,[[1]*len(next_tokens)]),axis=-1)returninput_ids
The text variable below is the input used to generate a predicted
sequence.
importtimeifnotmodel_name.value=="PersonaGPT (Converastional)":text="Deep learning is a type of machine learning that uses neural networks"input_ids,attention_mask=tokenize(text)start=time.perf_counter()output_ids=generate_sequence(input_ids,attention_mask)end=time.perf_counter()output_text=" "# Convert IDs to words and make the sentence from itforiinoutput_ids[0]:output_text+=tokenizer.batch_decode([i])[0]print(f"Generation took {end-start:.3f} s")print(f"Input Text: {text}")print()print(f"{model_name.value}: {output_text}")else:print("Selected Model is PersonaGPT. Please select GPT-Neo or GPT-2 in the first cell to generate text sequences")
User Input is tokenized with eos_token concatenated in the end.
Model input is tokenized text, which serves as initial condition for
generation, then logits from model inference result should be obtained
and token with the highest probability is selected using top-k sampling
strategy and joined to input sequence. The procedure repeats until end
of sequence token will be received or specified maximum length is
reached. After that, decoding token ids to text using tokenized should
be applied.
The Generated response is added to the history with the eos_token at
the end. Further User Input is added to it and again passed into the
model.
Wrapper on generate sequence function to support conversation
defconverse(input:str,history:List[int],eos_token:str=eos_token,eos_token_id:int=eos_token_id)->Tuple[str,List[int]]:""" Converse with the Model. Parameters: input: Text input given by the User history: Chat History, ids of tokens of chat occured so far eos_token: end of sequence string eos_token_id: end of sequence index from vocab Returns: response: Text Response generated by the model history: Chat History, Ids of the tokens of chat occured so far,including the tokens of generated response """# Get Input Ids of the User Inputnew_user_input_ids,_=tokenize(input+eos_token)# append the new user input tokens to the chat history, if history existsiflen(history)==0:bot_input_ids=new_user_input_idselse:bot_input_ids=np.concatenate([history,new_user_input_ids[0]])bot_input_ids=np.expand_dims(bot_input_ids,axis=0)# Create Attention Maskbot_attention_mask=np.ones_like(bot_input_ids)# Generate Response from the modelhistory=generate_sequence(bot_input_ids,bot_attention_mask,max_sequence_length=1000)# Add the eos_token to mark end of sequencehistory=np.append(history[0],eos_token_id)# convert the tokens to text, and then split the responses into lines and retrieve the response from the Modelresponse=''.join(tokenizer.batch_decode(history)).split(eos_token)[-2]returnresponse,history
classConversation:def__init__(self):# Initialize Empty Historyself.history=[]self.messages=[]defchat(self,input_text):""" Wrapper Over Converse Function. Parameters: input_text: Text input given by the User Returns: response: Text Response generated by the model """response,self.history=converse(input_text,self.history)self.messages.append(f"Person: {input_text}")self.messages.append(f"PersonaGPT: {response}")returnresponse
This notebook provides two styles of inference, Plain and Interactive.
The style of inference can be selected in the next cell.
importgradioasgrifmodel_name.value=="PersonaGPT (Converastional)":ifinteractive_mode.value=='Plain':conversation=Conversation()user_prompt=Nonepre_written_prompts=["Hi,How are you?","What are you doing?","I like to dance,do you?","Can you recommend me some books?"]# Number of responses generated by modeln_prompts=10foriinrange(n_prompts):# Uncomment for taking User Input# user_prompt = input()ifnotuser_prompt:user_prompt=pre_written_prompts[i%len(pre_written_prompts)]conversation.chat(user_prompt)print(conversation.messages[-2])print(conversation.messages[-1])user_prompt=Noneelse:defadd_text(history,text):history=history+[(text,None)]returnhistory,""conversation=Conversation()defbot(history):conversation.chat(history[-1][0])response=conversation.messages[-1]history[-1][1]=responsereturnhistorywithgr.Blocks()asdemo:chatbot=gr.Chatbot([],elem_id="chatbot")withgr.Row():withgr.Column():txt=gr.Textbox(show_label=False,placeholder="Enter text and press enter, or upload an image",container=False)txt.submit(add_text,[chatbot,txt],[chatbot,txt]).then(bot,chatbot,chatbot)try:demo.launch(debug=False)exceptException:demo.launch(debug=False,share=True)# if you are launching remotely, specify server_name and server_port# demo.launch(server_name='your server name', server_port='server port in int')# Read more in the docs: https://gradio.app/docs/else:print("Selected Model is not PersonaGPT, Please select PersonaGPT in the first cell to have a conversation")
Person: Hi,How are you?
PersonaGPT: i am alright. do you have any siblings?
Person: What are you doing?
PersonaGPT: i am busy with school. do you like to read?
Person: I like to dance,do you?
PersonaGPT: i do not. are you a professional dancer?
Person: Can you recommend me some books?
PersonaGPT: i think the bible is a good starting point
Person: Hi,How are you?
PersonaGPT: i'm okay thanks for asking.
Person: What are you doing?
PersonaGPT: i'm just reading.
Person: I like to dance,do you?
PersonaGPT: i do not but i like reading.
Person: Can you recommend me some books?
PersonaGPT: i guess not. i don't have any siblings.
Person: Hi,How are you?
PersonaGPT: i'm good thanks for asking.
Person: What are you doing?
PersonaGPT: i am practicing my dance moves.