""" TRELLIS.2 Text-to-3D Generator 🎨 Comic Classic Theme """ import os import shutil import torch import numpy as np from PIL import Image import tempfile import uuid from typing import Tuple from datetime import datetime import rerun as rr try: import rerun.blueprint as rrb except ImportError: rrb = None from gradio_rerun import Rerun import gradio as gr from gradio_client import Client, handle_file import spaces from diffusers import ZImagePipeline from trellis2.pipelines import Trellis2ImageTo3DPipeline import o_voxel os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') print("Loading Z-Image-Turbo...") try: z_pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16, low_cpu_mem_usage=False) device = "cuda" if torch.cuda.is_available() else "cpu" z_pipe.to(device) except Exception as e: print(f"Failed to load Z-Image-Turbo: {e}") z_pipe = None print("Loading TRELLIS.2...") try: trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') trellis_pipeline.rembg_model = None trellis_pipeline.low_vram = False trellis_pipeline.cuda() except Exception as e: print(f"Failed to load TRELLIS.2: {e}") trellis_pipeline = None rmbg_client = Client("briaai/BRIA-RMBG-2.0") def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(user_dir): shutil.rmtree(user_dir) def remove_background(input: Image.Image) -> Image.Image: with tempfile.NamedTemporaryFile(suffix='.png') as f: input = input.convert('RGB') input.save(f.name) output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0] output = Image.open(output) return output def preprocess_image(input: Image.Image) -> Image.Image: if input is None: return None has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) if has_alpha: output = input else: output = remove_background(input) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) if bbox.size == 0: return output bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = output.crop(bbox) output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output def get_seed(randomize_seed: bool, seed: int) -> int: return np.random.randint(0, MAX_SEED) if randomize_seed else seed @spaces.GPU def generate_txt2img(prompt, progress=gr.Progress(track_tqdm=True)): if z_pipe is None: raise gr.Error("Z-Image-Turbo model failed to load.") if not prompt.strip(): raise gr.Error("Please enter a prompt.") device = "cuda" if torch.cuda.is_available() else "cpu" generator = torch.Generator(device).manual_seed(42) progress(0.1, desc="Generating Image...") try: result = z_pipe( prompt=prompt, negative_prompt=None, height=1024, width=1024, num_inference_steps=9, guidance_scale=0.0, generator=generator, ) return result.images[0] except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") @spaces.GPU(duration=120) def generate_3d( image: Image.Image, seed: int, resolution: str, decimation_target: int, texture_size: int, ss_guidance_strength: float, ss_guidance_rescale: float, ss_sampling_steps: int, ss_rescale_t: float, shape_guidance: float, shape_rescale: float, shape_steps: int, shape_rescale_t: float, tex_guidance: float, tex_rescale: float, tex_steps: int, tex_rescale_t: float, req: gr.Request, progress=gr.Progress(track_tqdm=True) ) -> Tuple[str, str]: if image is None: raise gr.Error("Please provide an input image.") if trellis_pipeline is None: raise gr.Error("TRELLIS model is not loaded.") user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) progress(0.1, desc="Generating 3D...") try: outputs, latents = trellis_pipeline.run( image, seed=seed, preprocess_image=False, sparse_structure_sampler_params={"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t}, shape_slat_sampler_params={"steps": shape_steps, "guidance_strength": shape_guidance, "guidance_rescale": shape_rescale, "rescale_t": shape_rescale_t}, tex_slat_sampler_params={"steps": tex_steps, "guidance_strength": tex_guidance, "guidance_rescale": tex_rescale, "rescale_t": tex_rescale_t}, pipeline_type={"512": "512", "1024": "1024_cascade", "1536": "1536_cascade"}[resolution], return_latent=True, ) progress(0.7, desc="Processing Mesh...") mesh = outputs[0] mesh.simplify(1000000) progress(0.9, desc="Exporting GLB...") grid_size = latents[2] try: glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=trellis_pipeline.pbr_attr_layout, grid_size=grid_size, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) except RuntimeError: glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=trellis_pipeline.pbr_attr_layout, grid_size=grid_size, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=False, remesh_band=1, remesh_project=0, use_tqdm=True, ) timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S") glb_path = os.path.join(user_dir, f'output_{timestamp}.glb') glb.export(glb_path, extension_webp=False) progress(0.95, desc="Creating Viewer...") run_id = str(uuid.uuid4()) rec = rr.new_recording(application_id="TRELLIS-3D-Viewer", recording_id=run_id) if hasattr(rr, "new_recording") else rr.RecordingStream(application_id="TRELLIS-3D-Viewer", recording_id=run_id) if hasattr(rr, "RecordingStream") else rr rec.log("world", rr.Clear(recursive=True), static=True) rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True) rec.log("world/model", rr.Asset3D(path=glb_path), static=True) if rrb is not None: try: blueprint = rrb.Blueprint(rrb.Spatial3DView(origin="/world", name="3D View"), collapse_panels=True) rec.send_blueprint(blueprint) except: pass rrd_path = os.path.join(user_dir, f'output_{timestamp}.rrd') rec.save(rrd_path) torch.cuda.empty_cache() return rrd_path, glb_path except Exception as e: torch.cuda.empty_cache() raise gr.Error(f"Generation failed: {str(e)}") css = """ @import url('https://fonts.googleapis.com/css2?family=Bangers&family=Comic+Neue:wght@400;700&display=swap'); .gradio-container { background-color: #FEF9C3 !important; background-image: radial-gradient(#1F2937 1px, transparent 1px) !important; background-size: 20px 20px !important; min-height: 100vh !important; font-family: 'Comic Neue', cursive, sans-serif !important; } .huggingface-space-header, #space-header, .space-header, [class*="space-header"], .svelte-1ed2p3z, .space-header-badge, .header-badge, [data-testid="space-header"], .svelte-kqij2n, .svelte-1ax1toq, .embed-container > div:first-child { display: none !important; visibility: hidden !important; height: 0 !important; width: 0 !important; overflow: hidden !important; opacity: 0 !important; pointer-events: none !important; } footer, .footer, .gradio-container footer, .built-with, [class*="footer"], .gradio-footer, .main-footer, div[class*="footer"], .show-api, .built-with-gradio, a[href*="gradio.app"], a[href*="huggingface.co/spaces"] { display: none !important; visibility: hidden !important; height: 0 !important; padding: 0 !important; margin: 0 !important; } #col-container { max-width: 960px; margin: 0 auto; } .header-text h1 { font-family: 'Bangers', cursive !important; color: #1F2937 !important; font-size: 3.5rem !important; font-weight: 400 !important; text-align: center !important; margin-bottom: 0.5rem !important; text-shadow: 4px 4px 0px #FACC15, 6px 6px 0px #1F2937 !important; letter-spacing: 3px !important; -webkit-text-stroke: 2px #1F2937 !important; } .subtitle { text-align: center !important; font-family: 'Comic Neue', cursive !important; font-size: 1.2rem !important; color: #1F2937 !important; margin-bottom: 1.5rem !important; font-weight: 700 !important; } .gr-panel, .gr-box, .gr-form, .block, .gr-group { background: #FFFFFF !important; border: 3px solid #1F2937 !important; border-radius: 8px !important; box-shadow: 6px 6px 0px #1F2937 !important; transition: all 0.2s ease !important; } .gr-panel:hover, .block:hover { transform: translate(-2px, -2px) !important; box-shadow: 8px 8px 0px #1F2937 !important; } textarea, input[type="text"], input[type="number"] { background: #FFFFFF !important; border: 3px solid #1F2937 !important; border-radius: 8px !important; color: #1F2937 !important; font-family: 'Comic Neue', cursive !important; font-size: 1rem !important; font-weight: 700 !important; transition: all 0.2s ease !important; } textarea:focus, input[type="text"]:focus, input[type="number"]:focus { border-color: #3B82F6 !important; box-shadow: 4px 4px 0px #3B82F6 !important; outline: none !important; } .gr-button-primary, button.primary, .gr-button.primary { background: #3B82F6 !important; border: 3px solid #1F2937 !important; border-radius: 8px !important; color: #FFFFFF !important; font-family: 'Bangers', cursive !important; font-weight: 400 !important; font-size: 1.3rem !important; letter-spacing: 2px !important; padding: 14px 28px !important; box-shadow: 5px 5px 0px #1F2937 !important; transition: all 0.1s ease !important; text-shadow: 1px 1px 0px #1F2937 !important; } .gr-button-primary:hover, button.primary:hover, .gr-button.primary:hover { background: #2563EB !important; transform: translate(-2px, -2px) !important; box-shadow: 7px 7px 0px #1F2937 !important; } .gr-button-primary:active, button.primary:active, .gr-button.primary:active { transform: translate(3px, 3px) !important; box-shadow: 2px 2px 0px #1F2937 !important; } .gr-button-secondary, button.secondary { background: #EF4444 !important; border: 3px solid #1F2937 !important; border-radius: 8px !important; color: #FFFFFF !important; font-family: 'Bangers', cursive !important; font-weight: 400 !important; font-size: 1.1rem !important; letter-spacing: 1px !important; box-shadow: 4px 4px 0px #1F2937 !important; transition: all 0.1s ease !important; text-shadow: 1px 1px 0px #1F2937 !important; } .gr-button-secondary:hover, button.secondary:hover { background: #DC2626 !important; transform: translate(-2px, -2px) !important; box-shadow: 6px 6px 0px #1F2937 !important; } label, .gr-input-label, .gr-block-label { color: #1F2937 !important; font-family: 'Comic Neue', cursive !important; font-weight: 700 !important; font-size: 1rem !important; } .gr-file-upload { border: 3px dashed #1F2937 !important; border-radius: 8px !important; background: #FEF9C3 !important; } .gr-file-upload:hover { border-color: #3B82F6 !important; background: #EFF6FF !important; } ::-webkit-scrollbar { width: 12px; height: 12px; } ::-webkit-scrollbar-track { background: #FEF9C3; border: 2px solid #1F2937; } ::-webkit-scrollbar-thumb { background: #3B82F6; border: 2px solid #1F2937; border-radius: 0px; } ::-webkit-scrollbar-thumb:hover { background: #EF4444; } ::selection { background: #FACC15; color: #1F2937; } a { color: #3B82F6 !important; text-decoration: none !important; font-weight: 700 !important; } a:hover { color: #EF4444 !important; } @media (max-width: 768px) { .header-text h1 { font-size: 2.2rem !important; text-shadow: 3px 3px 0px #FACC15, 4px 4px 0px #1F2937 !important; } .gr-button-primary, button.primary { padding: 12px 20px !important; font-size: 1.1rem !important; } .gr-panel, .block { box-shadow: 4px 4px 0px #1F2937 !important; } } @media (prefers-color-scheme: dark) { .gradio-container { background-color: #FEF9C3 !important; } } """ EXAMPLES_IMAGE = [f"example-images/A ({i}).webp" for i in range(1, 72)] EXAMPLES_TEXT = [ "A Cat 3D model", "A realistic Cat 3D model", "A cartoon Cat 3D model", "A low poly Cat 3D", "A cyberpunk Cat 3D", "A robotic Cat 3D", "A Plane 3D model", "A fighter jet Plane 3D", "A vintage Plane 3D", "A Car 3D model", "A sports Car 3D", "A cyberpunk Car 3D", "A Shoe 3D model", "A sneaker Shoe 3D", "A boot Shoe 3D", "A Chair 3D model", "A Table 3D model", "A Robot 3D model", "A House 3D model", "A Spaceship 3D model", "A Motorcycle 3D model", ] if __name__ == "__main__": os.makedirs(TMP_DIR, exist_ok=True) with gr.Blocks(title="TRELLIS.2 Text-to-3D", delete_cache=(300, 300)) as demo: gr.LoginButton(value="Option: HuggingFace 'Login' for extra GPU quota +", size="sm") gr.HTML(f"") gr.HTML("""
HOME
""") gr.Markdown("# 🎮 TRELLIS.2 TEXT-TO-3D 🎮", elem_classes="header-text") gr.Markdown('

✨ Generate 3D models from text or images! 🚀

') with gr.Row(): with gr.Column(scale=1, min_width=360): with gr.Tabs(): with gr.Tab("📝 Text-to-3D"): txt_prompt = gr.Textbox(label="💬 Prompt", placeholder="e.g. A Cat 3D model", lines=2) btn_gen_img = gr.Button("1️⃣ Generate Image", variant="primary") with gr.Tab("🖼️ Image-to-3D"): gr.Markdown("Upload an image directly.") image_prompt = gr.Image(label="📷 Input Image", format="png", image_mode="RGBA", type="pil", height=350) with gr.Accordion(label="⚙️ 3D Settings", open=False): resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="🎲 Randomize Seed", value=True) decimation_target = gr.Slider(50000, 500000, label="Target Faces", value=150000, step=10000) texture_size = gr.Slider(512, 4096, label="Texture Size", value=1024, step=512) btn_gen_3d = gr.Button("2️⃣ Generate 3D", variant="primary") with gr.Accordion(label="🔧 Advanced Sampler", open=False): gr.Markdown("**Stage 1: Sparse Structure**") ss_guidance_strength = gr.Slider(1.0, 10.0, value=7.5, label="Guidance") ss_guidance_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Rescale") ss_sampling_steps = gr.Slider(1, 50, value=12, label="Steps") ss_rescale_t = gr.Slider(1.0, 6.0, value=5.0, label="Rescale T") gr.Markdown("**Stage 2: Shape**") shape_guidance = gr.Slider(1.0, 10.0, value=7.5, label="Guidance") shape_rescale = gr.Slider(0.0, 1.0, value=0.5, label="Rescale") shape_steps = gr.Slider(1, 50, value=12, label="Steps") shape_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T") gr.Markdown("**Stage 3: Material**") tex_guidance = gr.Slider(1.0, 10.0, value=1.0, label="Guidance") tex_rescale = gr.Slider(0.0, 1.0, value=0.0, label="Rescale") tex_steps = gr.Slider(1, 50, value=12, label="Steps") tex_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T") with gr.Column(scale=2): gr.Markdown("### 🎯 3D Output") rerun_output = Rerun(label="3D Viewer", height=600) download_btn = gr.DownloadButton(label="3️⃣ Download GLB", variant="primary") gr.Examples(examples=[[img] for img in EXAMPLES_IMAGE], inputs=[image_prompt], label="🖼️ Image Examples") gr.Examples(examples=[[txt] for txt in EXAMPLES_TEXT], inputs=[txt_prompt], label="📝 Text Examples") demo.load(start_session) demo.unload(end_session) btn_gen_img.click(generate_txt2img, inputs=[txt_prompt], outputs=[image_prompt]).then( preprocess_image, inputs=[image_prompt], outputs=[image_prompt] ) image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt]) btn_gen_3d.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then( generate_3d, inputs=[ image_prompt, seed, resolution, decimation_target, texture_size, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_guidance, shape_rescale, shape_steps, shape_rescale_t, tex_guidance, tex_rescale, tex_steps, tex_rescale_t, ], outputs=[rerun_output, download_btn], ) demo.launch()