Sound Generation with AudioLDM2 and OpenVINO™#
This Jupyter notebook can be launched after a local installation only.
AudioLDM 2 is a latent text-to-audio diffusion model capable of generating realistic audio samples given any text input.
AudioLDM 2 was proposed in the paper AudioLDM 2: Learning Holistic
Audio Generation with Self-supervised
Pretraining by Haohe Liu
et
al.
The model takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional sound effects, human speech and music.
In this tutorial we will try out the pipeline, convert the models backing it one by one and will run an interactive app with Gradio!
Table of contents:
Installation Instructions#
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
Prerequisites#
%pip install -qU accelerate "diffusers>=0.30.0" "transformers>=4.43" "torch>=2.1" "gradio>=4.19" "peft>=0.6.2" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "openvino>=2024.0.0"
Instantiating Generation Pipeline#
To work with AudioLDM 2 by
Centre for Vision, Speech and Signal Processing - University of Surrey,
we will use Hugging Face Diffusers
package. Diffusers package
exposes the AudioLDM2Pipeline
class, simplifying the model
instantiation and weights loading. The code below demonstrates how to
create a AudioLDM2Pipeline
and generate a text-conditioned sound
sample.
from collections import namedtuple
from functools import partial
import gc
from pathlib import Path
from diffusers import AudioLDM2Pipeline
from IPython.display import Audio
import numpy as np
import openvino as ov
import torch
MODEL_ID = "cvssp/audioldm2"
pipe = AudioLDM2Pipeline.from_pretrained(MODEL_ID)
prompt = "birds singing in the forest"
negative_prompt = "Low quality"
audio = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=150,
audio_length_in_s=3.0,
).audios[0]
sampling_rate = 16000
Audio(audio, rate=sampling_rate)
2024-08-14 08:47:45.295561: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0. 2024-08-14 08:47:45.297388: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used. 2024-08-14 08:47:45.331810: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-08-14 08:47:45.955688: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Fetching 26 files: 0%| | 0/26 [00:00<?, ?it/s]
model.safetensors: 68%|######7 | 524M/776M [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors: 29%|##8 | 397M/1.39G [00:00<?, ?B/s]
model.safetensors: 88%|########8 | 440M/498M [00:00<?, ?B/s]
model.safetensors: 31%|### | 419M/1.36G [00:00<?, ?B/s]
Loading pipeline components...: 0%| | 0/11 [00:00<?, ?it/s]
0%| | 0/150 [00:00<?, ?it/s]
Convert models to OpenVINO Intermediate representation (IR) format#
Model conversion
API
enables direct conversion of PyTorch models backing the pipeline. We
need to provide a model object, input data for model tracing to
ov.convert_model
function to obtain OpenVINO ov.Model
object
instance. Model can be saved on disk for next deployment using
ov.save_model
function.
The pipeline consists of seven important parts:
T5 and CLAP Text Encoders for creation condition to generate an sound from a text prompt.
Projection model to merge outputs from the two text encoders.
GPT-2 language model to generate a sequence of hidden-states conditioned on the projected outputs from the two text encoders.
Vocoder to convert the mel-spectrogram latents to the final audio waveform.
Unet for step-by-step denoising latent image representation.
Autoencoder (VAE) for decoding latent space to image.
models_base_folder = Path("models")
def cleanup_torchscript_cache():
"""
Helper for removing cached model representation
"""
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._clear_class_state()
CLAP Text Encoder Conversion#
First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model CLAP, specifically the laion/clap-htsat-unfused variant. The text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to rank generated waveforms against the text prompt by computing similarity scores.
class ClapEncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
encoder.eval()
self.encoder = encoder
def forward(self, input_ids, attention_mask):
return self.encoder.get_text_features(input_ids, attention_mask)
clap_text_encoder_ir_path = models_base_folder / "clap_text_encoder.xml"
if not clap_text_encoder_ir_path.exists():
with torch.no_grad():
ov_model = ov.convert_model(
ClapEncoderWrapper(pipe.text_encoder), # model instance
example_input={
"input_ids": torch.ones((1, 512), dtype=torch.long),
"attention_mask": torch.ones((1, 512), dtype=torch.long),
}, # inputs for model tracing
)
ov.save_model(ov_model, clap_text_encoder_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("Text Encoder successfully converted to IR")
else:
print(f"Text Encoder will be loaded from {clap_text_encoder_ir_path}")
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[ WARNING ] Please fix your imports. Module %s has been moved to %s. The old module will be deleted in version %s. /home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/transformers/modeling_utils.py:4664: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead warnings.warn(
Text Encoder successfully converted to IR
T5 Text Encoder Conversion#
As second frozen text-encoder, AudioLDM2 uses the T5, specifically the google/flan-t5-large variant.
The text-encoder is responsible for transforming the input prompt, for example, “birds singing in the forest” into an embedding space that can be understood by the U-Net. It is usually a simple transformer-based encoder that maps a sequence of input tokens to a sequence of latent text embeddings.
The input of the text encoder is tensor input_ids
, which contains
indexes of tokens from text processed by the tokenizer and padded to the
maximum length accepted by the model. Model outputs are two tensors:
last_hidden_state
- hidden state from the last MultiHeadAttention
layer in the model and pooler_out
- pooled output for whole model
hidden states.
t5_text_encoder_ir_path = models_base_folder / "t5_text_encoder.xml"
if not t5_text_encoder_ir_path.exists():
pipe.text_encoder_2.eval()
with torch.no_grad():
ov_model = ov.convert_model(
pipe.text_encoder_2, # model instance
example_input=torch.ones((1, 7), dtype=torch.long), # inputs for model tracing
)
ov.save_model(ov_model, t5_text_encoder_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("Text Encoder successfully converted to IR")
else:
print(f"Text Encoder will be loaded from {t5_text_encoder_ir_path}")
Text Encoder successfully converted to IR
Projection model conversion#
A trained model used to linearly project the hidden-states from the first and second text encoder models and insert learned Start Of Sequence and End Of Sequence token embeddings. The projected hidden-states from the two text encoders are concatenated to give the input to the language model.
projection_model_ir_path = models_base_folder / "projection_model.xml"
projection_model_inputs = {
"hidden_states": torch.randn((1, 1, 512), dtype=torch.float32),
"hidden_states_1": torch.randn((1, 7, 1024), dtype=torch.float32),
"attention_mask": torch.ones((1, 1), dtype=torch.int64),
"attention_mask_1": torch.ones((1, 7), dtype=torch.int64),
}
if not projection_model_ir_path.exists():
pipe.projection_model.eval()
with torch.no_grad():
ov_model = ov.convert_model(
pipe.projection_model, # model instance
example_input=projection_model_inputs, # inputs for model tracing
)
ov.save_model(ov_model, projection_model_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("The Projection Model successfully converted to IR")
else:
print(f"The Projection Model will be loaded from {projection_model_ir_path}")
The Projection Model successfully converted to IR
GPT-2 conversion#
GPT-2 is an auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected outputs from the two text encoders.
language_model_ir_path = models_base_folder / "language_model.xml"
language_model_inputs = {
"inputs_embeds": torch.randn((1, 12, 768), dtype=torch.float32),
"attention_mask": torch.ones((1, 12), dtype=torch.int64),
}
if not language_model_ir_path.exists():
pipe.language_model.config.torchscript = True
pipe.language_model.eval()
pipe.language_model.__call__ = partial(
pipe.language_model.__call__,
kwargs={"past_key_values": None, "use_cache": False, "return_dict": False},
)
with torch.no_grad():
ov_model = ov.convert_model(
pipe.language_model, # model instance
example_input=language_model_inputs, # inputs for model tracing
)
ov_model.inputs[0].get_node().set_partial_shape(ov.PartialShape([1, -1]))
ov_model.inputs[0].get_node().set_element_type(ov.Type.i64)
ov_model.inputs[1].get_node().set_partial_shape(ov.PartialShape([1, -1, 768]))
ov_model.inputs[1].get_node().set_element_type(ov.Type.f32)
ov_model.validate_nodes_and_infer_types()
ov.save_model(ov_model, language_model_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("The Projection Model successfully converted to IR")
else:
print(f"The Projection Model will be loaded from {language_model_ir_path}")
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/transformers/modeling_attn_mask_utils.py:114: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/transformers/modeling_attn_mask_utils.py:162: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if past_key_values_length > 0:
The Projection Model successfully converted to IR
Vocoder conversion#
SpeechT5 HiFi-GAN Vocoder is used to convert the mel-spectrogram latents to the final audio waveform.
vocoder_ir_path = models_base_folder / "vocoder.xml"
if not vocoder_ir_path.exists():
pipe.vocoder.eval()
with torch.no_grad():
ov_model = ov.convert_model(
pipe.vocoder, # model instance
example_input=torch.ones((1, 700, 64), dtype=torch.float32), # inputs for model tracing
)
ov.save_model(ov_model, vocoder_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("The Vocoder successfully converted to IR")
else:
print(f"The Vocoder will be loaded from {vocoder_ir_path}")
The Vocoder successfully converted to IR
UNet conversion#
The UNet model is used to denoise the encoded audio latents. The process of UNet model conversion remains the same, like for original Stable Diffusion model.
unet_ir_path = models_base_folder / "unet.xml"
pipe.unet.eval()
unet_inputs = {
"sample": torch.randn((2, 8, 75, 16), dtype=torch.float32),
"timestep": torch.tensor(1, dtype=torch.int64),
"encoder_hidden_states": torch.randn((2, 8, 768), dtype=torch.float32),
"encoder_hidden_states_1": torch.randn((2, 7, 1024), dtype=torch.float32),
"encoder_attention_mask_1": torch.ones((2, 7), dtype=torch.int64),
}
if not unet_ir_path.exists():
with torch.no_grad():
ov_model = ov.convert_model(pipe.unet, example_input=unet_inputs)
ov_model.inputs[0].get_node().set_partial_shape(ov.PartialShape((2, 8, -1, 16)))
ov_model.inputs[2].get_node().set_partial_shape(ov.PartialShape((2, 8, 768)))
ov_model.inputs[3].get_node().set_partial_shape(ov.PartialShape((2, -1, 1024)))
ov_model.inputs[4].get_node().set_partial_shape(ov.PartialShape((2, -1)))
ov_model.validate_nodes_and_infer_types()
ov.save_model(ov_model, unet_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("Unet successfully converted to IR")
else:
print(f"Unet will be loaded from {unet_ir_path}")
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/pipelines/audioldm2/modeling_audioldm2.py:736: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/models/downsampling.py:136: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert hidden_states.shape[1] == self.channels
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/models/downsampling.py:145: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert hidden_states.shape[1] == self.channels
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/models/attention_processor.py:613: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if current_length != target_length:
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/models/attention_processor.py:628: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attention_mask.shape[0] < batch_size * head_size:
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/models/upsampling.py:146: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert hidden_states.shape[1] == self.channels
/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/diffusers/models/upsampling.py:162: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if hidden_states.shape[0] >= 64:
Unet successfully converted to IR
VAE Decoder conversion#
The VAE model has two parts, an encoder, and a decoder. The encoder is used to convert the image into a low-dimensional latent representation, which will serve as the input to the U-Net model. The decoder, conversely, transforms the latent representation back into an image.
During latent diffusion training, the encoder is used to get the latent representations (latents) of the images for the forward diffusion process, which applies more and more noise at each step. During inference, the denoised latents generated by the reverse diffusion process are converted back into images using the VAE decoder. During inference, we will see that we only need the VAE decoder. You can find instructions on how to convert the encoder part in a stable diffusion notebook.
vae_ir_path = models_base_folder / "vae.xml"
class VAEDecoderWrapper(torch.nn.Module):
def __init__(self, vae):
super().__init__()
vae.eval()
self.vae = vae
def forward(self, latents):
return self.vae.decode(latents)
if not vae_ir_path.exists():
vae_decoder = VAEDecoderWrapper(pipe.vae)
latents = torch.zeros((1, 8, 175, 16))
vae_decoder.eval()
with torch.no_grad():
ov_model = ov.convert_model(vae_decoder, example_input=latents)
ov.save_model(ov_model, vae_ir_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print("VAE decoder successfully converted to IR")
else:
print(f"VAE decoder will be loaded from {vae_ir_path}")
VAE decoder successfully converted to IR
Select inference device for AudioLDM2 pipeline#
select device from dropdown list for running inference using OpenVINO
import ipywidgets as widgets
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value="CPU",
description="Device:",
disabled=False,
)
device
Dropdown(description='Device:', options=('CPU', 'AUTO'), value='CPU')
Adapt OpenVINO models to the original pipeline#
Here we create wrapper classes for all three OpenVINO models that we
want to embed in the original inference pipeline. Here are some of the
things to consider when adapting an OV model: - Make sure that
parameters passed by the original pipeline are forwarded to the compiled
OV model properly; sometimes the OV model uses only a portion of the
input arguments and some are ignored, sometimes you need to convert the
argument to another data type or unwrap some data structures such as
tuples or dictionaries. - Do guarantee that the wrapper class returns
results to the pipeline in an expected format. In the example below you
can see how we pack OV model outputs into special named tuples to adapt
them for the pipeline. - Pay attention to the model method used in the
original pipeline for calling the model - it may be not the forward
method! Refer to the OVClapEncoderWrapper
to see how we wrap OV
model inference into the get_text_features
method.
class OVClapEncoderWrapper:
def __init__(self, encoder_ir, config):
self.encoder = core.compile_model(encoder_ir, device.value)
self.config = config
def get_text_features(self, input_ids, attention_mask, **_):
last_hidden_state = self.encoder([input_ids, attention_mask])[0]
return torch.from_numpy(last_hidden_state)
class OVT5EncoderWrapper:
def __init__(self, encoder_ir, config):
self.encoder = core.compile_model(encoder_ir, device.value)
self.config = config
self.dtype = self.config.torch_dtype
def __call__(self, input_ids, **_):
last_hidden_state = self.encoder(input_ids)[0]
return torch.from_numpy(last_hidden_state)[None, ...]
class OVVocoderWrapper:
def __init__(self, vocoder_ir, config):
self.vocoder = core.compile_model(vocoder_ir, device.value)
self.config = config
def __call__(self, mel_spectrogram, **_):
waveform = self.vocoder(mel_spectrogram)[0]
return torch.from_numpy(waveform)
class OVProjectionModelWrapper:
def __init__(self, proj_model_ir, config):
self.proj_model = core.compile_model(proj_model_ir, device.value)
self.config = config
self.output_type = namedtuple("ProjectionOutput", ["hidden_states", "attention_mask"])
def __call__(self, hidden_states, hidden_states_1, attention_mask, attention_mask_1, **_):
output = self.proj_model(
{
"hidden_states": hidden_states,
"hidden_states_1": hidden_states_1,
"attention_mask": attention_mask,
"attention_mask_1": attention_mask_1,
}
)
return self.output_type(torch.from_numpy(output[0]), torch.from_numpy(output[1]))
class OVUnetWrapper:
def __init__(self, unet_ir, config):
self.unet = core.compile_model(unet_ir, device.value)
self.config = config
def __call__(self, sample, timestep, encoder_hidden_states, encoder_hidden_states_1, encoder_attention_mask_1, **_):
output = self.unet(
{
"sample": sample,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_1": encoder_hidden_states_1,
"encoder_attention_mask_1": encoder_attention_mask_1,
}
)
return (torch.from_numpy(output[0]),)
class OVVaeDecoderWrapper:
def __init__(self, vae_ir, config):
self.vae = core.compile_model(vae_ir, device.value)
self.config = config
self.output_type = namedtuple("VaeOutput", ["sample"])
def decode(self, latents, **_):
last_hidden_state = self.vae(latents)[0]
return self.output_type(torch.from_numpy(last_hidden_state))
def generate_language_model(gpt_2: ov.CompiledModel, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, max_new_tokens: int = 8, **_) -> torch.Tensor:
"""
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
"""
if not max_new_tokens:
max_new_tokens = 8
inputs_embeds = inputs_embeds.cpu().numpy()
attention_mask = attention_mask.cpu().numpy()
for _ in range(max_new_tokens):
# forward pass to get next hidden states
output = gpt_2({"inputs_embeds": inputs_embeds, "attention_mask": attention_mask})
next_hidden_states = output[0]
# Update the model input
inputs_embeds = np.concatenate([inputs_embeds, next_hidden_states[:, -1:, :]], axis=1)
attention_mask = np.concatenate([attention_mask, np.ones((attention_mask.shape[0], 1))], axis=1)
return torch.from_numpy(inputs_embeds[:, -max_new_tokens:, :])
Now we initialize the wrapper objects and load them to the HF pipeline
pipe = AudioLDM2Pipeline.from_pretrained(MODEL_ID)
pipe.config.torchscript = True
pipe.config.return_dict = False
np.random.seed(0)
torch.manual_seed(0)
pipe.text_encoder = OVClapEncoderWrapper(clap_text_encoder_ir_path, pipe.text_encoder.config)
pipe.text_encoder_2 = OVT5EncoderWrapper(t5_text_encoder_ir_path, pipe.text_encoder_2.config)
pipe.projection_model = OVProjectionModelWrapper(projection_model_ir_path, pipe.projection_model.config)
pipe.vocoder = OVVocoderWrapper(vocoder_ir_path, pipe.vocoder.config)
pipe.unet = OVUnetWrapper(unet_ir_path, pipe.unet.config)
pipe.vae = OVVaeDecoderWrapper(vae_ir_path, pipe.vae.config)
pipe.generate_language_model = partial(generate_language_model, core.compile_model(language_model_ir_path, device.value))
gc.collect()
prompt = "birds singing in the forest"
negative_prompt = "Low quality"
audio = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=150,
audio_length_in_s=3.0,
).audios[0]
sampling_rate = 16000
Audio(audio, rate=sampling_rate)
Loading pipeline components...: 0%| | 0/11 [00:00<?, ?it/s]
0%| | 0/150 [00:00<?, ?it/s]
Try out the converted pipeline#
Now, we are ready to start generation. For improving the generation
process, we also introduce an opportunity to provide a
negative prompt
. Technically, positive prompt steers the diffusion
toward the output associated with it, while negative prompt steers the
diffusion away from it. The demo app below is created using Gradio
package
import gradio as gr
def _generate(
prompt,
negative_prompt,
audio_length_in_s,
num_inference_steps,
_=gr.Progress(track_tqdm=True),
):
"""Gradio backing function."""
audio_values = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
audio_length_in_s=audio_length_in_s,
)
waveform = audio_values[0].squeeze() * 2**15
return (sampling_rate, waveform.astype(np.int16))
import requests
if not Path("gradio_helper.py").exists():
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/sound-generation-audioldm2/gradio_helper.py")
open("gradio_helper.py", "w").write(r.text)
from gradio_helper import make_demo
demo = make_demo(fn=_generate)
try:
demo.queue().launch(debug=False)
except Exception:
demo.queue().launch(share=True, debug=False)
# If you are launching remotely, specify server_name and server_port
# EXAMPLE: `demo.launch(server_name="your server name", server_port="server port in int")`
# To learn more please refer to the Gradio docs: https://gradio.app/docs/
# please uncomment and run this cell for stopping gradio interface
# demo.close()