Spaces:
Running
on
Zero
Running
on
Zero
| try: | |
| import spaces | |
| GPU = spaces.GPU | |
| print("spaces GPU is available") | |
| except ImportError: | |
| def GPU(duration=15): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| print("spaces GPU is NOT available, using fallback decorator") | |
| import os | |
| import torch | |
| import numpy as np | |
| import imageio | |
| import json | |
| import time | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import einops | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models import * | |
| from utils import * | |
| from transformers import T5TokenizerFast, UMT5EncoderModel | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): | |
| def index_for_timestep(self, timestep, schedule_timesteps=None): | |
| if schedule_timesteps is None: | |
| schedule_timesteps = self.timesteps | |
| return torch.argmin( | |
| (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item() | |
| class GenerationSystem(nn.Module): | |
| def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False): | |
| super().__init__() | |
| self.device = device | |
| self.offload_t5 = offload_t5 | |
| self.offload_vae = offload_vae | |
| self.latent_dim = 48 | |
| self.temporal_downsample_factor = 4 | |
| self.spatial_downsample_factor = 16 | |
| self.feat_dim = 1024 | |
| self.latent_patch_size = 2 | |
| self.denoising_steps = [0, 250, 500, 750] | |
| model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" | |
| self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval() | |
| from models.autoencoder_kl_wan import WanCausalConv3d | |
| with torch.no_grad(): | |
| for name, module in self.vae.named_modules(): | |
| if isinstance(module, WanCausalConv3d): | |
| time_pad = module._padding[4] | |
| module.padding = (0, module._padding[2], module._padding[0]) | |
| module._padding = (0, 0, 0, 0, 0, 0) | |
| module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone()) | |
| self.vae.requires_grad_(False) | |
| self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device)) | |
| self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device)) | |
| self.latent_scale_fn = lambda x: (x - self.latents_mean) / self.latents_std | |
| self.latent_unscale_fn = lambda x: x * self.latents_std + self.latents_mean | |
| self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer") | |
| self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu") | |
| self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False) | |
| self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim))) | |
| weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1]) | |
| bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim) | |
| extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02 | |
| extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim) | |
| self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone()) | |
| self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone()) | |
| self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device) | |
| self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3) | |
| self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device)) | |
| self.transformer.disable_gradient_checkpointing() | |
| self.transformer.gradient_checkpointing = False | |
| self.add_feedback_for_transformer() | |
| if ckpt_path is not None: | |
| state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| self.transformer.load_state_dict(state_dict["transformer"]) | |
| self.recon_decoder.load_state_dict(state_dict["recon_decoder"]) | |
| print(f"Loaded {ckpt_path}.") | |
| from quant import FluxFp8GeMMProcessor | |
| FluxFp8GeMMProcessor(self.transformer) | |
| del self.vae.post_quant_conv, self.vae.decoder | |
| self.vae.to(self.device if not self.offload_vae else "cpu") | |
| self.transformer.to(self.device) | |
| def add_feedback_for_transformer(self): | |
| self.use_feedback = True | |
| self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim))) | |
| def encode_text(self, texts): | |
| max_sequence_length = 512 | |
| text_inputs = self.tokenizer( | |
| texts, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| if getattr(self, "offload_t5", False): | |
| text_input_ids = text_inputs.input_ids.to("cpu") | |
| mask = text_inputs.attention_mask.to("cpu") | |
| else: | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| mask = text_inputs.attention_mask.to(self.device) | |
| seq_lens = mask.gt(0).sum(dim=1).long() | |
| if getattr(self, "offload_t5", False): | |
| with torch.no_grad(): | |
| text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device) | |
| else: | |
| text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state | |
| text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)] | |
| text_embeds = torch.stack( | |
| [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0 | |
| ) | |
| return text_embeds.float() | |
| def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True): | |
| out = self.transformer( | |
| hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1), | |
| timestep=t, | |
| encoder_hidden_states=text_embeds, | |
| return_dict=False, | |
| )[0] | |
| v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1) | |
| sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device) | |
| latents_pred_2d = noisy_latents - sigma * v_pred | |
| if need_3d_mode: | |
| scene_params = self.recon_decoder( | |
| einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2), | |
| einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2), | |
| cameras | |
| ).flatten(1, -2) | |
| images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white") | |
| latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode( | |
| einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float() | |
| ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype) | |
| return { | |
| '2d': latents_pred_2d, | |
| '3d': latents_pred_3d if need_3d_mode else None, | |
| 'rgb_3d': images_pred if need_3d_mode else None, | |
| 'scene': scene_params if need_3d_mode else None, | |
| 'feat': feats | |
| } | |
| def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None): | |
| with torch.no_grad(): | |
| batch_size = 1 | |
| cameras = cameras.to(self.device).unsqueeze(0) | |
| if cameras.shape[1] != n_frame: | |
| render_cameras = cameras.clone() | |
| cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0) | |
| else: | |
| render_cameras = cameras | |
| cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None) | |
| render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None) | |
| text = "[Static] " + text | |
| text_embeds = self.encode_text([text]) | |
| masks = torch.zeros(batch_size, n_frame, device=self.device) | |
| condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) | |
| if image is not None: | |
| image = image.to(self.device) | |
| latent = self.latent_scale_fn(self.vae.encode( | |
| image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float() | |
| ).latent_dist.sample().to(self.device)).squeeze(2) | |
| masks[:, image_index] = 1 | |
| condition_latents[:, :, image_index] = latent | |
| raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor) | |
| raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame) | |
| noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) | |
| noisy_latents = noise | |
| torch.cuda.empty_cache() | |
| if self.use_feedback: | |
| prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) | |
| prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) | |
| for i in range(len(self.denoising_steps)): | |
| t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device) | |
| t = self.timesteps[t_ids] | |
| if self.use_feedback: | |
| _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1) | |
| else: | |
| _condition_latents = condition_latents | |
| if i < len(self.denoising_steps) - 1: | |
| out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True) | |
| latents_pred = out["3d"] | |
| if self.use_feedback: | |
| prev_latents_pred = latents_pred | |
| prev_feats = out['feat'] | |
| noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise)) | |
| else: | |
| out = self.transformer( | |
| hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1), | |
| timestep=t, | |
| encoder_hidden_states=text_embeds, | |
| return_dict=False, | |
| )[0] | |
| v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1) | |
| sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device) | |
| latents_pred = noisy_latents - sigma * v_pred | |
| scene_params = self.recon_decoder( | |
| einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2), | |
| einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2), | |
| cameras | |
| ).flatten(1, -2) | |
| if video_output_path is not None: | |
| interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white") | |
| interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C') | |
| interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))] | |
| imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1) | |
| scene_params = scene_params[0] | |
| scene_params = scene_params.detach().cpu() | |
| return scene_params, ref_w2c, T_norm | |
| # Initialize the model globally (outside GPU decorator) | |
| print("Initializing model...") | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--ckpt", default=None) | |
| parser.add_argument("--gpu", type=int, default=0) | |
| parser.add_argument("--offload_t5", action="store_true", help="Offload T5 encoder to CPU to save GPU memory") | |
| args, _ = parser.parse_known_args() | |
| # Ensure model.ckpt exists, download if not present | |
| if args.ckpt is None: | |
| from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | |
| ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt") | |
| if not os.path.exists(ckpt_path): | |
| print("Downloading model checkpoint...") | |
| hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False) | |
| else: | |
| ckpt_path = args.ckpt | |
| device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on device: {device}") | |
| generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device, offload_t5=args.offload_t5) | |
| print("Model loaded successfully!") | |
| # GPU-decorated generation function with 15-second budget | |
| def generate_scene( | |
| image_prompt, | |
| text_prompt, | |
| camera_json, | |
| resolution, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Generate 3D scene from image/text prompts and camera trajectory. | |
| Args: | |
| image_prompt: PIL Image or None | |
| text_prompt: str | |
| camera_json: JSON string with camera trajectory | |
| resolution: str in format "NxHxW" | |
| """ | |
| try: | |
| progress(0, desc="Parsing inputs...") | |
| # Parse resolution | |
| n_frame, image_height, image_width = [int(x) for x in resolution.split('x')] | |
| # Parse camera JSON | |
| try: | |
| camera_data = json.loads(camera_json) | |
| if "cameras" not in camera_data or len(camera_data["cameras"]) == 0: | |
| return None, "Error: No cameras found in JSON" | |
| except json.JSONDecodeError as e: | |
| return None, f"Error: Invalid JSON format: {str(e)}" | |
| progress(0.1, desc="Processing camera trajectory...") | |
| # Convert cameras to tensor | |
| cameras = [] | |
| for cam in camera_data["cameras"]: | |
| quat = cam["quaternion"] # [w, x, y, z] | |
| pos = cam["position"] # [x, y, z] | |
| fx = cam.get("fx", 0.5 / np.tan(0.5 * 60 * np.pi / 180) * image_height) | |
| fy = cam.get("fy", 0.5 / np.tan(0.5 * 60 * np.pi / 180) * image_height) | |
| cx = cam.get("cx", 0.5 * image_width) | |
| cy = cam.get("cy", 0.5 * image_height) | |
| camera_tensor = np.array([ | |
| quat[0], quat[1], quat[2], quat[3], # quaternion | |
| pos[0], pos[1], pos[2], # position | |
| fx / image_width, fy / image_height, # normalized focal lengths | |
| cx / image_width, cy / image_height # normalized principal point | |
| ], dtype=np.float32) | |
| cameras.append(camera_tensor) | |
| cameras = torch.from_numpy(np.stack(cameras, axis=0)) | |
| # Process image prompt | |
| image = None | |
| if image_prompt is not None: | |
| progress(0.2, desc="Processing image prompt...") | |
| # Convert PIL to tensor and resize | |
| img = image_prompt.convert('RGB') | |
| w, h = img.size | |
| # Center crop | |
| if image_height / h > image_width / w: | |
| scale = image_height / h | |
| else: | |
| scale = image_width / w | |
| new_h = int(image_height / scale) | |
| new_w = int(image_width / scale) | |
| img = img.crop(( | |
| (w - new_w) // 2, (h - new_h) // 2, | |
| new_w + (w - new_w) // 2, new_h + (h - new_h) // 2 | |
| )).resize((image_width, image_height)) | |
| image = torch.from_numpy(np.array(img)).float().permute(2, 0, 1) / 255.0 * 2 - 1 | |
| progress(0.3, desc="Generating 3D scene (this takes ~7 seconds)...") | |
| # Generate scene | |
| output_path = f"/tmp/flashworld_output_{int(time.time())}.mp4" | |
| scene_params, ref_w2c, T_norm = generation_system.generate( | |
| cameras=cameras, | |
| n_frame=n_frame, | |
| image=image, | |
| text=text_prompt, | |
| image_index=0, | |
| image_height=image_height, | |
| image_width=image_width, | |
| video_output_path=output_path | |
| ) | |
| progress(0.9, desc="Exporting result...") | |
| # Export to PLY | |
| ply_path = f"/tmp/flashworld_output_{int(time.time())}.ply" | |
| export_ply_for_gaussians(ply_path, scene_params, opacity_threshold=0.001, T_norm=T_norm) | |
| progress(1.0, desc="Done!") | |
| return ply_path, f"Generation successful! Scene contains {scene_params.shape[0]} Gaussians." | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error during generation: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return None, error_msg | |
| # Create Gradio interface | |
| def create_demo(): | |
| with gr.Blocks(title="FlashWorld: Fast 3D Scene Generation") as demo: | |
| gr.Markdown(""" | |
| # FlashWorld: High-quality 3D Scene Generation within Seconds | |
| Generate 3D scenes in ~7 seconds from text or image prompts with camera trajectory! | |
| **Note:** This demo uses ZeroGPU with a 15-second budget. Please ensure your camera trajectory is reasonable. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input controls | |
| gr.Markdown("### 1. Prompts") | |
| image_input = gr.Image(label="Image Prompt (Optional)", type="pil") | |
| text_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="A beautiful mountain landscape with trees...", | |
| value="" | |
| ) | |
| gr.Markdown("### 2. Camera Trajectory") | |
| camera_json_input = gr.Code( | |
| label="Camera JSON", | |
| language="json", | |
| value="""{ | |
| "cameras": [ | |
| { | |
| "quaternion": [1, 0, 0, 0], | |
| "position": [0, 0, 0], | |
| "fx": 352.0, | |
| "fy": 352.0, | |
| "cx": 352.0, | |
| "cy": 240.0 | |
| }, | |
| { | |
| "quaternion": [1, 0, 0, 0], | |
| "position": [0, 0, -0.5], | |
| "fx": 352.0, | |
| "fy": 352.0, | |
| "cx": 352.0, | |
| "cy": 240.0 | |
| } | |
| ] | |
| }""", | |
| lines=15 | |
| ) | |
| gr.Markdown("### 3. Resolution") | |
| resolution_input = gr.Dropdown( | |
| label="Resolution (NxHxW)", | |
| choices=["24x480x704", "24x704x480"], | |
| value="24x480x704" | |
| ) | |
| generate_btn = gr.Button("Generate 3D Scene", variant="primary", size="lg") | |
| with gr.Column(): | |
| # Output | |
| gr.Markdown("### Output") | |
| output_file = gr.File(label="Download PLY File") | |
| output_message = gr.Textbox(label="Status", lines=3) | |
| gr.Markdown(""" | |
| ### Instructions: | |
| 1. **Optional:** Upload an image prompt | |
| 2. **Optional:** Enter a text description | |
| 3. **Required:** Provide camera trajectory as JSON | |
| 4. Select resolution (24 frames recommended) | |
| 5. Click "Generate 3D Scene" | |
| The camera JSON should contain an array of cameras with: | |
| - `quaternion`: [w, x, y, z] rotation | |
| - `position`: [x, y, z] translation | |
| - `fx`, `fy`: focal lengths (pixels) | |
| - `cx`, `cy`: principal point (pixels) | |
| **Tips:** | |
| - Generation takes ~7 seconds on GPU | |
| - Download the PLY file to view in 3D viewers | |
| - Use reasonable camera trajectories (not too many frames) | |
| """) | |
| # Connect the button | |
| generate_btn.click( | |
| fn=generate_scene, | |
| inputs=[image_input, text_input, camera_json_input, resolution_input], | |
| outputs=[output_file, output_message] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |