Controllable Music Generation with MusicGen and OpenVINO

MusicGen is a single-stage auto-regressive Transformer model capable of generating high-quality music samples conditioned on text descriptions or audio prompts. The text prompt is passed to a text encoder model (T5) to obtain a sequence of hidden-state representations. These hidden states are fed to MusicGen, which predicts discrete audio tokens (audio codes). Finally, audio tokens are then decoded using an audio compression model (EnCodec) to recover the audio waveform.



The MusicGen model does not require a self-supervised semantic representation of the text/audio prompts; it operates over several streams of compressed discrete music representation with efficient token interleaving patterns, thus eliminating the need to cascade multiple models to predict a set of codebooks (e.g. hierarchically or upsampling). Unlike prior models addressing music generation, it is able to generate all the codebooks in a single forward pass.

In this tutorial, we consider how to run the MusicGen model using OpenVINO.

We will use a model implementation from the Hugging Face Transformers library.

Table of contents:


Install requirements

%pip install -q "openvino>=2023.1.0"
%pip install -q --extra-index-url torch onnx gradio
%pip install -q transformers
from collections import namedtuple
import gc
from pathlib import Path
from typing import Optional, Tuple
import warnings

from IPython.display import Audio
from openvino import Core, convert_model, PartialShape, save_model, Type
import numpy as np
import torch
from torch.jit import TracerWarning
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

# Ignore tracing warnings
warnings.filterwarnings("ignore", category=TracerWarning)
MusicGen in HF Transformers

To work with MusicGen by Meta AI, we will use Hugging Face Transformers package. Transformers package exposes the MusicgenForConditionalGeneration class, simplifying the model instantiation and weights loading. The code below demonstrates how to create a MusicgenForConditionalGeneration and generate a text-conditioned music sample.

# Load the pipeline
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False)

In the cell below user is free to change PyTorch model inference device and the desired music sample length.

device = "cpu"
sample_length = 8  # seconds

n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3
sampling_rate = model.config.audio_encoder.sampling_rate
print('Sampling rate is', sampling_rate, 'Hz')
Sampling rate is 32000 Hz

Original Pipeline Inference

Text Preprocessing prepares the text prompt to be fed into the model, the processor object abstracts this step for us. Text tokenization is performed under the hood, it assigning tokens or IDs to the words; in other words, token IDs are just indices of the words in the model vocabulary. It helps the model understand the context of a sentence.

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")

inputs = processor(
    text=["80s pop track with bassy drums and synth"],

audio_values = model.generate(**, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)

Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)

Convert models to OpenVINO Intermediate representation (IR) format

Model conversion API enables direct conversion of PyTorch models. We will utilize the openvino.convert_model method to acquire OpenVINO IR versions of the models. The method requires a model object and example input for model tracing. Under the hood, the converter will use the PyTorch JIT compiler, to build a frozen model graph.

The pipeline consists of three important parts:

  • The T5 text encoder that translates user prompts into vectors in the latent space that the next model - the MusicGen decoder can utilize.

  • The MusicGen Language Model that auto-regressively generates audio tokens (codes).

  • The EnCodec model (we will use only the decoder part of it) is used to decode the audio waveform from the audio tokens predicted by the MusicGen Language Model.

Let us convert each model step by step.

0. Set Up Variables

models_dir = Path("./models")
t5_ir_path = models_dir / "t5.xml"
musicgen_0_ir_path = models_dir / "mg_0.xml"
musicgen_ir_path = models_dir / "mg.xml"
audio_decoder_onnx_path = models_dir / "encodec.onnx"
audio_decoder_ir_path = models_dir / "encodec.xml"

1. Convert Text Encoder

The text encoder is responsible for converting the input prompt, such as “90s rock song with loud guitars and heavy drums” into an embedding space that can be fed to the next model. Typically, it is a transformer-based encoder that maps a sequence of input tokens to a sequence of text embeddings.

The input for the text encoder consists of a tensor input_ids, which contains token indices from the text processed by the tokenizer and attention_mask that we will ignore as we will process one prompt at a time and this vector will just consist of ones.

We use OpenVINO Converter (OVC) below to convert the PyTorch model to the OpenVINO Intermediate Representation format (IR), which you can infer later with OpenVINO runtime

if not t5_ir_path.exists():
    t5_ov = convert_model(model.text_encoder, example_input={'input_ids': inputs['input_ids']})

    save_model(t5_ov, t5_ir_path)
    del t5_ov
2. Convert MusicGen Language Model

This model is the central part of the whole pipeline, it takes the embedded text representation and generates audio codes that can be then decoded into actual music. The model outputs several streams of audio codes - tokens sampled from the pre-trained codebooks representing music efficiently with a lower frame rate. The model employs innovative codes intervaling strategy, that makes single-stage generation possible.

On the 0th generation step the model accepts input_ids representing the indices of audio codes, encoder_hidden_states and encoder_attention_mask that were provided by the text encoder.

# Set model config `torchscript` to True, so the model returns a tuple as output
model.decoder.config.torchscript = True

if not musicgen_0_ir_path.exists():
    decoder_input = {
        'input_ids': torch.ones(8, 1, dtype=torch.int64),
        'encoder_hidden_states': torch.ones(2, 12, 1024, dtype=torch.float32),
        'encoder_attention_mask': torch.ones(2, 12, dtype=torch.int64),
    mg_ov_0_step = convert_model(model.decoder, example_input=decoder_input)

    save_model(mg_ov_0_step, musicgen_0_ir_path)
    del mg_ov_0_step

On further iterations, the model is also provided with a past_key_values argument that contains previous outputs of the attention block, it allows us to save on computations. But for us, it means that the signature of the model’s forward method changed. Models in OpenVINO IR have frozen calculation graphs and do not allow optional arguments, that is why the MusicGen model must be converted a second time, with an increased number of inputs.

# Add additional argument to the example_input dict
if not musicgen_ir_path.exists():
    # Add `past_key_values` to the converted model signature
    decoder_input['past_key_values'] = tuple(
            torch.ones(2, 16, 1, 64, dtype=torch.float32),
            torch.ones(2, 16, 1, 64, dtype=torch.float32),
            torch.ones(2, 16, 12, 64, dtype=torch.float32),
            torch.ones(2, 16, 12, 64, dtype=torch.float32),
        )] * 24

    mg_ov = convert_model(model.decoder, example_input=decoder_input)

Moreover, the past_key_values argument is passed as Tuple[Tuple[torch.tensor]] which is a hard call for the converter. The converter flattens these tuples, but have hard time identifying the correct shapes for the tensors. The code below goes over raveled past_key_values-related model inputs and sets the correct shape and type for them.

if not musicgen_ir_path.exists():
    for input in mg_ov.inputs[3:]:
        input.get_node().set_partial_shape(PartialShape([-1, 16, -1, 64]))


    save_model(mg_ov, musicgen_ir_path)
    del mg_ov

3. Convert Audio Decoder

The audio decoder which is a part of the EnCodec model is used to recover the audio waveform from the audio tokens predicted by the MusicGen decoder. To learn more about the model please refer to the corresponding OpenVINO example.

The audio decoder computation graph contains an LSTM sub-network which currently should be converted through ONNX. To do this we create a wrapper class with its forward() method calling encodec.decode().

if not audio_decoder_onnx_path.exists():
    class AudioDecoder(torch.nn.Module):
        def __init__(self, model):
            self.model = model

        def forward(self, output_ids):
            return self.model.decode(output_ids, [None])

    audio_decoder_input = {'output_ids': torch.ones(1, 1, 4, n_tokens - 3, dtype=torch.int64),}

    with torch.no_grad():
                'output_ids': {3: 'sequence_length'},
                'decoded_audio': {2: 'audio_values'}
Now we can convert the frozen ONNX computation graph to OpenVINO IR.

# Now we can convert the model to OpenVINO IR
if not audio_decoder_ir_path.exists():
    audio_decoder_ov = convert_model(str(audio_decoder_onnx_path))

    save_model(audio_decoder_ov, audio_decoder_ir_path)
    del audio_decoder_ov

Embedding the converted models into the original pipeline

OpenVINO™ Runtime Python API is used to compile the model in OpenVINO IR format. The Core class provides access to the OpenVINO Runtime API. The core object, which is an instance of the Core class represents the API and it is used to compile the model.

core = Core()

Select device that will be used to do models inference using OpenVINO from the dropdown list:

import ipywidgets as widgets

DEVICE = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],

Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

Adapt OpenVINO models to the original pipeline

Here we create wrapper classes for all three OpenVINO models that we want to embed in the original inference pipeline. Here are some of the things to consider when adapting an OV model:

  • Make sure that parameters passed by the original pipeline are forwarded to the compiled OV model properly; sometimes the OV model uses only a portion of the input arguments and some are ignored, sometimes you need to convert the argument to another data type or unwrap some data structures such as tuples or dictionaries.

  • Guarantee that the wrapper class returns results to the pipeline in an expected format. In the example below you can see how we pack OV model outputs into special classes declared in the HF repo.

  • Pay attention to the model method used in the original pipeline for calling the model - it may be not the forward method! Refer to the AudioDecoderWrapper to see how we wrap OV model inference into the decode method.

class TextEncoderWrapper(torch.nn.Module):
    def __init__(self, encoder_ir, config):
        self.encoder = core.compile_model(encoder_ir, DEVICE.value)
        self.config = config

    def forward(self, input_ids, **kwargs):
        last_hidden_state = self.encoder(input_ids)[self.encoder.outputs[0]]
        last_hidden_state = torch.tensor(last_hidden_state)
        return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=last_hidden_state)

class MusicGenWrapper(torch.nn.Module):
    def __init__(self, music_gen_lm_0_ir, music_gen_lm_ir, config, num_codebooks, build_delay_pattern_mask,
        self.music_gen_lm_0 = core.compile_model(music_gen_lm_0_ir, DEVICE.value)
        self.music_gen_lm = core.compile_model(music_gen_lm_ir, DEVICE.value)
        self.config = config
        self.num_codebooks = num_codebooks
        self.build_delay_pattern_mask = build_delay_pattern_mask
        self.apply_delay_pattern_mask = apply_delay_pattern_mask

    def forward(
        input_ids: torch.LongTensor = None,
        encoder_hidden_states: torch.FloatTensor = None,
        encoder_attention_mask: torch.LongTensor = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        if past_key_values is None:
            model = self.music_gen_lm_0
            arguments = (input_ids, encoder_hidden_states, encoder_attention_mask)
            model = self.music_gen_lm
            arguments = (input_ids, encoder_hidden_states, encoder_attention_mask, *past_key_values)

        output = model(arguments)
        return CausalLMOutputWithCrossAttentions(
            past_key_values=tuple([output[model.outputs[i]] for i in range(1, 97)]),

class AudioDecoderWrapper(torch.nn.Module):
    def __init__(self, decoder_ir, config):
        self.decoder = core.compile_model(decoder_ir, DEVICE.value)
        self.config = config
        self.output_type = namedtuple("AudioDecoderOutput", ["audio_values"])

    def decode(self, output_ids, audio_scales):
        output = self.decoder(output_ids)[self.decoder.outputs[0]]
        return self.output_type(audio_values=torch.tensor(output))

Now we initialize the wrapper objects and load them to the HF pipeline

text_encode_ov = TextEncoderWrapper(t5_ir_path, model.text_encoder.config)
musicgen_decoder_ov = MusicGenWrapper(
audio_encoder_ov = AudioDecoderWrapper(audio_decoder_ir_path, model.audio_encoder.config)

del model.text_encoder
del model.decoder
del model.audio_encoder

model.text_encoder = text_encode_ov
model.decoder = musicgen_decoder_ov
model.audio_encoder = audio_encoder_ov

We can now infer the pipeline backed by OpenVINO models.

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")

inputs = processor(
    text=["80s pop track with bassy drums and synth"],

audio_values = model.generate(**, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)

Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)

Try out the converted pipeline

The demo app below is created using Gradio package

def _generate(prompt):
    inputs = processor(
    audio_values = model.generate(**, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)
    waveform = audio_values[0].cpu().squeeze() * 2**15
    return (sampling_rate, waveform.numpy().astype(np.int16))
import gradio as gr

demo = gr.Interface(
        gr.Textbox(label="Text Prompt"),
        ["80s pop track with bassy drums and synth"],
        ["Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves"],
        ["90s rock song with loud guitars and heavy drums"],
        ["Heartful EDM with beautiful synths and chords"],
except Exception:
    demo.launch(share=True, debug=False)

# If you are launching remotely, specify server_name and server_port
# EXAMPLE: `demo.launch(server_name='your server name', server_port='server port in int')`
# To learn more please refer to the Gradio docs:
Running on local URL:

To create a public link, set share=True in launch().