Sound Generation with Stable Audio Open and OpenVINO™#

This Jupyter notebook can be launched after a local installation only.

Github

Stable Audio Open is an open-source model optimized for generating short audio samples, sound effects, and production elements using text prompts. The model was trained on data from Freesound and the Free Music Archive, respecting creator rights.

stable-audio

stable-audio#

Key Takeaways:#

  • Stable Audio Open is an open source text-to-audio model for generating up to 47 seconds of samples and sound effects.

  • Users can create drum beats, instrument riffs, ambient sounds, foley and production elements.

  • The model enables audio variations and style transfer of audio samples.

This model is made to be used with the stable-audio-tools library for inference.

Table of contents:

Installation Instructions#

This is a self-contained example that relies solely on its own code.

We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.

Prerequisites#

import platform

%pip install -q "torch>=2.2" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q  "stable-audio-tools" "nncf>=2.12.0" --extra-index-url https://download.pytorch.org/whl/cpu
if platform.system() == "Darwin":
    %pip install -q "numpy>=1.26,<2.0.0" "pandas>2.0.2"
else:
    %pip install -q "numpy>=1.26" "pandas>2.0.2"
%pip install -q  "openvino>=2024.4.0"

Load the original model and inference#

Note: run model with notebook, you will need to accept license agreement. You must be a registered user in Hugging Face Hub. Please visit HuggingFace model card, carefully read terms of usage and click accept button. You will need to use an access token for the code below to run. For more information on access tokens, refer to this section of the documentation. You can login on Hugging Face Hub in notebook environment, using following code:

# uncomment these lines to login to huggingfacehub to get access to pretrained model
# from huggingface_hub import notebook_login, whoami

# try:
#     whoami()
#     print('Authorization token already provided')
# except OSError:
#     notebook_login()
import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond


# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
model_config.json:   0%|          | 0.00/4.17k [00:00<?, ?B/s]
2024-10-29 21:32:11.156823: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-10-29 21:32:11.171697: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1730223131.187567  300904 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1730223131.192174  300904 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-29 21:32:11.209352: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/ea/work/py311/lib/python3.11/site-packages/x_transformers/x_transformers.py:435: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  @autocast(enabled = False)
/home/ea/work/py311/lib/python3.11/site-packages/x_transformers/x_transformers.py:461: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  @autocast(enabled = False)
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/models/transformer.py:126: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  @autocast(enabled = False)
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/models/transformer.py:151: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  @autocast(enabled = False)
No module named 'flash_attn'
flash_attn not installed, disabling Flash Attention
/home/ea/work/py311/lib/python3.11/site-packages/vector_quantize_pytorch/vector_quantize_pytorch.py:436: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  @autocast(enabled = False)
/home/ea/work/py311/lib/python3.11/site-packages/vector_quantize_pytorch/vector_quantize_pytorch.py:619: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  @autocast(enabled = False)
/home/ea/work/py311/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  WeightNorm.apply(module, name, dim)
model.safetensors:   0%|          | 0.00/4.85G [00:00<?, ?B/s]
sample_rate = model_config["sample_rate"]

model = model.to("cpu")
total_seconds = 20

# Set up text and timing conditioning
conditioning = [{"prompt": "128 BPM tech house drum loop", "seconds_start": 0, "seconds_total": total_seconds}]

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    steps=100,
    seed=42,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=sample_rate * total_seconds,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device="cpu",
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", output, sample_rate)
42
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/models/conditioners.py:314: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
/home/ea/work/py311/lib/python3.11/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn(
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/inference/sampling.py:177: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  with torch.cuda.amp.autocast():
0%|          | 0/100 [00:00<?, ?it/s]
/home/ea/work/py311/lib/python3.11/site-packages/torchsde/_brownian/brownian_interval.py:608: UserWarning: Should have tb<=t1 but got tb=500.00006103515625 and t1=500.000061.
  warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.")
/home/ea/work/py311/lib/python3.11/site-packages/torchsde/_brownian/brownian_interval.py:599: UserWarning: Should have ta>=t0 but got ta=0.29999998211860657 and t0=0.3.
  warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.")
from IPython.display import Audio

Audio("output.wav")

Convert the model to OpenVINO IR#

Let’s define the conversion function for PyTorch modules. We use ov.convert_model function to obtain OpenVINO Intermediate Representation object and ov.save_model function to save it as XML file.

For reducing memory consumption, weights compression optimization can be applied using NNCF. Weight compression aims to reduce the memory footprint of a model. models, which require extensive memory to store the weights during inference, can benefit from weight compression in the following ways:

  • enabling the inference of exceptionally large models that cannot be accommodated in the memory of the device;

  • improving the inference performance of the models by reducing the latency of the memory access when computing the operations with weights, for example, Linear layers.

Neural Network Compression Framework (NNCF) provides 4-bit / 8-bit mixed weight quantization as a compression method. The main difference between weights compression and full model quantization (post-training quantization) is that activations remain floating-point in the case of weights compression which leads to a better accuracy. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use.

nncf.compress_weights function can be used for performing weights compression. The function accepts an OpenVINO model and other compression parameters. Different parameters may be suitable for different models. In this case default parameters give bad results. But we can change mode to CompressWeightsMode.INT8_SYM to compress weights symmetrically to 8-bit integer data type and get the inference results the same as original.

More details about weights compression can be found in OpenVINO documentation.

from pathlib import Path
import torch
from nncf import compress_weights, CompressWeightsMode
import openvino as ov


def convert(model: torch.nn.Module, xml_path: str, example_input):
    xml_path = Path(xml_path)
    if not xml_path.exists():
        xml_path.parent.mkdir(parents=True, exist_ok=True)
        model.eval()
        with torch.no_grad():
            converted_model = ov.convert_model(model, example_input=example_input)
            converted_model = compress_weights(converted_model, mode=CompressWeightsMode.INT8_SYM)
        ov.save_model(converted_model, xml_path)

        # cleanup memory
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
        torch.jit._state._clear_class_state()
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
MODEL_DIR = Path("model")

CONDITIONER_ENCODER_PATH = MODEL_DIR / "conditioner_encoder.xml"
DIFFUSION_PATH = MODEL_DIR / "diffusion.xml"
PRETRANSFORM_PATH = MODEL_DIR / "pretransform.xml"

The pipeline comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder. In this example an initial audio is not used, so we need to convert T5-based text embedding model, transformer-based diffusion (DiT) model and only decoder part of autoencoder.

T5-based text embedding#

example_input = {
    "input_ids": torch.zeros(1, 120, dtype=torch.int64),
    "attention_mask": torch.zeros(1, 120, dtype=torch.int64),
}

convert(model.conditioner.conditioners["prompt"].model, CONDITIONER_ENCODER_PATH, example_input)
WARNING:nncf:NNCF provides best results with torch==2.4.*, while current torch version is 2.5.1+cpu. If you encounter issues, consider switching to torch==2.4.*
/home/ea/work/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:4779: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead
  warnings.warn(
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (74 / 74)              │ 100% (74 / 74)                         │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()

Transformer-based diffusion (DiT) model#

import types


def continious_transformer_forward(self, x, mask=None, prepend_embeds=None, prepend_mask=None, global_cond=None, return_info=False, **kwargs):
    batch, seq, device = *x.shape[:2], x.device

    info = {
        "hidden_states": [],
    }

    x = self.project_in(x)

    if prepend_embeds is not None:
        prepend_length, prepend_dim = prepend_embeds.shape[1:]

        assert prepend_dim == x.shape[-1], "prepend dimension must match sequence dimension"

        x = torch.cat((prepend_embeds, x), dim=-2)

        if prepend_mask is not None or mask is not None:
            mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool)
            prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device=device, dtype=torch.bool)

            mask = torch.cat((prepend_mask, mask), dim=-1)

    # Attention layers

    if self.rotary_pos_emb is not None:
        rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
    else:
        rotary_pos_emb = None

    if self.use_sinusoidal_emb or self.use_abs_pos_emb:
        x = x + self.pos_emb(x)

    # Iterate over the transformer layers
    for layer in self.layers:
        x = layer(x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs)
        if return_info:
            info["hidden_states"].append(x)

    x = self.project_out(x)

    if return_info:
        return x, info

    return x


class DiffusionWrapper(torch.nn.Module):
    def __init__(self, diffusion):
        super().__init__()
        self.diffusion = diffusion

    def forward(self, x=None, t=None, cross_attn_cond=None, cross_attn_cond_mask=None, global_embed=None):
        model_inputs = {"cross_attn_cond": cross_attn_cond, "cross_attn_cond_mask": cross_attn_cond_mask, "global_embed": global_embed}

        return self.diffusion.forward(x, t, cfg_scale=7, **model_inputs)


example_input = {
    "x": torch.rand([1, 64, 1024], dtype=torch.float32),
    "t": torch.rand([1], dtype=torch.float32),
    "cross_attn_cond": torch.rand([1, 130, 768], dtype=torch.float32),
    "cross_attn_cond_mask": torch.ones([1, 130], dtype=torch.float32),
    "global_embed": torch.rand(torch.Size([1, 1536]), dtype=torch.float32),
}

diffuser = model.model.model
diffuser.transformer.forward = types.MethodType(continious_transformer_forward, diffuser.transformer)
convert(DiffusionWrapper(diffuser), DIFFUSION_PATH, example_input)
/tmp/ipykernel_300904/2053830173.py:16: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert prepend_dim == x.shape[-1], "prepend dimension must match sequence dimension"
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/models/transformer.py:461: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if n == 1 and causal:
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (179 / 179)            │ 100% (179 / 179)                       │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()

Decoder part of autoencoder#

convert(model.pretransform.model.decoder, PRETRANSFORM_PATH, torch.rand([1, 64, 215], dtype=torch.float32))
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (37 / 37)              │ 100% (37 / 37)                         │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()

Compiling models and inference#

Select device from dropdown list for running inference using OpenVINO.

import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)

from notebook_utils import device_widget

device = device_widget("CPU")

device
Dropdown(description='Device:', options=('CPU', 'AUTO'), value='CPU')

Let’s create callable wrapper classes for compiled models to allow interaction with original pipeline. Note that all of wrapper classes return torch.Tensors instead of np.arrays.

core = ov.Core()


class TextEncoderWrapper(torch.nn.Module):
    def __init__(self, text_encoder, dtype, device="CPU"):
        super().__init__()
        self.text_encoder = core.compile_model(text_encoder, device)
        self.dtype = dtype

    def __call__(self, input_ids=None, attention_mask=None):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        last_hidden_state = self.text_encoder(inputs)[0]

        return {"last_hidden_state": torch.from_numpy(last_hidden_state)}
class OVWrapper(torch.nn.Module):
    def __init__(self, ov_model, old_model, device="CPU") -> None:
        super().__init__()
        self.mock = torch.nn.Parameter(torch.zeros(1))  # this is only mock to not change the pipeline
        self.dif_transformer = core.compile_model(ov_model, device)

    def forward(self, x=None, t=None, cross_attn_cond=None, cross_attn_cond_mask=None, global_embed=None, **kwargs):
        inputs = {
            "x": x,
            "t": t,
            "cross_attn_cond": cross_attn_cond,
            "cross_attn_cond_mask": cross_attn_cond_mask,
            "global_embed": global_embed,
        }
        result = self.dif_transformer(inputs)

        return torch.from_numpy(result[0])
class PretransformDecoderWrapper(torch.nn.Module):
    def __init__(self, ov_model, device="CPU"):
        super().__init__()
        self.decoder = core.compile_model(ov_model, device)

    def forward(self, latents=None):
        result = self.decoder(latents)

        return torch.from_numpy(result[0])

Now we can replace the original models by our wrapped OpenVINO models and run inference.

model.model.model = OVWrapper(DIFFUSION_PATH, model.model.model, device.value)
model.conditioner.conditioners["prompt"].model = TextEncoderWrapper(
    CONDITIONER_ENCODER_PATH, model.conditioner.conditioners["prompt"].model.dtype, device.value
)
model.pretransform.model.decoder = PretransformDecoderWrapper(PRETRANSFORM_PATH, device.value)
output = generate_diffusion_cond(
    model,
    steps=100,
    seed=42,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=sample_rate * total_seconds,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device="cpu",
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", output, sample_rate)
42
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/models/conditioners.py:314: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
/home/ea/work/py311/lib/python3.11/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn(
/home/ea/work/py311/lib/python3.11/site-packages/stable_audio_tools/inference/sampling.py:177: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
  with torch.cuda.amp.autocast():
0%|          | 0/100 [00:00<?, ?it/s]
/home/ea/work/py311/lib/python3.11/site-packages/torchsde/_brownian/brownian_interval.py:608: UserWarning: Should have tb<=t1 but got tb=500.00006103515625 and t1=500.000061.
  warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.")
/home/ea/work/py311/lib/python3.11/site-packages/torchsde/_brownian/brownian_interval.py:599: UserWarning: Should have ta>=t0 but got ta=0.29999998211860657 and t0=0.3.
  warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.")
Audio("output.wav")

Interactive inference#

def _generate(prompt, total_seconds, steps, seed):
    sample_rate = model_config["sample_rate"]

    # Set up text and timing conditioning
    conditioning = [{"prompt": prompt, "seconds_start": 0, "seconds_total": total_seconds}]

    output = generate_diffusion_cond(
        model,
        steps=steps,
        seed=seed,
        cfg_scale=7,
        conditioning=conditioning,
        sample_size=sample_rate * total_seconds,
        sigma_min=0.3,
        sigma_max=500,
        sampler_type="dpmpp-3m-sde",
        device="cpu",
    )

    # Rearrange audio batch to a single sequence
    output = rearrange(output, "b d n -> d (b n)")

    # Peak normalize, clip, convert to int16, and save to file
    output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
    return (sample_rate, output.numpy().transpose())
if not Path("gradio_helper.py").exists():
    r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/stable-audio/gradio_helper.py")
    open("gradio_helper.py", "w").write(r.text)

from gradio_helper import make_demo

demo = make_demo(fn=_generate)

try:
    demo.launch(debug=True)
except Exception:
    demo.launch(share=True, debug=True)
# If you are launching remotely, specify server_name and server_port
# EXAMPLE: `demo.launch(server_name='your server name', server_port='server port in int')`
# To learn more please refer to the Gradio docs: https://gradio.app/docs/