import gc import os import shutil import time from datetime import datetime import io import sys os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import cv2 import gradio as gr import numpy as np import spaces import torch from PIL import Image from pillow_heif import register_heif_opener register_heif_opener() from src.utils.inference_utils import load_and_preprocess_images from src.utils.geometry import ( depth_edge, normals_edge ) from src.utils.visual_util import ( convert_predictions_to_glb_scene, segment_sky, download_file_from_url ) from src.utils.save_utils import save_camera_params, save_gs_ply, process_ply_to_splat, convert_gs_to_ply from src.utils.render_utils import render_interpolated_video import onnxruntime # Initialize model - this will be done on GPU when needed model = None # Global variable to store current terminal output current_terminal_output = "" # Helper class to capture terminal output class TeeOutput: """Capture output while still printing to console""" def __init__(self, max_chars=10000): self.terminal = sys.stdout self.log = io.StringIO() self.max_chars = max_chars # 限制最大字符数 def write(self, message): global current_terminal_output self.terminal.write(message) self.log.write(message) # 获取当前内容并限制长度 content = self.log.getvalue() if len(content) > self.max_chars: # 只保留最后 max_chars 个字符 content = "...(earlier output truncated)...\n" + content[-self.max_chars:] self.log = io.StringIO() self.log.write(content) current_terminal_output = self.log.getvalue() def flush(self): self.terminal.flush() def getvalue(self): return self.log.getvalue() def clear(self): global current_terminal_output self.log = io.StringIO() current_terminal_output = "" # ------------------------------------------------------------------------- # Model inference # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_model( target_dir, confidence_percentile: float = 10, edge_normal_threshold: float = 5.0, edge_depth_threshold: float = 0.03, apply_confidence_mask: bool = True, apply_edge_mask: bool = True, ): """ Run the WorldMirror model on images in the 'target_dir/images' folder and return predictions. """ global model import torch # Ensure torch is available in function scope from src.models.models.worldmirror import WorldMirror from src.models.utils.geometry import depth_to_world_coords_points print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # Initialize model if not already done if model is None: model = WorldMirror.from_pretrained("tencent/HunyuanWorld-Mirror").to(device) else: model.to(device) model.eval() # Load images using WorldMirror's load_images function print("Loading images...") image_folder_path = os.path.join(target_dir, "images") image_file_paths = [os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)] img = load_and_preprocess_images(image_file_paths).to(device) print(f"Loaded {img.shape[1]} images") if img.shape[1] == 0: raise ValueError("No images found. Check your upload.") # Run model inference print("Running inference...") inputs = {} inputs['img'] = img use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported() if use_amp: amp_dtype = torch.bfloat16 else: amp_dtype = torch.float32 with torch.amp.autocast('cuda', enabled=bool(use_amp), dtype=amp_dtype): predictions = model(inputs) # img imgs = inputs["img"].permute(0, 1, 3, 4, 2) imgs = imgs[0].detach().cpu().numpy() # S H W 3 # depth output depth_preds = predictions["depth"] depth_conf = predictions["depth_conf"] depth_preds = depth_preds[0].detach().cpu().numpy() # S H W 1 depth_conf = depth_conf[0].detach().cpu().numpy() # S H W # normal output normal_preds = predictions["normals"] # S H W 3 normal_preds = normal_preds[0].detach().cpu().numpy() # S H W 3 # camera parameters camera_poses = predictions["camera_poses"][0].detach().cpu().numpy() # [S,4,4] camera_intrs = predictions["camera_intrs"][0].detach().cpu().numpy() # [S,3,3] # points output pts3d_preds = depth_to_world_coords_points(predictions["depth"][0, ..., 0], predictions["camera_poses"][0], predictions["camera_intrs"][0])[0] pts3d_preds = pts3d_preds.detach().cpu().numpy() # S H W 3 pts3d_conf = depth_conf # S H W # sky mask segmentation if not os.path.exists("skyseg.onnx"): print("Downloading skyseg.onnx...") download_file_from_url( "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" ) skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") sky_mask_list = [] for i, img_path in enumerate([os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]): sky_mask = segment_sky(img_path, skyseg_session) # Resize mask to match H×W if needed if sky_mask.shape[0] != imgs.shape[1] or sky_mask.shape[1] != imgs.shape[2]: sky_mask = cv2.resize(sky_mask, (imgs.shape[2], imgs.shape[1])) sky_mask_list.append(sky_mask) sky_mask = np.stack(sky_mask_list, axis=0) # [S, H, W] sky_mask = sky_mask>0 # mask computation final_mask_list = [] for i in range(inputs["img"].shape[1]): final_mask = None if apply_confidence_mask: # compute confidence mask based on the pointmap confidence confidences = pts3d_conf[i, :, :] # [H, W] percentile_threshold = np.quantile(confidences, confidence_percentile / 100.0) conf_mask = confidences >= percentile_threshold if final_mask is None: final_mask = conf_mask else: final_mask = final_mask & conf_mask if apply_edge_mask: # compute edge mask based on the normalmap normal_pred = normal_preds[i] # [H, W, 3] normal_edges = normals_edge( normal_pred, tol=edge_normal_threshold, mask=final_mask ) # compute depth mask based on the depthmap depth_pred = depth_preds[i, :, :, 0] # [H, W] depth_edges = depth_edge( depth_pred, rtol=edge_depth_threshold, mask=final_mask ) edge_mask = ~(depth_edges & normal_edges) if final_mask is None: final_mask = edge_mask else: final_mask = final_mask & edge_mask final_mask_list.append(final_mask) if final_mask_list[0] is not None: final_mask = np.stack(final_mask_list, axis=0) # [S, H, W] else: final_mask = np.ones(pts3d_conf.shape[:3], dtype=bool) # [S, H, W] # gaussian splatting output if "splats" in predictions: splats_dict = {} splats_dict['means'] = predictions["splats"]["means"] splats_dict['scales'] = predictions["splats"]["scales"] splats_dict['quats'] = predictions["splats"]["quats"] splats_dict['opacities'] = predictions["splats"]["opacities"] if "sh" in predictions["splats"]: splats_dict['sh'] = predictions["splats"]["sh"] if "colors" in predictions["splats"]: splats_dict['colors'] = predictions["splats"]["colors"] # output lists outputs = {} outputs['images'] = imgs outputs['world_points'] = pts3d_preds outputs['depth'] = depth_preds outputs['normal'] = normal_preds outputs['final_mask'] = final_mask outputs['sky_mask'] = sky_mask outputs['camera_poses'] = camera_poses outputs['camera_intrs'] = camera_intrs if "splats" in predictions: outputs['splats'] = splats_dict # Process data for visualization tabs (depth, normal) processed_data = prepare_visualization_data( outputs, inputs ) # Clean up torch.cuda.empty_cache() return outputs, processed_data # ------------------------------------------------------------------------- # Update and navigation function # ------------------------------------------------------------------------- def update_view_info(current_view, total_views, view_type="Depth"): """Update view information display""" return f"""
{view_type} View Navigation | Current: View {current_view} / {total_views} views
""" def update_view_selectors(processed_data): """Update view selector sliders and info displays based on available views""" if processed_data is None or len(processed_data) == 0: num_views = 1 else: num_views = len(processed_data) # 确保 num_views 至少为 1 num_views = max(1, num_views) # 更新滑块的最大值和视图信息,使用 gr.update() 而不是创建新组件 depth_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1) normal_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1) # 更新视图信息显示 depth_info_update = update_view_info(1, num_views, "Depth") normal_info_update = update_view_info(1, num_views, "Normal") return ( depth_slider_update, # depth_view_slider normal_slider_update, # normal_view_slider depth_info_update, # depth_view_info normal_info_update, # normal_view_info ) def get_view_data_by_index(processed_data, view_index): """Get view data by index, handling bounds""" if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index): """Update depth view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None return render_depth_visualization(view_data["depth"], mask=view_data.get("mask")) def update_normal_view(processed_data, view_index): """Update normal view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None return render_normal_visualization(view_data["normal"], mask=view_data.get("mask")) def initialize_depth_normal_views(processed_data): """Initialize the depth and normal view displays with the first view data""" if processed_data is None or len(processed_data) == 0: return None, None # Use update functions to ensure confidence filtering is applied from the start depth_vis = update_depth_view(processed_data, 0) normal_vis = update_normal_view(processed_data, 0) return depth_vis, normal_vis # ------------------------------------------------------------------------- # File upload and update preview gallery # ------------------------------------------------------------------------- def process_uploaded_files(files, time_interval=1.0): """ Process uploaded files by extracting video frames or copying images. Args: files: List of uploaded file objects (videos or images) time_interval: Interval in seconds for video frame extraction Returns: tuple: (target_dir, image_paths) where target_dir is the output directory and image_paths is a list of processed image file paths """ gc.collect() torch.cuda.empty_cache() # Create unique output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" images_dir = os.path.join(target_dir, "images") if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(images_dir) image_paths = [] if files is None: return target_dir, image_paths video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] for file_data in files: # Get file path if isinstance(file_data, dict) and "name" in file_data: src_path = file_data["name"] else: src_path = str(file_data) ext = os.path.splitext(src_path)[1].lower() base_name = os.path.splitext(os.path.basename(src_path))[0] # Process video: extract frames if ext in video_exts: cap = cv2.VideoCapture(src_path) fps = cap.get(cv2.CAP_PROP_FPS) interval = int(fps * time_interval) frame_count = 0 saved_count = 0 while True: ret, frame = cap.read() if not ret: break frame_count += 1 if frame_count % interval == 0: dst_path = os.path.join(images_dir, f"{base_name}_{saved_count:06}.png") cv2.imwrite(dst_path, frame) image_paths.append(dst_path) saved_count += 1 cap.release() print(f"Extracted {saved_count} frames from: {os.path.basename(src_path)}") # Process HEIC/HEIF: convert to JPEG elif ext in [".heic", ".heif"]: try: with Image.open(src_path) as img: if img.mode not in ("RGB", "L"): img = img.convert("RGB") dst_path = os.path.join(images_dir, f"{base_name}.jpg") img.save(dst_path, "JPEG", quality=95) image_paths.append(dst_path) print(f"Converted HEIC: {os.path.basename(src_path)} -> {os.path.basename(dst_path)}") except Exception as e: print(f"HEIC conversion failed for {src_path}: {e}") dst_path = os.path.join(images_dir, os.path.basename(src_path)) shutil.copy(src_path, dst_path) image_paths.append(dst_path) # Process regular images: copy directly else: dst_path = os.path.join(images_dir, os.path.basename(src_path)) shutil.copy(src_path, dst_path) image_paths.append(dst_path) image_paths = sorted(image_paths) print(f"Processed files to {images_dir}") return target_dir, image_paths # Handle file upload and update preview gallery def update_gallery_on_upload(input_video, input_images, time_interval=1.0): """ Process uploaded files immediately when user uploads or changes files, and display them in the gallery. Returns (target_dir, image_paths). If nothing is uploaded, returns None and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = process_uploaded_files(input_video, input_images, time_interval) return ( None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # Init function # ------------------------------------------------------------------------- def prepare_visualization_data( model_outputs, input_views ): """Transform model predictions into structured format for display components""" visualization_dict = {} # Iterate through each input view nviews = input_views["img"].shape[1] for idx in range(nviews): # Extract RGB image data rgb_image = input_views["img"][0, idx].detach().cpu().numpy() # Retrieve 3D coordinate predictions world_coordinates = model_outputs["world_points"][idx] # Build view-specific data structure current_view_info = { "image": rgb_image, "points3d": world_coordinates, "depth": None, "normal": None, "mask": None, } # Apply final segmentation mask from model segmentation_mask = model_outputs["final_mask"][idx].copy() current_view_info["mask"] = segmentation_mask current_view_info["depth"] = model_outputs["depth"][idx].squeeze() surface_normals = model_outputs["normal"][idx] current_view_info["normal"] = surface_normals visualization_dict[idx] = current_view_info return visualization_dict @spaces.GPU(duration=120) def gradio_demo( target_dir, frame_selector="All", show_camera=False, filter_sky_bg=False, show_mesh=False, filter_ambiguous=False, ): """ Perform reconstruction using the already-created target_dir/images. """ # Capture terminal output tee = TeeOutput() old_stdout = sys.stdout sys.stdout = tee try: if not os.path.isdir(target_dir) or target_dir == "None": terminal_log = tee.getvalue() sys.stdout = old_stdout return None, "No valid target directory found. Please upload first.", None, None, None, None, None, None, None, None, None, None, None, None, terminal_log start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_selector dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_selector_choices = ["All"] + all_files print("Running WorldMirror model...") with torch.no_grad(): predictions, processed_data = run_model(target_dir) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Save camera parameters as JSON camera_params_file = save_camera_params( predictions['camera_poses'], predictions['camera_intrs'], target_dir ) # Handle None frame_selector if frame_selector is None: frame_selector = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_camera}_mesh{show_mesh}.glb", ) # Convert predictions to GLB glbscene = convert_predictions_to_glb_scene( predictions, filter_by_frames=frame_selector, show_camera=show_camera, mask_sky_bg=filter_sky_bg, as_mesh=show_mesh, # Use the show_mesh parameter mask_ambiguous=filter_ambiguous ) glbscene.export(file_obj=glbfile) end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds") log_msg = ( f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." ) # Convert predictions to 3dgs ply gs_file = None splat_mode = 'ply' if "splats" in predictions: # Get Gaussian parameters (already filtered by GaussianSplatRenderer) means = predictions["splats"]["means"][0].reshape(-1, 3) scales = predictions["splats"]["scales"][0].reshape(-1, 3) quats = predictions["splats"]["quats"][0].reshape(-1, 4) colors = (predictions["splats"]["sh"][0] if "sh" in predictions["splats"] else predictions["splats"]["colors"][0]).reshape(-1, 3) opacities = predictions["splats"]["opacities"][0].reshape(-1) # Convert to torch tensors if needed if not isinstance(means, torch.Tensor): means = torch.from_numpy(means) if not isinstance(scales, torch.Tensor): scales = torch.from_numpy(scales) if not isinstance(quats, torch.Tensor): quats = torch.from_numpy(quats) if not isinstance(colors, torch.Tensor): colors = torch.from_numpy(colors) if not isinstance(opacities, torch.Tensor): opacities = torch.from_numpy(opacities) if splat_mode == 'ply': gs_file = os.path.join(target_dir, "gaussians.ply") save_gs_ply( gs_file, means, scales, quats, colors, opacities ) print(f"Saved Gaussian Splatting PLY to: {gs_file}") print(f"File exists: {os.path.exists(gs_file)}") if os.path.exists(gs_file): print(f"File size: {os.path.getsize(gs_file)} bytes") elif splat_mode == 'splat': # Save Gaussian splat plydata = convert_gs_to_ply( means, scales, quats, colors, opacities ) gs_file = os.path.join(target_dir, "gaussians.splat") gs_file = process_ply_to_splat(plydata, gs_file) # Initialize depth and normal view displays with processed data depth_vis, normal_vis = initialize_depth_normal_views( processed_data ) # Update view selectors and info displays based on available views depth_slider, normal_slider, depth_info, normal_info = update_view_selectors( processed_data ) # Automatically generate render video # Generate render video if possible rgb_video_path = None depth_video_path = None if "splats" in predictions: # try: from pathlib import Path device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get camera parameters and image dimensions camera_poses = torch.tensor(predictions['camera_poses']).unsqueeze(0).to(device) camera_intrs = torch.tensor(predictions['camera_intrs']).unsqueeze(0).to(device) H, W = predictions['images'].shape[1], predictions['images'].shape[2] # Render video out_path = Path(target_dir) / "rendered_video" render_interpolated_video( model.gs_renderer, predictions["splats"], camera_poses, camera_intrs, (H, W), out_path, interp_per_pair=15, loop_reverse=True, save_mode="split" ) # Check output files rgb_video_path = str(out_path) + "_rgb.mp4" depth_video_path = str(out_path) + "_depth.mp4" if not os.path.exists(rgb_video_path) and not os.path.exists(depth_video_path): rgb_video_path = None depth_video_path = None # Cleanup del predictions gc.collect() torch.cuda.empty_cache() # Get terminal output and restore stdout terminal_log = tee.getvalue() sys.stdout = old_stdout return ( glbfile, log_msg, gr.Dropdown(choices=frame_selector_choices, value=frame_selector, interactive=True), processed_data, depth_vis, normal_vis, depth_slider, normal_slider, depth_info, normal_info, camera_params_file, gs_file, rgb_video_path, depth_video_path, terminal_log, ) except Exception as e: # In case of error, still restore stdout terminal_log = tee.getvalue() sys.stdout = old_stdout print(f"Error occurred: {e}") raise # ------------------------------------------------------------------------- # Helper functions for visualization # ------------------------------------------------------------------------- def render_depth_visualization(depth_map, mask=None): """Generate a color-coded depth visualization image with masking capabilities""" if depth_map is None: return None # Create working copy and identify positive depth values depth_copy = depth_map.copy() positive_depth_mask = depth_copy > 0 # Combine with user-provided mask for filtering if mask is not None: positive_depth_mask = positive_depth_mask & mask # Perform percentile-based normalization on valid regions if positive_depth_mask.sum() > 0: valid_depth_values = depth_copy[positive_depth_mask] lower_bound = np.percentile(valid_depth_values, 5) upper_bound = np.percentile(valid_depth_values, 95) depth_copy[positive_depth_mask] = (depth_copy[positive_depth_mask] - lower_bound) / (upper_bound - lower_bound) # Convert to RGB using matplotlib colormap import matplotlib.pyplot as plt color_mapper = plt.cm.turbo_r rgb_result = color_mapper(depth_copy) rgb_result = (rgb_result[:, :, :3] * 255).astype(np.uint8) # Mark invalid regions with white color rgb_result[~positive_depth_mask] = [255, 255, 255] return rgb_result def render_normal_visualization(normal_map, mask=None): """Convert surface normal vectors to RGB color representation for display""" if normal_map is None: return None # Make a working copy to avoid modifying original data normal_display = normal_map.copy() # Handle masking by zeroing out invalid regions if mask is not None: masked_regions = ~mask normal_display[masked_regions] = [0, 0, 0] # Zero out masked pixels # Transform from [-1, 1] to [0, 1] range for RGB display normal_display = (normal_display + 1.0) / 2.0 normal_display = (normal_display * 255).astype(np.uint8) return normal_display def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def get_terminal_output(): """ Get current terminal output for real-time display """ global current_terminal_output return current_terminal_output # ------------------------------------------------------------------------- # FunctionExample scene metadata extraction # ------------------------------------------------------------------------- def extract_example_scenes_metadata(base_directory): """ Extract comprehensive metadata for all scene directories containing valid images. Args: base_directory: Root path where example scene directories are located Returns: Collection of dictionaries with scene details (title, location, preview, etc.) """ from glob import glob # Return empty list if base directory is missing if not os.path.exists(base_directory): return [] # Define supported image format extensions VALID_IMAGE_FORMATS = ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif'] scenes_data = [] # Process each subdirectory in the base directory for directory_name in sorted(os.listdir(base_directory)): current_directory = os.path.join(base_directory, directory_name) # Filter out non-directory items if not os.path.isdir(current_directory): continue # Gather all valid image files within the current directory discovered_images = [] for file_format in VALID_IMAGE_FORMATS: # Include both lowercase and uppercase format variations discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format}'))) discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format.upper()}'))) # Skip directories without any valid images if not discovered_images: continue # Ensure consistent image ordering discovered_images.sort() # Construct scene metadata record scene_record = { 'name': directory_name, 'path': current_directory, 'thumbnail': discovered_images[0], 'num_images': len(discovered_images), 'image_files': discovered_images, } scenes_data.append(scene_record) return scenes_data def load_example_scenes(scene_name, scenes): """ Initialize and prepare an example scene for 3D reconstruction processing. Args: scene_name: Identifier of the target scene to load scenes: List containing all available scene configurations Returns: Tuple containing processed scene data and status information """ # Locate the target scene configuration by matching names target_scene_config = None for scene_config in scenes: if scene_config["name"] == scene_name: target_scene_config = scene_config break # Handle case where requested scene doesn't exist if target_scene_config is None: return None, None, None, "Scene not found" # Prepare image file paths for processing pipeline # Extract all image file paths from the selected scene image_file_paths = [] for img_file_path in target_scene_config["image_files"]: image_file_paths.append(img_file_path) # Process the scene images through the standard upload pipeline processed_target_dir, processed_image_list = process_uploaded_files(image_file_paths, 1.0) # Return structured response with scene data and user feedback status_message = f"Successfully loaded scene '{scene_name}' containing {target_scene_config['num_images']} images. Click 'Reconstruct' to begin 3D processing." return ( None, # Reset reconstruction visualization None, # Reset gaussian splatting output processed_target_dir, # Provide working directory path processed_image_list, # Update image gallery display status_message, ) # ------------------------------------------------------------------------- # UI and event handling # ------------------------------------------------------------------------- theme = gr.themes.Base() with gr.Blocks( theme=theme, css=""" .custom-log * { font-style: italic; font-size: 22px !important; background-image: linear-gradient(120deg, #a9b8f8 0%, #7081e8 60%, #4254c5 100%); -webkit-background-clip: text; background-clip: text; font-weight: bold !important; color: transparent !important; text-align: center !important; } .normal-weight-btn button, .normal-weight-btn button span, .normal-weight-btn button *, .normal-weight-btn * { font-weight: 400 !important; } .terminal-output { max-height: 400px !important; overflow-y: auto !important; } .terminal-output textarea { font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important; font-size: 13px !important; line-height: 1.5 !important; color: #333 !important; background-color: #f8f9fa !important; max-height: 400px !important; } .example-gallery { width: 100% !important; } .example-gallery img { width: 100% !important; height: 280px !important; object-fit: contain !important; aspect-ratio: 16 / 9 !important; } .example-gallery .grid-wrap { width: 100% !important; } /* 滑块导航样式 */ .depth-tab-improved .gradio-slider input[type="range"] { height: 8px !important; border-radius: 4px !important; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important; } .depth-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb { height: 20px !important; width: 20px !important; border-radius: 50% !important; background: #fff !important; box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important; } .depth-tab-improved button { transition: all 0.3s ease !important; border-radius: 6px !important; font-weight: 500 !important; } .depth-tab-improved button:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } .normal-tab-improved .gradio-slider input[type="range"] { height: 8px !important; border-radius: 4px !important; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important; } .normal-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb { height: 20px !important; width: 20px !important; border-radius: 50% !important; background: #fff !important; box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important; } .normal-tab-improved button { transition: all 0.3s ease !important; border-radius: 6px !important; font-weight: 500 !important; } .normal-tab-improved button:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } #depth-view-info, #normal-view-info { animation: fadeIn 0.5s ease-in-out; } @keyframes fadeIn { from { opacity: 0; transform: translateY(-10px); } to { opacity: 1; transform: translateY(0); } } """ ) as demo: # State variables for the tabbed interface is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") processed_data_state = gr.State(value=None) current_view_index = gr.State(value=0) # Track current view index for navigation # Header and description gr.HTML( """

WorldMirror: Universal 3D World Reconstruction with Any Prior Prompting

📄 ArXiv Paper | 🌐 Project Page | 💻 GitHub Repository | 🤗 Hugging Face Model

WorldMirror supports any combination of inputs (images, intrinsics, poses, and depth) and multiple outputs including point clouds, camera parameters, depth maps, normal maps, and 3D Gaussian Splatting (3DGS).

How to Use:

  1. Upload Your Data: Click the "Upload Video or Images" button to add your files. Videos are automatically extracted into frames at one-second intervals.
  2. Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction.
  3. Visualize: Explore multiple reconstruction results across different tabs:
    • 3D View: Interactive point cloud/mesh visualization with camera poses (downloadable as GLB)
    • 3D Gaussian Splatting: Interactive 3D Gaussian Splatting visualization with RGB and depth videos (downloadable as PLY)
    • Depth Maps: Per-view depth estimation results (downloadable as PNG)
    • Normal Maps: Per-view surface orientation visualization (downloadable as PNG)
    • Camera Parameters: Estimated camera poses and intrinsics (downloadable as JSON)

Please note: Loading data and displaying 3D effects may take a moment. For faster performance, we recommend downloading the code from our GitHub and running it locally.

""") output_path_state = gr.Textbox(label="Output Path", visible=False, value="None") # Main UI components with gr.Row(equal_height=False): with gr.Column(scale=1): file_upload = gr.File( file_count="multiple", label="Upload Video or Images", interactive=True, file_types=["image", "video"], height="200px", ) time_interval = gr.Slider( minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Video Sample interval", interactive=True, visible=True, scale=4, ) resample_btn = gr.Button( "Resample", visible=True, scale=1, elem_classes=["normal-weight-btn"], ) image_gallery = gr.Gallery( label="Image Preview", columns=4, height="200px", show_download_button=True, object_fit="contain", preview=True ) terminal_output = gr.Textbox( label="Terminal Output", lines=6, max_lines=6, interactive=False, show_copy_button=True, container=True, elem_classes=["terminal-output"], autoscroll=True ) with gr.Column(scale=3): log_output = gr.Markdown( "Upload video or images first, then click Reconstruct to start processing", elem_classes=["custom-log"], ) with gr.Tabs() as tabs: with gr.Tab("3D Gaussian Splatting", id=1) as gs_tab: with gr.Row(): with gr.Column(scale=3): gs_output = gr.Model3D( label="Gaussian Splatting", height=500, ) with gr.Column(scale=1): gs_rgb_video = gr.Video( label="Rendered RGB Video", height=250, autoplay=False, loop=False, interactive=False, ) gs_depth_video = gr.Video( label="Rendered Depth Video", height=250, autoplay=False, loop=False, interactive=False, ) with gr.Tab("Point Cloud/Mesh", id=0): reconstruction_output = gr.Model3D( label="3D Pointmap/Mesh", height=500, zoom_speed=0.4, pan_speed=0.4, ) with gr.Tab("Depth", elem_classes=["depth-tab-improved"]): depth_view_info = gr.HTML( value="
" "Depth View Navigation | Current: View 1 / 1 views
", elem_id="depth-view-info" ) depth_view_slider = gr.Slider( minimum=1, maximum=1, step=1, value=1, label="View Selection Slider", interactive=True, elem_id="depth-view-slider" ) depth_map = gr.Image( type="numpy", label="Depth Map", format="png", interactive=False, height=340 ) with gr.Tab("Normal", elem_classes=["normal-tab-improved"]): normal_view_info = gr.HTML( value="
" "Normal View Navigation | Current: View 1 / 1 views
", elem_id="normal-view-info" ) normal_view_slider = gr.Slider( minimum=1, maximum=1, step=1, value=1, label="View Selection Slider", interactive=True, elem_id="normal-view-slider" ) normal_map = gr.Image( type="numpy", label="Normal Map", format="png", interactive=False, height=340 ) with gr.Tab("Camera Parameters", elem_classes=["camera-tab"]): with gr.Row(): gr.HTML("") camera_params = gr.DownloadButton( label="Download Camera Parameters", scale=1, variant="primary", ) gr.HTML("") with gr.Row(): reconstruct_btn = gr.Button( "Reconstruct", scale=1, variant="primary" ) clear_btn = gr.ClearButton( [ file_upload, reconstruction_output, log_output, output_path_state, image_gallery, depth_map, normal_map, depth_view_slider, normal_view_slider, depth_view_info, normal_view_info, camera_params, gs_output, gs_rgb_video, gs_depth_video, ], scale=1, ) with gr.Row(): frame_selector = gr.Dropdown( choices=["All"], value="All", label="Show Points of a Specific Frame" ) gr.Markdown("### Reconstruction Options: (not applied to 3DGS)") with gr.Row(): show_camera = gr.Checkbox(label="Show Camera", value=True) show_mesh = gr.Checkbox(label="Show Mesh", value=True) filter_ambiguous = gr.Checkbox(label="Filter low confidence & depth/normal edges", value=True) filter_sky_bg = gr.Checkbox(label="Filter Sky Background", value=False) with gr.Column(scale=1): gr.Markdown("### Click to load example scenes") realworld_scenes = extract_example_scenes_metadata("examples/realistic") if os.path.exists("examples/realistic") else extract_example_scenes_metadata("examples") generated_scenes = extract_example_scenes_metadata("examples/stylistic") if os.path.exists("examples/stylistic") else [] # If no subdirectories exist, fall back to single gallery if not os.path.exists("examples/realistic") and not os.path.exists("examples/stylistic"): # Fallback: use all scenes from examples directory all_scenes = extract_example_scenes_metadata("examples") if all_scenes: gallery_items = [ (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images") for scene in all_scenes ] example_gallery = gr.Gallery( value=gallery_items, label="Example Scenes", columns=1, rows=None, height=800, object_fit="contain", show_label=False, interactive=True, preview=False, allow_preview=False, elem_classes=["example-gallery"] ) def handle_example_selection(evt: gr.SelectData): if evt: result = load_example_scenes(all_scenes[evt.index]["name"], all_scenes) return result return (None, None, None, None, "No scene selected") example_gallery.select( fn=handle_example_selection, outputs=[ reconstruction_output, gs_output, output_path_state, image_gallery, log_output, ], ) else: # Tabbed interface for categorized examples with gr.Tabs(): with gr.Tab("🌍 Realistic Cases"): if realworld_scenes: realworld_items = [ (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images") for scene in realworld_scenes ] realworld_gallery = gr.Gallery( value=realworld_items, label="Real-world Examples", columns=1, rows=None, height=750, object_fit="contain", show_label=False, interactive=True, preview=False, allow_preview=False, elem_classes=["example-gallery"] ) def handle_realworld_selection(evt: gr.SelectData): if evt: result = load_example_scenes(realworld_scenes[evt.index]["name"], realworld_scenes) return result return (None, None, None, None, "No scene selected") realworld_gallery.select( fn=handle_realworld_selection, outputs=[ reconstruction_output, gs_output, output_path_state, image_gallery, log_output, ], ) else: gr.Markdown("No real-world examples available") with gr.Tab("🎨 Stylistic Cases"): if generated_scenes: generated_items = [ (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images") for scene in generated_scenes ] generated_gallery = gr.Gallery( value=generated_items, label="Generated Examples", columns=1, rows=None, height=750, object_fit="contain", show_label=False, interactive=True, preview=False, allow_preview=False, elem_classes=["example-gallery"] ) def handle_generated_selection(evt: gr.SelectData): if evt: result = load_example_scenes(generated_scenes[evt.index]["name"], generated_scenes) return result return (None, None, None, None, "No scene selected") generated_gallery.select( fn=handle_generated_selection, outputs=[ reconstruction_output, gs_output, output_path_state, image_gallery, log_output, ], ) else: gr.Markdown("No generated examples available") # ------------------------------------------------------------------------- # Click logic # ------------------------------------------------------------------------- reconstruct_btn.click(fn=clear_fields, inputs=[], outputs=[]).then( fn=update_log, inputs=[], outputs=[log_output] ).then( fn=gradio_demo, inputs=[ output_path_state, frame_selector, show_camera, filter_sky_bg, show_mesh, filter_ambiguous ], outputs=[ reconstruction_output, log_output, frame_selector, processed_data_state, depth_map, normal_map, depth_view_slider, normal_view_slider, depth_view_info, normal_view_info, camera_params, gs_output, gs_rgb_video, gs_depth_video, terminal_output, ], ).then( fn=lambda: "False", inputs=[], outputs=[is_example], # set is_example to "False" ) # ------------------------------------------------------------------------- # Live update logic # ------------------------------------------------------------------------- def refresh_3d_scene( workspace_path, frame_selector, show_camera, is_example, filter_sky_bg=False, show_mesh=False, filter_ambiguous=False ): """ Refresh 3D scene visualization Load prediction data from workspace, generate or reuse GLB scene files based on current parameters, and return file paths needed for the 3D viewer. Args: workspace_path: Workspace directory path for reconstruction results frame_selector: Frame selector value for filtering points from specific frames show_camera: Whether to display camera positions is_example: Whether this is an example scene filter_sky_bg: Whether to filter sky background show_mesh: Whether to display as mesh mode filter_ambiguous: Whether to filter low-confidence ambiguous areas Returns: tuple: (GLB scene file path, Gaussian point cloud file path, status message) """ # If example scene is clicked, skip processing directly if is_example == "True": return ( gr.update(), gr.update(), "No reconstruction results available. Please click the Reconstruct button first.", ) # Validate workspace directory path if not workspace_path or workspace_path == "None" or not os.path.isdir(workspace_path): return ( gr.update(), gr.update(), "No reconstruction results available. Please click the Reconstruct button first.", ) # Check if prediction data file exists prediction_file_path = os.path.join(workspace_path, "predictions.npz") if not os.path.exists(prediction_file_path): return ( gr.update(), gr.update(), f"Prediction file does not exist: {prediction_file_path}. Please run reconstruction first.", ) # Load prediction data prediction_data = np.load(prediction_file_path, allow_pickle=True) predictions = {key: prediction_data[key] for key in prediction_data.keys() if key != 'splats'} # Generate GLB scene file path (named based on parameter combination) safe_frame_name = frame_selector.replace('.', '_').replace(':', '').replace(' ', '_') scene_filename = f"scene_{safe_frame_name}_cam{show_camera}_mesh{show_mesh}_edges{filter_ambiguous}_sky{filter_sky_bg}.glb" scene_glb_path = os.path.join(workspace_path, scene_filename) # If GLB file doesn't exist, generate new scene file if not os.path.exists(scene_glb_path): scene_model = convert_predictions_to_glb_scene( predictions, filter_by_frames=frame_selector, show_camera=show_camera, mask_sky_bg=filter_sky_bg, as_mesh=show_mesh, mask_ambiguous=filter_ambiguous ) scene_model.export(file_obj=scene_glb_path) # Find Gaussian point cloud file gaussian_file_path = os.path.join(workspace_path, "gaussians.ply") if not os.path.exists(gaussian_file_path): gaussian_file_path = None return ( scene_glb_path, gaussian_file_path, "3D scene updated.", ) def refresh_view_displays_on_filter_update( workspace_dir, sky_background_filter, current_processed_data, depth_slider_position, normal_slider_position, ): """ Refresh depth and normal view displays when filter settings change When the background filter checkbox state changes, regenerate processed data and update all view displays. This ensures that filter effects are reflected in real-time in the depth map and normal map visualizations. Args: workspace_dir: Workspace directory path containing prediction data and images sky_background_filter: Sky background filter enable status current_processed_data: Currently processed visualization data depth_slider_position: Current position of the depth view slider normal_slider_position: Current position of the normal view slider Returns: tuple: (updated processed data, depth visualization result, normal visualization result) """ # Validate workspace directory validity if not workspace_dir or workspace_dir == "None" or not os.path.isdir(workspace_dir): return current_processed_data, None, None # Build and check prediction data file path prediction_data_path = os.path.join(workspace_dir, "predictions.npz") if not os.path.exists(prediction_data_path): return current_processed_data, None, None try: # Load raw prediction data raw_prediction_data = np.load(prediction_data_path, allow_pickle=True) predictions_dict = {key: raw_prediction_data[key] for key in raw_prediction_data.keys()} # Load image data using WorldMirror's load_images function images_directory = os.path.join(workspace_dir, "images") image_file_paths = [os.path.join(images_directory, path) for path in os.listdir(images_directory)] img = load_and_preprocess_images(image_file_paths) img = img.detach().cpu().numpy() # Regenerate processed data with new filter settings refreshed_data = {} for view_idx in range(img.shape[1]): view_data = { "image": img[0, view_idx], "points3d": predictions_dict["world_points"][view_idx], "depth": None, "normal": None, "mask": None, } mask = predictions_dict["final_mask"][view_idx].copy() if sky_background_filter: sky_mask = predictions_dict["sky_mask"][view_idx] mask = mask & sky_mask view_data["mask"] = mask view_data["depth"] = predictions_dict["depth"][view_idx].squeeze() view_data["normal"] = predictions_dict["normal"][view_idx] refreshed_data[view_idx] = view_data # Get current view indices from slider positions (convert to 0-based indices) current_depth_index = int(depth_slider_position) - 1 if depth_slider_position else 0 current_normal_index = int(normal_slider_position) - 1 if normal_slider_position else 0 # Update depth and normal views with new filter data updated_depth_visualization = update_depth_view(refreshed_data, current_depth_index) updated_normal_visualization = update_normal_view(refreshed_data, current_normal_index) return refreshed_data, updated_depth_visualization, updated_normal_visualization except Exception as error: print(f"Error occurred while refreshing view displays: {error}") return current_processed_data, None, None frame_selector.change( refresh_3d_scene, [ output_path_state, frame_selector, show_camera, is_example, filter_sky_bg, show_mesh, filter_ambiguous ], [reconstruction_output, gs_output, log_output], ) show_camera.change( refresh_3d_scene, [ output_path_state, frame_selector, show_camera, is_example, filter_sky_bg, show_mesh, filter_ambiguous ], [reconstruction_output, gs_output, log_output], ) show_mesh.change( refresh_3d_scene, [ output_path_state, frame_selector, show_camera, is_example, filter_sky_bg, show_mesh, filter_ambiguous ], [reconstruction_output, gs_output, log_output], ) filter_sky_bg.change( refresh_3d_scene, [ output_path_state, frame_selector, show_camera, is_example, filter_sky_bg, show_mesh, filter_ambiguous ], [reconstruction_output, gs_output, log_output], ).then( fn=refresh_view_displays_on_filter_update, inputs=[ output_path_state, filter_sky_bg, processed_data_state, depth_view_slider, normal_view_slider, ], outputs=[ processed_data_state, depth_map, normal_map, ], ) filter_ambiguous.change( refresh_3d_scene, [ output_path_state, frame_selector, show_camera, is_example, filter_sky_bg, show_mesh, filter_ambiguous ], [reconstruction_output, gs_output, log_output], ).then( fn=refresh_view_displays_on_filter_update, inputs=[ output_path_state, filter_sky_bg, processed_data_state, depth_view_slider, normal_view_slider, ], outputs=[ processed_data_state, depth_map, normal_map, ], ) # ------------------------------------------------------------------------- # Auto update gallery when user uploads or changes files # ------------------------------------------------------------------------- def update_gallery_on_file_upload(files, interval): if not files: return None, None, None, "" # Capture terminal output tee = TeeOutput() old_stdout = sys.stdout sys.stdout = tee try: target_dir, image_paths = process_uploaded_files(files, interval) terminal_log = tee.getvalue() sys.stdout = old_stdout return ( target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", terminal_log, ) except Exception as e: terminal_log = tee.getvalue() sys.stdout = old_stdout print(f"Error occurred: {e}") raise def resample_video_with_new_interval(files, new_interval, current_target_dir): """Resample video with new slider value""" if not files: return ( current_target_dir, None, "No files to resample.", "", ) # Check if we have videos to resample video_extensions = [ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", ] has_video = any( os.path.splitext( str(file_data["name"] if isinstance(file_data, dict) else file_data) )[1].lower() in video_extensions for file_data in files ) if not has_video: return ( current_target_dir, None, "No videos found to resample.", "", ) # Capture terminal output tee = TeeOutput() old_stdout = sys.stdout sys.stdout = tee try: # Clean up old target directory if it exists if ( current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir) ): shutil.rmtree(current_target_dir) # Process files with new interval target_dir, image_paths = process_uploaded_files(files, new_interval) terminal_log = tee.getvalue() sys.stdout = old_stdout return ( target_dir, image_paths, f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.", terminal_log, ) except Exception as e: terminal_log = tee.getvalue() sys.stdout = old_stdout print(f"Error occurred: {e}") raise file_upload.change( fn=update_gallery_on_file_upload, inputs=[file_upload, time_interval], outputs=[output_path_state, image_gallery, log_output, terminal_output], ) resample_btn.click( fn=resample_video_with_new_interval, inputs=[file_upload, time_interval, output_path_state], outputs=[output_path_state, image_gallery, log_output, terminal_output], ) # ------------------------------------------------------------------------- # Navigation for Depth, Normal tabs # ------------------------------------------------------------------------- def navigate_with_slider(processed_data, target_view): """Navigate to specified view using slider""" if processed_data is None or len(processed_data) == 0: return None, update_view_info(1, 1) # Check if target_view is None or invalid value, and safely convert to int try: if target_view is None: target_view = 1 else: target_view = int(float(target_view)) # Convert to float first then int, handle decimal input except (ValueError, TypeError): target_view = 1 total_views = len(processed_data) # Ensure view index is within valid range view_index = max(1, min(target_view, total_views)) - 1 # Update depth map depth_vis = update_depth_view(processed_data, view_index) # Update view information info_html = update_view_info(view_index + 1, total_views) return depth_vis, info_html def navigate_with_slider_normal(processed_data, target_view): """Navigate to specified normal view using slider""" if processed_data is None or len(processed_data) == 0: return None, update_view_info(1, 1, "Normal") # Check if target_view is None or invalid value, and safely convert to int try: if target_view is None: target_view = 1 else: target_view = int(float(target_view)) # Convert to float first then int, handle decimal input except (ValueError, TypeError): target_view = 1 total_views = len(processed_data) # Ensure view index is within valid range view_index = max(1, min(target_view, total_views)) - 1 # Update normal map normal_vis = update_normal_view(processed_data, view_index) # Update view information info_html = update_view_info(view_index + 1, total_views, "Normal") return normal_vis, info_html def handle_depth_slider_change(processed_data, target_view): return navigate_with_slider(processed_data, target_view) def handle_normal_slider_change(processed_data, target_view): return navigate_with_slider_normal(processed_data, target_view) depth_view_slider.change( fn=handle_depth_slider_change, inputs=[processed_data_state, depth_view_slider], outputs=[depth_map, depth_view_info] ) normal_view_slider.change( fn=handle_normal_slider_change, inputs=[processed_data_state, normal_view_slider], outputs=[normal_map, normal_view_info] ) # ------------------------------------------------------------------------- # Real-time terminal output update # ------------------------------------------------------------------------- # Use a timer to periodically update terminal output timer = gr.Timer(value=0.5) # Update every 0.5 seconds timer.tick( fn=get_terminal_output, inputs=[], outputs=[terminal_output] ) gr.HTML("""

Acknowledgements

🔗 MoGe2 on HuggingFace | 🔗 VGGT on HuggingFace

""") demo.queue().launch( show_error=True, share=True, ssr_mode=False, )