Image generation with Stable Cascade and OpenVINO#
This Jupyter notebook can be launched after a local installation only.
Stable Cascade is built upon the Würstchen architecture and its main difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this important? The smaller the latent space, the faster you can run inference and the cheaper the training becomes. How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a 1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the highly compressed latent space.
Table of contents:
This is a self-contained example that relies solely on its own code.
We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.
Prerequisites#
%pip install -q "diffusers>=0.27.0" accelerate datasets gradio transformers "nncf>=2.10.0" "protobuf>=3.20" "openvino>=2024.1.0" "torch>=2.1" --extra-index-url https://download.pytorch.org/whl/cpu
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mobileclip 0.1.0 requires torchvision==0.14.1, but you have torchvision 0.19.1+cpu which is incompatible.
s3fs 2024.10.0 requires fsspec==2024.10.0.*, but you have fsspec 2024.9.0 which is incompatible.
Note: you may need to restart the kernel to use updated packages.
Load and run the original pipeline#
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.float32)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float32)
2024-11-05 05:00:13.157963: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0. 2024-11-05 05:00:13.192172: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-11-05 05:00:13.880563: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Loading pipeline components...: 0%| | 0/6 [00:00<?, ?it/s]
Loading pipeline components...: 0%| | 0/5 [00:00<?, ?it/s]
To reduce memory usage, we skip the original inference. If you want run it, turn it.
import ipywidgets as widgets
run_original_inference = widgets.Checkbox(
value=False,
description="Run original inference",
disabled=False,
)
run_original_inference
Checkbox(value=False, description='Run original inference')
if run_original_inference.value:
prior.to(torch.device("cpu"))
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20,
)
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10,
).images[0]
display(decoder_output)
Convert the model to OpenVINO IR#
Stable Cascade has 2 components: - Prior stage prior
: create
low-dimensional latent space representation of the image using
text-conditional LDM - Decoder stage decoder
: using representation
from Prior Stage, produce a latent image in latent space of higher
dimensionality using LDM and using VQGAN-decoder, decode the latent
image to yield a full-resolution output image.
Let’s define the conversion function for PyTorch modules. We use
ov.convert_model
function to obtain OpenVINO Intermediate
Representation object and ov.save_model
function to save it as XML
file. We use nncf.compress_weights
to compress model
weights
to 8-bit to reduce model size.
import gc
from pathlib import Path
import openvino as ov
import nncf
MODELS_DIR = Path("models")
def convert(model: torch.nn.Module, xml_path: str, example_input, input_shape=None):
xml_path = Path(xml_path)
if not xml_path.exists():
model.eval()
xml_path.parent.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
if not input_shape:
converted_model = ov.convert_model(model, example_input=example_input)
else:
converted_model = ov.convert_model(model, example_input=example_input, input=input_shape)
converted_model = nncf.compress_weights(converted_model)
ov.save_model(converted_model, xml_path)
del converted_model
# cleanup memory
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._clear_class_state()
gc.collect()
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
Prior pipeline#
This pipeline consists of text encoder and prior diffusion model. From
here, we always use fixed shapes in conversion by using an
input_shape
parameter to generate a less memory-demanding model.
PRIOR_TEXT_ENCODER_OV_PATH = MODELS_DIR / "prior_text_encoder_model.xml"
prior.text_encoder.config.output_hidden_states = True
class TextEncoderWrapper(torch.nn.Module):
def __init__(self, text_encoder):
super().__init__()
self.text_encoder = text_encoder
def forward(self, input_ids, attention_mask):
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
return outputs["text_embeds"], outputs["last_hidden_state"], outputs["hidden_states"]
convert(
TextEncoderWrapper(prior.text_encoder),
PRIOR_TEXT_ENCODER_OV_PATH,
example_input={
"input_ids": torch.zeros(1, 77, dtype=torch.int32),
"attention_mask": torch.zeros(1, 77),
},
input_shape={"input_ids": ((1, 77),), "attention_mask": ((1, 77),)},
)
del prior.text_encoder
gc.collect();
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[ WARNING ] Please fix your imports. Module %s has been moved to %s. The old module will be deleted in version %s. /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/810/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py:5006: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead warnings.warn( loss_type=None was set in the config but it is unrecognised.Using the default loss: ForCausalLMLoss. /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/810/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_attn_mask_utils.py:88: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if input_shape[-1] > 1 or self.sliding_window is not None: /opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/810/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_attn_mask_utils.py:164: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if past_key_values_length > 0:
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (194 / 194) │ 100% (194 / 194) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
PRIOR_PRIOR_MODEL_OV_PATH = MODELS_DIR / "prior_prior_model.xml"
convert(
prior.prior,
PRIOR_PRIOR_MODEL_OV_PATH,
example_input={
"sample": torch.zeros(2, 16, 24, 24),
"timestep_ratio": torch.ones(2),
"clip_text_pooled": torch.zeros(2, 1, 1280),
"clip_text": torch.zeros(2, 77, 1280),
"clip_img": torch.zeros(2, 1, 768),
},
input_shape=[((-1, 16, 24, 24),), ((-1),), ((-1, 1, 1280),), ((-1, 77, 1280),), (-1, 1, 768)],
)
del prior.prior
gc.collect();
/opt/home/k8sworker/ci-ai/cibuilds/jobs/ov-notebook/jobs/OVNotebookOps/builds/810/archive/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/models/unets/unet_stable_cascade.py:548: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (711 / 711) │ 100% (711 / 711) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
Decoder pipeline#
Decoder pipeline consists of 3 parts: decoder, text encoder and VQGAN.
DECODER_TEXT_ENCODER_MODEL_OV_PATH = MODELS_DIR / "decoder_text_encoder_model.xml"
convert(
TextEncoderWrapper(decoder.text_encoder),
DECODER_TEXT_ENCODER_MODEL_OV_PATH,
example_input={
"input_ids": torch.zeros(1, 77, dtype=torch.int32),
"attention_mask": torch.zeros(1, 77),
},
input_shape={"input_ids": ((1, 77),), "attention_mask": ((1, 77),)},
)
del decoder.text_encoder
gc.collect();
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (194 / 194) │ 100% (194 / 194) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
DECODER_DECODER_MODEL_OV_PATH = MODELS_DIR / "decoder_decoder_model.xml"
convert(
decoder.decoder,
DECODER_DECODER_MODEL_OV_PATH,
example_input={
"sample": torch.zeros(1, 4, 256, 256),
"timestep_ratio": torch.ones(1),
"clip_text_pooled": torch.zeros(1, 1, 1280),
"effnet": torch.zeros(1, 16, 24, 24),
},
input_shape=[((-1, 4, 256, 256),), ((-1),), ((-1, 1, 1280),), ((-1, 16, 24, 24),)],
)
del decoder.decoder
gc.collect();
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (779 / 779) │ 100% (779 / 779) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
VQGAN_PATH = MODELS_DIR / "vqgan_model.xml"
class VqganDecoderWrapper(torch.nn.Module):
def __init__(self, vqgan):
super().__init__()
self.vqgan = vqgan
def forward(self, h):
return self.vqgan.decode(h)
convert(
VqganDecoderWrapper(decoder.vqgan),
VQGAN_PATH,
example_input=torch.zeros(1, 4, 256, 256),
input_shape=(1, 4, 256, 256),
)
del decoder.vqgan
gc.collect();
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ 8 │ 100% (42 / 42) │ 100% (42 / 42) │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Output()
Select inference device#
Select device from dropdown list for running inference using OpenVINO.
import requests
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)
from notebook_utils import device_widget
device = device_widget()
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
Building the pipeline#
Let’s create callable wrapper classes for compiled models to allow
interaction with original pipelines. Note that all of wrapper classes
return torch.Tensor
s instead of np.array
s.
from collections import namedtuple
core = ov.Core()
BaseModelOutputWithPooling = namedtuple("BaseModelOutputWithPooling", ["text_embeds", "last_hidden_state", "hidden_states"])
class TextEncoderWrapper:
dtype = torch.float32 # accessed in the original workflow
def __init__(self, text_encoder_path, device):
self.text_encoder = core.compile_model(text_encoder_path, device.value)
def __call__(self, input_ids, attention_mask, output_hidden_states=True):
output = self.text_encoder({"input_ids": input_ids, "attention_mask": attention_mask})
text_embeds = output[0]
last_hidden_state = output[1]
hidden_states = list(output.values())[1:]
return BaseModelOutputWithPooling(torch.from_numpy(text_embeds), torch.from_numpy(last_hidden_state), [torch.from_numpy(hs) for hs in hidden_states])
class PriorPriorWrapper:
def __init__(self, prior_path, device):
self.prior = core.compile_model(prior_path, device.value)
self.config = namedtuple("PriorWrapperConfig", ["clip_image_in_channels", "in_channels"])(768, 16) # accessed in the original workflow
self.parameters = lambda: (torch.zeros(i, dtype=torch.float32) for i in range(1)) # accessed in the original workflow
def __call__(self, sample, timestep_ratio, clip_text_pooled, clip_text=None, clip_img=None, **kwargs):
inputs = {
"sample": sample,
"timestep_ratio": timestep_ratio,
"clip_text_pooled": clip_text_pooled,
"clip_text": clip_text,
"clip_img": clip_img,
}
output = self.prior(inputs)
return [torch.from_numpy(output[0])]
class DecoderWrapper:
dtype = torch.float32 # accessed in the original workflow
def __init__(self, decoder_path, device):
self.decoder = core.compile_model(decoder_path, device.value)
def __call__(self, sample, timestep_ratio, clip_text_pooled, effnet, **kwargs):
inputs = {"sample": sample, "timestep_ratio": timestep_ratio, "clip_text_pooled": clip_text_pooled, "effnet": effnet}
output = self.decoder(inputs)
return [torch.from_numpy(output[0])]
VqganOutput = namedtuple("VqganOutput", "sample")
class VqganWrapper:
config = namedtuple("VqganWrapperConfig", "scale_factor")(0.3764) # accessed in the original workflow
def __init__(self, vqgan_path, device):
self.vqgan = core.compile_model(vqgan_path, device.value)
def decode(self, h):
output = self.vqgan(h)[0]
output = torch.tensor(output)
return VqganOutput(output)
And insert wrappers instances in the pipeline:
prior.text_encoder = TextEncoderWrapper(PRIOR_TEXT_ENCODER_OV_PATH, device)
prior.prior = PriorPriorWrapper(PRIOR_PRIOR_MODEL_OV_PATH, device)
decoder.decoder = DecoderWrapper(DECODER_DECODER_MODEL_OV_PATH, device)
decoder.text_encoder = TextEncoderWrapper(DECODER_TEXT_ENCODER_MODEL_OV_PATH, device)
decoder.vqgan = VqganWrapper(VQGAN_PATH, device)
Inference#
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20,
)
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10,
).images[0]
display(decoder_output)
0%| | 0/20 [00:00<?, ?it/s]
0%| | 0/10 [00:00<?, ?it/s]
Interactive inference#
def generate(prompt, negative_prompt, prior_guidance_scale, decoder_guidance_scale, seed):
generator = torch.Generator().manual_seed(seed)
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=prior_guidance_scale,
num_images_per_prompt=1,
num_inference_steps=20,
generator=generator,
)
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=decoder_guidance_scale,
output_type="pil",
num_inference_steps=10,
generator=generator,
).images[0]
return decoder_output
import requests
if not Path("gradio_helper.py").exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/stable-cascade-image-generation/gradio_helper.py"
)
open("gradio_helper.py", "w").write(r.text)
from gradio_helper import make_demo
demo = make_demo(generate)
try:
demo.queue().launch(debug=False)
except Exception:
demo.queue().launch(debug=False, share=True)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL: http://127.0.0.1:7860 To create a public link, set share=True in launch().