Image generation with Stable Cascade and OpenVINO#

This Jupyter notebook can be launched after a local installation only.

Github

Stable Cascade is built upon the Würstchen architecture and its main difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this important? The smaller the latent space, the faster you can run inference and the cheaper the training becomes. How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a 1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the highly compressed latent space.

Table of contents:

This is a self-contained example that relies solely on its own code.

We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.

Prerequisites#

%pip install -q "diffusers>=0.27.0" accelerate datasets gradio transformers "nncf>=2.10.0" "protobuf>=3.20" "openvino>=2024.1.0" "torch>=2.1" --extra-index-url https://download.pytorch.org/whl/cpu

Load and run the original pipeline#

import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)

# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry
from notebook_utils import collect_telemetry

collect_telemetry("stable-cascade-image-generation.ipynb")

prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.float32)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float32)
Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]
Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

To reduce memory usage, we skip the original inference. If you want run it, turn it.

import ipywidgets as widgets


run_original_inference = widgets.Checkbox(
    value=False,
    description="Run original inference",
    disabled=False,
)

run_original_inference
Checkbox(value=False, description='Run original inference')
if run_original_inference.value:
    prior.to(torch.device("cpu"))
    prior_output = prior(
        prompt=prompt,
        height=1024,
        width=1024,
        negative_prompt=negative_prompt,
        guidance_scale=4.0,
        num_images_per_prompt=1,
        num_inference_steps=20,
    )

    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings,
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=0.0,
        output_type="pil",
        num_inference_steps=10,
    ).images[0]
    display(decoder_output)
0%|          | 0/20 [00:00<?, ?it/s]
0%|          | 0/10 [00:00<?, ?it/s]
../_images/stable-cascade-image-generation-with-output_8_2.png

Convert the model to OpenVINO IR#

Stable Cascade has 2 components: - Prior stage prior: create low-dimensional latent space representation of the image using text-conditional LDM - Decoder stage decoder: using representation from Prior Stage, produce a latent image in latent space of higher dimensionality using LDM and using VQGAN-decoder, decode the latent image to yield a full-resolution output image.

Let’s define the conversion function for PyTorch modules. We use ov.convert_model function to obtain OpenVINO Intermediate Representation object and ov.save_model function to save it as XML file. We use nncf.compress_weights to compress model weights to 8-bit to reduce model size.

import gc
from pathlib import Path

import openvino as ov
import nncf


MODELS_DIR = Path("models")


def convert(model: torch.nn.Module, xml_path: str, example_input, input_shape=None):
    xml_path = Path(xml_path)
    if not xml_path.exists():
        model.eval()
        xml_path.parent.mkdir(parents=True, exist_ok=True)
        with torch.no_grad():
            if not input_shape:
                converted_model = ov.convert_model(model, example_input=example_input)
            else:
                converted_model = ov.convert_model(model, example_input=example_input, input=input_shape)
        converted_model = nncf.compress_weights(converted_model)
        ov.save_model(converted_model, xml_path)
        del converted_model

        # cleanup memory
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
        torch.jit._state._clear_class_state()

        gc.collect()
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino

Prior pipeline#

This pipeline consists of text encoder and prior diffusion model. From here, we always use fixed shapes in conversion by using an input_shape parameter to generate a less memory-demanding model.

PRIOR_TEXT_ENCODER_OV_PATH = MODELS_DIR / "prior_text_encoder_model.xml"

prior.text_encoder.config.output_hidden_states = True


class TextEncoderWrapper(torch.nn.Module):
    def __init__(self, text_encoder):
        super().__init__()
        self.text_encoder = text_encoder

    def forward(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        return outputs["text_embeds"], outputs["last_hidden_state"], outputs["hidden_states"]


convert(
    TextEncoderWrapper(prior.text_encoder),
    PRIOR_TEXT_ENCODER_OV_PATH,
    example_input={
        "input_ids": torch.zeros(1, 77, dtype=torch.int32),
        "attention_mask": torch.zeros(1, 77),
    },
    input_shape={"input_ids": ((1, 77),), "attention_mask": ((1, 77),)},
)
del prior.text_encoder
gc.collect();
PRIOR_PRIOR_MODEL_OV_PATH = MODELS_DIR / "prior_prior_model.xml"

convert(
    prior.prior,
    PRIOR_PRIOR_MODEL_OV_PATH,
    example_input={
        "sample": torch.zeros(2, 16, 24, 24),
        "timestep_ratio": torch.ones(2),
        "clip_text_pooled": torch.zeros(2, 1, 1280),
        "clip_text": torch.zeros(2, 77, 1280),
        "clip_img": torch.zeros(2, 1, 768),
    },
    input_shape=[((-1, 16, 24, 24),), ((-1),), ((-1, 1, 1280),), ((-1, 77, 1280),), (-1, 1, 768)],
)
del prior.prior
gc.collect();

Decoder pipeline#

Decoder pipeline consists of 3 parts: decoder, text encoder and VQGAN.

DECODER_TEXT_ENCODER_MODEL_OV_PATH = MODELS_DIR / "decoder_text_encoder_model.xml"

convert(
    TextEncoderWrapper(decoder.text_encoder),
    DECODER_TEXT_ENCODER_MODEL_OV_PATH,
    example_input={
        "input_ids": torch.zeros(1, 77, dtype=torch.int32),
        "attention_mask": torch.zeros(1, 77),
    },
    input_shape={"input_ids": ((1, 77),), "attention_mask": ((1, 77),)},
)

del decoder.text_encoder
gc.collect();
DECODER_DECODER_MODEL_OV_PATH = MODELS_DIR / "decoder_decoder_model.xml"

convert(
    decoder.decoder,
    DECODER_DECODER_MODEL_OV_PATH,
    example_input={
        "sample": torch.zeros(1, 4, 256, 256),
        "timestep_ratio": torch.ones(1),
        "clip_text_pooled": torch.zeros(1, 1, 1280),
        "effnet": torch.zeros(1, 16, 24, 24),
    },
    input_shape=[((-1, 4, 256, 256),), ((-1),), ((-1, 1, 1280),), ((-1, 16, 24, 24),)],
)
del decoder.decoder
gc.collect();
VQGAN_PATH = MODELS_DIR / "vqgan_model.xml"


class VqganDecoderWrapper(torch.nn.Module):
    def __init__(self, vqgan):
        super().__init__()
        self.vqgan = vqgan

    def forward(self, h):
        return self.vqgan.decode(h)


convert(
    VqganDecoderWrapper(decoder.vqgan),
    VQGAN_PATH,
    example_input=torch.zeros(1, 4, 256, 256),
    input_shape=(1, 4, 256, 256),
)
del decoder.vqgan
gc.collect();

Select inference device#

Select device from dropdown list for running inference using OpenVINO.

from notebook_utils import device_widget

device = device_widget()

device
Dropdown(description='Device:', index=4, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'AUTO'), value='AUTO')

Building the pipeline#

Let’s create callable wrapper classes for compiled models to allow interaction with original pipelines. Note that all of wrapper classes return torch.Tensors instead of np.arrays.

from collections import namedtuple

core = ov.Core()


BaseModelOutputWithPooling = namedtuple("BaseModelOutputWithPooling", ["text_embeds", "last_hidden_state", "hidden_states"])


class TextEncoderWrapper:
    dtype = torch.float32  # accessed in the original workflow

    def __init__(self, text_encoder_path, device):
        self.text_encoder = core.compile_model(text_encoder_path, device.value)

    def __call__(self, input_ids, attention_mask, output_hidden_states=True):
        output = self.text_encoder({"input_ids": input_ids, "attention_mask": attention_mask})
        text_embeds = output[0]
        last_hidden_state = output[1]
        hidden_states = list(output.values())[1:]
        return BaseModelOutputWithPooling(torch.from_numpy(text_embeds), torch.from_numpy(last_hidden_state), [torch.from_numpy(hs) for hs in hidden_states])
class PriorPriorWrapper:
    def __init__(self, prior_path, device):
        self.prior = core.compile_model(prior_path, device.value)
        self.config = namedtuple("PriorWrapperConfig", ["clip_image_in_channels", "in_channels"])(768, 16)  # accessed in the original workflow
        self.parameters = lambda: (torch.zeros(i, dtype=torch.float32) for i in range(1))  # accessed in the original workflow

    def __call__(self, sample, timestep_ratio, clip_text_pooled, clip_text=None, clip_img=None, **kwargs):
        inputs = {
            "sample": sample,
            "timestep_ratio": timestep_ratio,
            "clip_text_pooled": clip_text_pooled,
            "clip_text": clip_text,
            "clip_img": clip_img,
        }
        output = self.prior(inputs)
        return [torch.from_numpy(output[0])]
class DecoderWrapper:
    dtype = torch.float32  # accessed in the original workflow

    def __init__(self, decoder_path, device):
        self.decoder = core.compile_model(decoder_path, device.value)

    def __call__(self, sample, timestep_ratio, clip_text_pooled, effnet, **kwargs):
        inputs = {"sample": sample, "timestep_ratio": timestep_ratio, "clip_text_pooled": clip_text_pooled, "effnet": effnet}
        output = self.decoder(inputs)
        return [torch.from_numpy(output[0])]
VqganOutput = namedtuple("VqganOutput", "sample")


class VqganWrapper:
    config = namedtuple("VqganWrapperConfig", "scale_factor")(0.3764)  # accessed in the original workflow

    def __init__(self, vqgan_path, device):
        self.vqgan = core.compile_model(vqgan_path, device.value)

    def decode(self, h):
        output = self.vqgan(h)[0]
        output = torch.tensor(output)
        return VqganOutput(output)

And insert wrappers instances in the pipeline:

prior.text_encoder = TextEncoderWrapper(PRIOR_TEXT_ENCODER_OV_PATH, device)
prior.prior = PriorPriorWrapper(PRIOR_PRIOR_MODEL_OV_PATH, device)
decoder.decoder = DecoderWrapper(DECODER_DECODER_MODEL_OV_PATH, device)
decoder.text_encoder = TextEncoderWrapper(DECODER_TEXT_ENCODER_MODEL_OV_PATH, device)
decoder.vqgan = VqganWrapper(VQGAN_PATH, device)

Inference#

prior_output = prior(
    prompt=prompt,
    height=1024,
    width=1024,
    negative_prompt=negative_prompt,
    guidance_scale=4.0,
    num_images_per_prompt=1,
    num_inference_steps=20,
)

decoder_output = decoder(
    image_embeddings=prior_output.image_embeddings,
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
    num_inference_steps=10,
).images[0]
display(decoder_output)
0%|          | 0/20 [00:00<?, ?it/s]
0%|          | 0/10 [00:00<?, ?it/s]
../_images/stable-cascade-image-generation-with-output_29_2.png

Interactive inference#

def generate(prompt, negative_prompt, prior_guidance_scale, decoder_guidance_scale, seed):
    generator = torch.Generator().manual_seed(seed)
    prior_output = prior(
        prompt=prompt,
        height=1024,
        width=1024,
        negative_prompt=negative_prompt,
        guidance_scale=prior_guidance_scale,
        num_images_per_prompt=1,
        num_inference_steps=20,
        generator=generator,
    )

    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings,
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=decoder_guidance_scale,
        output_type="pil",
        num_inference_steps=10,
        generator=generator,
    ).images[0]

    return decoder_output
import requests

if not Path("gradio_helper.py").exists():
    r = requests.get(
        url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/stable-cascade-image-generation/gradio_helper.py"
    )
    open("gradio_helper.py", "w").write(r.text)

from gradio_helper import make_demo

demo = make_demo(generate)

try:
    demo.queue().launch(debug=True)
except Exception:
    demo.queue().launch(debug=True, share=True)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/