PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis with OpenVINO#

This Jupyter notebook can be launched after a local installation only.

Github

This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators, reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost. To achieve this goal, three core designs are proposed: 1. Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; 2. Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; 3. High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning.

image0

Table of contents:

Prerequisites#

%pip install -q "diffusers>=0.14.0" sentencepiece "datasets>=2.14.6" "transformers>=4.25.1" "gradio>=4.19" "torch>=2.1" Pillow opencv-python --extra-index-url https://download.pytorch.org/whl/cpu
%pip install --pre -Uq openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
Note: you may need to restart the kernel to use updated packages.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
openvino-dev 2024.2.0 requires openvino==2024.2.0, but you have openvino 2024.3.0.dev20240627 which is incompatible.
Note: you may need to restart the kernel to use updated packages.

Load and run the original pipeline#

We use PixArt-LCM-XL-2-1024-MS that uses LCMs. LCMs is a diffusion distillation method which predict PF-ODE's solution directly in latent space, achieving super fast inference with few steps.

import torch
from diffusers import PixArtAlphaPipeline


pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-LCM-XL-2-1024-MS", use_safetensors=True)

prompt = "A small cactus with a happy face in the Sahara desert."
generator = torch.Generator().manual_seed(42)

image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, generator=generator).images[0]
2024-07-02 01:22:15.056189: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-07-02 01:22:15.090603: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-02 01:22:15.773162: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]
Some weights of the model checkpoint were not used when initializing PixArtTransformer2DModel:
 ['caption_projection.y_embedding']
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the legacy (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
0%|          | 0/4 [00:00<?, ?it/s]
image
../_images/pixart-with-output_5_0.png

Convert the model to OpenVINO IR#

Let’s define the conversion function for PyTorch modules. We use ov.convert_model function to obtain OpenVINO Intermediate Representation object and ov.save_model function to save it as XML file.

from pathlib import Path

import numpy as np
import torch

import openvino as ov


def convert(model: torch.nn.Module, xml_path: str, example_input):
    xml_path = Path(xml_path)
    if not xml_path.exists():
        xml_path.parent.mkdir(parents=True, exist_ok=True)
        model.eval()
        with torch.no_grad():
            converted_model = ov.convert_model(model, example_input=example_input)
        ov.save_model(converted_model, xml_path)

        # cleanup memory
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
        torch.jit._state._clear_class_state()

PixArt-α consists of pure transformer blocks for latent diffusion: It can directly generate 1024px images from text prompts within a single sampling process.

image01.

During inference it uses text encoder T5EncoderModel, transformer Transformer2DModel and VAE decoder AutoencoderKL. Let’s convert the models from the pipeline one by one.

MODEL_DIR = Path("model")

TEXT_ENCODER_PATH = MODEL_DIR / "text_encoder.xml"
TRANSFORMER_OV_PATH = MODEL_DIR / "transformer_ir.xml"
VAE_DECODER_PATH = MODEL_DIR / "vae_decoder.xml"

Convert text encoder#

example_input = {
    "input_ids": torch.zeros(1, 120, dtype=torch.int64),
    "attention_mask": torch.zeros(1, 120, dtype=torch.int64),
}

convert(pipe.text_encoder, TEXT_ENCODER_PATH, example_input)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.
[ WARNING ]  Please fix your imports. Module %s has been moved to %s. The old module will be deleted in version %s.
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py:4371: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead
  warnings.warn(

Convert transformer#

class TransformerWrapper(torch.nn.Module):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer

    def forward(self, hidden_states=None, timestep=None, encoder_hidden_states=None, encoder_attention_mask=None, resolution=None, aspect_ratio=None):

        return self.transformer.forward(
            hidden_states,
            timestep=timestep,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            added_cond_kwargs={"resolution": resolution, "aspect_ratio": aspect_ratio},
        )


example_input = {
    "hidden_states": torch.rand([2, 4, 128, 128], dtype=torch.float32),
    "timestep": torch.tensor([999, 999]),
    "encoder_hidden_states": torch.rand([2, 120, 4096], dtype=torch.float32),
    "encoder_attention_mask": torch.rand([2, 120], dtype=torch.float32),
    "resolution": torch.tensor([[1024.0, 1024.0], [1024.0, 1024.0]]),
    "aspect_ratio": torch.tensor([[1.0], [1.0]]),
}


w_transformer = TransformerWrapper(pipe.transformer)
convert(w_transformer, TRANSFORMER_OV_PATH, example_input)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/models/embeddings.py:219: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.height != height or self.width != width:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/models/attention_processor.py:682: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if current_length != target_length:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/models/attention_processor.py:697: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attention_mask.shape[0] < batch_size * head_size:

Convert VAE decoder#

class VAEDecoderWrapper(torch.nn.Module):

    def __init__(self, vae):
        super().__init__()
        self.vae = vae

    def forward(self, latents):
        return self.vae.decode(latents, return_dict=False)


convert(VAEDecoderWrapper(pipe.vae), VAE_DECODER_PATH, (torch.zeros((1, 4, 128, 128))))
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/models/upsampling.py:146: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert hidden_states.shape[1] == self.channels
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/models/upsampling.py:162: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if hidden_states.shape[0] >= 64:

Compiling models#

Select device from dropdown list for running inference using OpenVINO.

import ipywidgets as widgets

core = ov.Core()
device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
compiled_model = core.compile_model(TRANSFORMER_OV_PATH)
compiled_vae = core.compile_model(VAE_DECODER_PATH)
compiled_text_encoder = core.compile_model(TEXT_ENCODER_PATH)

Building the pipeline#

Let’s create callable wrapper classes for compiled models to allow interaction with original pipelines. Note that all of wrapper classes return torch.Tensors instead of np.arrays.

from collections import namedtuple

EncoderOutput = namedtuple("EncoderOutput", "last_hidden_state")


class TextEncoderWrapper(torch.nn.Module):
    def __init__(self, text_encoder, dtype):
        super().__init__()
        self.text_encoder = text_encoder
        self.dtype = dtype

    def forward(self, input_ids=None, attention_mask=None):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        last_hidden_state = self.text_encoder(inputs)[0]
        return EncoderOutput(torch.from_numpy(last_hidden_state))
class TransformerWrapper(torch.nn.Module):
    def __init__(self, transformer, config):
        super().__init__()
        self.transformer = transformer
        self.config = config

    def forward(
        self,
        hidden_states=None,
        timestep=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        resolution=None,
        aspect_ratio=None,
        added_cond_kwargs=None,
        **kwargs
    ):
        inputs = {
            "hidden_states": hidden_states,
            "timestep": timestep,
            "encoder_hidden_states": encoder_hidden_states,
            "encoder_attention_mask": encoder_attention_mask,
        }
        resolution = added_cond_kwargs["resolution"]
        aspect_ratio = added_cond_kwargs["aspect_ratio"]
        if resolution is not None:
            inputs["resolution"] = resolution
            inputs["aspect_ratio"] = aspect_ratio
        outputs = self.transformer(inputs)[0]

        return [torch.from_numpy(outputs)]
class VAEWrapper(torch.nn.Module):
    def __init__(self, vae, config):
        super().__init__()
        self.vae = vae
        self.config = config

    def decode(self, latents=None, **kwargs):
        inputs = {
            "latents": latents,
        }

        outs = self.vae(inputs)
        outs = namedtuple("VAE", "sample")(torch.from_numpy(outs[0]))

        return outs

And insert wrappers instances in the pipeline:

pipe.__dict__["_internal_dict"]["_execution_device"] = pipe._execution_device  # this is to avoid some problem that can occur in the pipeline

pipe.register_modules(
    text_encoder=TextEncoderWrapper(compiled_text_encoder, pipe.text_encoder.dtype),
    transformer=TransformerWrapper(compiled_model, pipe.transformer.config),
    vae=VAEWrapper(compiled_vae, pipe.vae.config),
)
generator = torch.Generator().manual_seed(42)

image = pipe(prompt=prompt, guidance_scale=0.0, num_inference_steps=4, generator=generator).images[0]
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-717/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/diffusers/configuration_utils.py:140: FutureWarning: Accessing config attribute _execution_device directly via 'PixArtAlphaPipeline' object attribute is deprecated. Please access '_execution_device' over 'PixArtAlphaPipeline's config object instead, e.g. 'scheduler.config._execution_device'.
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
0%|          | 0/4 [00:00<?, ?it/s]
image
../_images/pixart-with-output_26_0.png

Interactive inference#

import gradio as gr


def generate(prompt, seed, negative_prompt, num_inference_steps):
    generator = torch.Generator().manual_seed(seed)
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0).images[0]
    return image


demo = gr.Interface(
    generate,
    [
        gr.Textbox(label="Caption"),
        gr.Slider(0, np.iinfo(np.int32).max, label="Seed"),
        gr.Textbox(label="Negative prompt"),
        gr.Slider(2, 20, step=1, label="Number of inference steps", value=4),
    ],
    "image",
    examples=[
        ["A small cactus with a happy face in the Sahara desert.", 42],
        ["an astronaut sitting in a diner, eating fries, cinematic, analog film", 42],
        [
            "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
            0,
        ],
        ["professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.", 0],
    ],
    allow_flagging="never",
)
try:
    demo.queue().launch(debug=False)
except Exception:
    demo.queue().launch(debug=False, share=True)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL:  http://127.0.0.1:7860

To create a public link, set share=True in launch().