TripoSR feedforward 3D reconstruction from a single image and OpenVINO#

This Jupyter notebook can be launched after a local installation only.

Github

Warning

Important note: This notebook requires python >= 3.9. Please make sure that your environment fulfill to this requirement before running it

TripoSR is a state-of-the-art open-source model for fast feedforward 3D reconstruction from a single image, developed in collaboration between Tripo AI and Stability AI.

You can find the source code on GitHub and demo on HuggingFace. Also, you can read the paper TripoSR: Fast 3D Object Reconstruction from a Single Image.

Teaser Video

Teaser Video#

Table of contents:

Installation Instructions#

This is a self-contained example that relies solely on its own code.

We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. For details, please refer to Installation Guide.

Prerequisites#

%pip install -q "gradio>=4.19" "torch==2.2.2" "torchvision<0.18.0" rembg trimesh einops "omegaconf>=2.3.0" "transformers>=4.35.0" "openvino>=2024.0.0" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "git+https://github.com/tatsy/torchmcubes.git" --extra-index-url https://download.pytorch.org/whl/cpu
import sys
from pathlib import Path

if not Path("TripoSR").exists():
    !git clone https://huggingface.co/spaces/stabilityai/TripoSR

sys.path.append("TripoSR")

Get the original model#

import os

from tsr.system import TSR


model = TSR.from_pretrained(
    "stabilityai/TripoSR",
    config_name="config.yaml",
    weight_name="model.ckpt",
)
model.renderer.set_chunk_size(131072)
model.to("cpu");

Convert the model to OpenVINO IR#

Define the conversion function for PyTorch modules. We use ov.convert_model function to obtain OpenVINO Intermediate Representation object and ov.save_model function to save it as XML file.

import torch

import openvino as ov


def convert(model: torch.nn.Module, xml_path: str, example_input):
    xml_path = Path(xml_path)
    if not xml_path.exists():
        xml_path.parent.mkdir(parents=True, exist_ok=True)
        with torch.no_grad():
            converted_model = ov.convert_model(model, example_input=example_input)
        ov.save_model(converted_model, xml_path, compress_to_fp16=False)

        # cleanup memory
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
        torch.jit._state._clear_class_state()

The original model is a pipeline of several models. There are image_tokenizer, tokenizer, backbone and post_processor. image_tokenizer contains ViTModel that consists of ViTPatchEmbeddings, ViTEncoder and ViTPooler. tokenizer is Triplane1DTokenizer, backbone is Transformer1D, post_processor is TriplaneUpsampleNetwork. Convert all internal models one by one.

VIT_PATCH_EMBEDDINGS_OV_PATH = Path("models/vit_patch_embeddings_ir.xml")


class PatchEmbedingWrapper(torch.nn.Module):
    def __init__(self, patch_embeddings):
        super().__init__()
        self.patch_embeddings = patch_embeddings

    def forward(self, pixel_values, interpolate_pos_encoding=True):
        outputs = self.patch_embeddings(pixel_values=pixel_values, interpolate_pos_encoding=True)
        return outputs


example_input = {
    "pixel_values": torch.rand([1, 3, 512, 512], dtype=torch.float32),
}

convert(
    PatchEmbedingWrapper(model.image_tokenizer.model.embeddings.patch_embeddings),
    VIT_PATCH_EMBEDDINGS_OV_PATH,
    example_input,
)
VIT_ENCODER_OV_PATH = Path("models/vit_encoder_ir.xml")


class EncoderWrapper(torch.nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(
        self,
        hidden_states=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        outputs = self.encoder(
            hidden_states=hidden_states,
        )

        return outputs.last_hidden_state


example_input = {
    "hidden_states": torch.rand([1, 1025, 768], dtype=torch.float32),
}

convert(
    EncoderWrapper(model.image_tokenizer.model.encoder),
    VIT_ENCODER_OV_PATH,
    example_input,
)
VIT_POOLER_OV_PATH = Path("models/vit_pooler_ir.xml")
convert(
    model.image_tokenizer.model.pooler,
    VIT_POOLER_OV_PATH,
    torch.rand([1, 1025, 768], dtype=torch.float32),
)
TOKENIZER_OV_PATH = Path("models/tokenizer_ir.xml")
convert(model.tokenizer, TOKENIZER_OV_PATH, torch.tensor(1))
example_input = {
    "hidden_states": torch.rand([1, 1024, 3072], dtype=torch.float32),
    "encoder_hidden_states": torch.rand([1, 1025, 768], dtype=torch.float32),
}

BACKBONE_OV_PATH = Path("models/backbone_ir.xml")
convert(model.backbone, BACKBONE_OV_PATH, example_input)
POST_PROCESSOR_OV_PATH = Path("models/post_processor_ir.xml")
convert(
    model.post_processor,
    POST_PROCESSOR_OV_PATH,
    torch.rand([1, 3, 1024, 32, 32], dtype=torch.float32),
)

Compiling models and prepare pipeline#

Select device from dropdown list for running inference using OpenVINO.

import ipywidgets as widgets


core = ov.Core()
device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
compiled_vit_patch_embeddings = core.compile_model(VIT_PATCH_EMBEDDINGS_OV_PATH, device.value)
compiled_vit_model_encoder = core.compile_model(VIT_ENCODER_OV_PATH, device.value)
compiled_vit_model_pooler = core.compile_model(VIT_POOLER_OV_PATH, device.value)

compiled_tokenizer = core.compile_model(TOKENIZER_OV_PATH, device.value)
compiled_backbone = core.compile_model(BACKBONE_OV_PATH, device.value)
compiled_post_processor = core.compile_model(POST_PROCESSOR_OV_PATH, device.value)

Let’s create callable wrapper classes for compiled models to allow interaction with original TSR class. Note that all of wrapper classes return torch.Tensors instead of np.arrays.

from collections import namedtuple


class VitPatchEmdeddingsWrapper(torch.nn.Module):
    def __init__(self, vit_patch_embeddings, model):
        super().__init__()
        self.vit_patch_embeddings = vit_patch_embeddings
        self.projection = model.projection

    def forward(self, pixel_values, interpolate_pos_encoding=False):
        inputs = {
            "pixel_values": pixel_values,
        }
        outs = self.vit_patch_embeddings(inputs)[0]

        return torch.from_numpy(outs)


class VitModelEncoderWrapper(torch.nn.Module):
    def __init__(self, vit_model_encoder):
        super().__init__()
        self.vit_model_encoder = vit_model_encoder

    def forward(
        self,
        hidden_states,
        head_mask,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        inputs = {
            "hidden_states": hidden_states.detach().numpy(),
        }

        outs = self.vit_model_encoder(inputs)
        outputs = namedtuple("BaseModelOutput", ("last_hidden_state", "hidden_states", "attentions"))

        return outputs(torch.from_numpy(outs[0]), None, None)


class VitModelPoolerWrapper(torch.nn.Module):
    def __init__(self, vit_model_pooler):
        super().__init__()
        self.vit_model_pooler = vit_model_pooler

    def forward(self, hidden_states):
        outs = self.vit_model_pooler(hidden_states.detach().numpy())[0]

        return torch.from_numpy(outs)


class TokenizerWrapper(torch.nn.Module):
    def __init__(self, tokenizer, model):
        super().__init__()
        self.tokenizer = tokenizer
        self.detokenize = model.detokenize

    def forward(self, batch_size):
        outs = self.tokenizer(batch_size)[0]

        return torch.from_numpy(outs)


class BackboneWrapper(torch.nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

    def forward(self, hidden_states, encoder_hidden_states):
        inputs = {
            "hidden_states": hidden_states,
            "encoder_hidden_states": encoder_hidden_states.detach().numpy(),
        }

        outs = self.backbone(inputs)[0]

        return torch.from_numpy(outs)


class PostProcessorWrapper(torch.nn.Module):
    def __init__(self, post_processor):
        super().__init__()
        self.post_processor = post_processor

    def forward(self, triplanes):
        outs = self.post_processor(triplanes)[0]

        return torch.from_numpy(outs)

Replace all models in the original model by wrappers instances:

model.image_tokenizer.model.embeddings.patch_embeddings = VitPatchEmdeddingsWrapper(
    compiled_vit_patch_embeddings,
    model.image_tokenizer.model.embeddings.patch_embeddings,
)
model.image_tokenizer.model.encoder = VitModelEncoderWrapper(compiled_vit_model_encoder)
model.image_tokenizer.model.pooler = VitModelPoolerWrapper(compiled_vit_model_pooler)

model.tokenizer = TokenizerWrapper(compiled_tokenizer, model.tokenizer)
model.backbone = BackboneWrapper(compiled_backbone)
model.post_processor = PostProcessorWrapper(compiled_post_processor)

Interactive inference#

import tempfile

import gradio as gr
import numpy as np
import rembg
from PIL import Image

from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation


rembg_session = rembg.new_session()


def check_input_image(input_image):
    if input_image is None:
        raise gr.Error("No image uploaded!")


def preprocess(input_image, do_remove_background, foreground_ratio):
    def fill_background(image):
        image = np.array(image).astype(np.float32) / 255.0
        image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
        image = Image.fromarray((image * 255.0).astype(np.uint8))
        return image

    if do_remove_background:
        image = input_image.convert("RGB")
        image = remove_background(image, rembg_session)
        image = resize_foreground(image, foreground_ratio)
        image = fill_background(image)
    else:
        image = input_image
        if image.mode == "RGBA":
            image = fill_background(image)
    return image


def generate(image):
    scene_codes = model(image, "cpu")  # the device is provided for the image processorit is
    mesh = model.extract_mesh(scene_codes)[0]
    mesh = to_gradio_3d_orientation(mesh)
    mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
    mesh.export(mesh_path.name)
    return mesh_path.name


with gr.Blocks() as demo:
    with gr.Row(variant="panel"):
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(
                    label="Input Image",
                    image_mode="RGBA",
                    sources="upload",
                    type="pil",
                    elem_id="content_image",
                )
                processed_image = gr.Image(label="Processed Image", interactive=False)
            with gr.Row():
                with gr.Group():
                    do_remove_background = gr.Checkbox(label="Remove Background", value=True)
                    foreground_ratio = gr.Slider(
                        label="Foreground Ratio",
                        minimum=0.5,
                        maximum=1.0,
                        value=0.85,
                        step=0.05,
                    )
            with gr.Row():
                submit = gr.Button("Generate", elem_id="generate", variant="primary")
        with gr.Column():
            with gr.Tab("Model"):
                output_model = gr.Model3D(
                    label="Output Model",
                    interactive=False,
                )
    with gr.Row(variant="panel"):
        gr.Examples(
            examples=[os.path.join("TripoSR/examples", img_name) for img_name in sorted(os.listdir("TripoSR/examples"))],
            inputs=[input_image],
            outputs=[processed_image, output_model],
            label="Examples",
            examples_per_page=20,
        )
    submit.click(fn=check_input_image, inputs=[input_image]).success(
        fn=preprocess,
        inputs=[input_image, do_remove_background, foreground_ratio],
        outputs=[processed_image],
    ).success(
        fn=generate,
        inputs=[processed_image],
        outputs=[output_model],
    )

try:
    demo.launch(debug=True, height=680)
except Exception:
    demo.queue().launch(share=True, debug=True, height=680)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/