# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # conda activate hf3.10 import base64 import gc import os import shutil import sys import time from datetime import datetime import cv2 import gradio as gr import numpy as np import spaces import torch from huggingface_hub import hf_hub_download sys.path.append("mapanything/") from hf_utils.css_and_html import ( get_acknowledgements_html, get_description_html, get_gradio_theme, get_header_html, GRADIO_CSS, MEASURE_INSTRUCTIONS_HTML, ) from hf_utils.vgg_geometry import unproject_depth_map_to_point_map from hf_utils.visual_util import predictions_to_glb from mapanything.models import init_model from mapanything.utils.geometry import depth_edge, normals_edge, points_to_normals from mapanything.utils.image import load_images, rgb from mapanything.utils.inference import loss_of_one_batch_multi_view def get_logo_base64(): """Convert WAI logo to base64 for embedding in HTML""" logo_path = "examples/wai_logo/wai_logo.png" try: with open(logo_path, "rb") as img_file: img_data = img_file.read() base64_str = base64.b64encode(img_data).decode() return f"data:image/png;base64,{base64_str}" except FileNotFoundError: return None print("Initializing and loading MapAnything model...") def load_hf_token(): """Load HuggingFace access token from local file""" token_file_paths = [ "~/hf_token.txt", ] for token_path in token_file_paths: if os.path.exists(token_path): try: with open(token_path, "r") as f: token = f.read().strip() print(f"Loaded HuggingFace token from: {token_path}") return token except Exception as e: print(f"Error reading token from {token_path}: {e}") continue # Also try environment variable # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options token = ( os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HUGGING_FACE_MODEL_TOKEN") ) if token: print("Loaded HuggingFace token from environment variable") return token print( "Warning: No HuggingFace token found. Model loading may fail for private repositories." ) return None def init_hydra_config(config_path, overrides=None): "Initialize Hydra config" import hydra config_dir = os.path.dirname(config_path) config_name = os.path.basename(config_path).split(".")[0] relative_path = os.path.relpath(config_dir, os.path.dirname(__file__)) hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.initialize(version_base=None, config_path=relative_path) if overrides is not None: cfg = hydra.compose(config_name=config_name, overrides=overrides) else: cfg = hydra.compose(config_name=config_name) return cfg def init_inference_model(config, ckpt_path, device): "Initialize the model for inference" if isinstance(config, dict): config_path = config["path"] overrrides = config["config_overrides"] model_args = init_hydra_config(config_path, overrides=overrrides) model = init_model(model_args.model.model_str, model_args.model.model_config) else: config_path = config model_args = init_hydra_config(config_path) model = init_model(model_args.model_str, model_args.model_config) model.to(device) if ckpt_path is not None: print("Loading model from: ", ckpt_path) # Load HuggingFace token for private repositories hf_token = load_hf_token() # Try to download from HuggingFace Hub first if it's a HF URL if "huggingface.co" in ckpt_path: try: # Extract repo_id and filename from URL # URL format: https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth parts = ckpt_path.replace("https://huggingface.co/", "").split("/") repo_id = f"{parts[0]}/{parts[1]}" # e.g., "facebook/MapAnything" filename = "/".join( parts[4:] ) # e.g., "mapa_curri_24v_13d_48ipg_64g.pth" print(f"Downloading from HuggingFace Hub: {repo_id}/{filename}") local_file = hf_hub_download( repo_id=repo_id, filename=filename, token=hf_token, cache_dir=None, # Use default cache ) ckpt = torch.load(local_file, map_location=device, weights_only=False) except Exception as e: print(f"HuggingFace Hub download failed: {e}") print("Falling back to torch.hub.load_state_dict_from_url...") # Fallback to original method ckpt = torch.hub.load_state_dict_from_url( ckpt_path, map_location=device ) else: # Use original method for non-HF URLs ckpt = torch.hub.load_state_dict_from_url(ckpt_path, map_location=device) print(model.load_state_dict(ckpt["model"], strict=False)) model.eval() return model # MapAnything Configuration high_level_config = { "path": "configs/train.yaml", "config_overrides": [ "machine=aws", "model=mapanything", "model/task=images_only", "model.encoder.uses_torch_hub=false", ], "checkpoint_path": "https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth", "trained_with_amp": True, "trained_with_amp_dtype": "fp16", "data_norm_type": "dinov2", "patch_size": 14, "resolution": 518, } # Initialize model - this will be done on GPU when needed model = None # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_model(target_dir, model_placeholder): """ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. """ global model print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # if not torch.cuda.is_available(): # raise ValueError("CUDA is not available. Check your environment.") # Initialize model if not already done if model is None: print("Initializing MapAnything model...") model = init_inference_model( high_level_config, high_level_config["checkpoint_path"], device ) else: model = model.to(device) model.eval() # Load images using MapAnything's load_images function print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images( image_folder_path, resolution_set=high_level_config["resolution"], verbose=False, norm_type=high_level_config["data_norm_type"], patch_size=high_level_config["patch_size"], stride=1, ) print(f"Loaded {len(views)} images") if len(views) == 0: raise ValueError("No images found. Check your upload.") # Run inference using MapAnything's inference function print("Running MapAnything inference...") with torch.no_grad(): pred_result = loss_of_one_batch_multi_view( views, model, None, device, use_amp=high_level_config["trained_with_amp"], amp_dtype=high_level_config["trained_with_amp_dtype"], ) # Convert predictions to format expected by visualization predictions = {} # Initialize lists for the required keys extrinsic_list = [] intrinsic_list = [] world_points_list = [] depth_maps_list = [] images_list = [] confidence_list = [] final_mask_list = [] # Check if confidence data is available has_confidence = False for view_idx, view in enumerate(views): view_key = f"pred{view_idx + 1}" if view_key in pred_result and "conf" in pred_result[view_key]: has_confidence = True break # Extract predictions for each view for view_idx, view in enumerate(views): # Get image for colors image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) view_key = f"pred{view_idx + 1}" if view_key in pred_result: pred_pts3d = pred_result[view_key]["pts3d"][0].cpu().numpy() # Get confidence data if available confidence_map = None if "conf" in pred_result[view_key]: confidence_map = pred_result[view_key]["conf"][0].cpu().numpy() # Compute final_mask just like in visualize_raw_inference_output function # Create the prediction mask based on parameters pred_mask = None use_gt_mask_on_pred = False # Set based on your requirements use_pred_mask = True # Set based on your requirements use_non_ambi_mask = True # Set based on your requirements use_conf_mask = False # Set based on your requirements conf_percentile = 10 # Set based on your requirements use_edge_mask = True # Set based on your requirements pts_edge_tol = 5 # Set based on your requirements depth_edge_rtol = 0.03 # Set based on your requirements if use_pred_mask: # Get non ambiguous mask if available and requested has_non_ambiguous_mask = ( "non_ambiguous_mask" in pred_result[view_key] and use_non_ambi_mask ) if has_non_ambiguous_mask: non_ambiguous_mask = ( pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy() ) pred_mask = non_ambiguous_mask # Get confidence mask if available and requested has_conf = "conf" in pred_result[view_key] and use_conf_mask if has_conf: confidences = pred_result[view_key]["conf"][0].cpu() percentile_threshold = torch.quantile( confidences, conf_percentile / 100.0 ) conf_mask = confidences > percentile_threshold conf_mask = conf_mask.numpy() if pred_mask is not None: pred_mask = pred_mask & conf_mask else: pred_mask = conf_mask # Apply edge mask if requested if use_edge_mask and pred_mask is not None: if "cam_quats" not in pred_result[view_key]: # For direct point prediction # Compute normals and edge mask normals, normals_mask = points_to_normals( pred_pts3d, mask=pred_mask ) edge_mask = ~( normals_edge(normals, tol=pts_edge_tol, mask=normals_mask) ) else: # For ray-based prediction ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu() local_pts3d = ( pred_result[view_key]["ray_directions"][0].cpu() * ray_depth ) depth_z = local_pts3d[..., 2].numpy() # Compute normals and edge mask normals, normals_mask = points_to_normals( pred_pts3d, mask=pred_mask ) edge_mask = ~( depth_edge(depth_z, rtol=depth_edge_rtol, mask=pred_mask) & normals_edge(normals, tol=pts_edge_tol, mask=normals_mask) ) if pred_mask is not None: pred_mask = pred_mask & edge_mask # Determine final mask to use (like in visualize_raw_inference_output) final_mask = None valid_mask = np.ones_like( pred_pts3d[..., 0], dtype=bool ) # Create dummy valid_mask for app.py context if use_gt_mask_on_pred: final_mask = valid_mask if use_pred_mask and pred_mask is not None: final_mask = final_mask & pred_mask elif use_pred_mask and pred_mask is not None: final_mask = pred_mask else: final_mask = np.ones_like(valid_mask, dtype=bool) # Check if we have camera pose and intrinsics data if "cam_quats" in pred_result[view_key]: # Get decoupled quantities (like in visualize_raw_custom_data_inference_output) cam_quats = pred_result[view_key]["cam_quats"][0].cpu() cam_trans = pred_result[view_key]["cam_trans"][0].cpu() ray_directions = pred_result[view_key]["ray_directions"][0].cpu() ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu() # Convert the quantities from mapanything.utils.geometry import ( quaternion_to_rotation_matrix, recover_pinhole_intrinsics_from_ray_directions, ) cam_rot = quaternion_to_rotation_matrix(cam_quats) cam_pose = torch.eye(4) cam_pose[:3, :3] = cam_rot cam_pose[:3, 3] = cam_trans cam_pose = np.linalg.inv(cam_pose) cam_intrinsics = recover_pinhole_intrinsics_from_ray_directions( ray_directions, use_geometric_calculation=True ) # Compute depth as in app_map.py local_pts3d = ray_directions * ray_depth depth_z = local_pts3d[..., 2] # Convert to numpy and extract 3x4 extrinsic (remove bottom row) extrinsic = cam_pose[:3, :4].numpy() # Shape: (3, 4) intrinsic = cam_intrinsics.numpy() # Shape: (3, 3) depth_z = depth_z.numpy() # Shape: (H, W) else: # Use dummy values if camera info not available # extrinsic: (3, 4) - [R|t] matrix extrinsic = np.eye(3, 4) # Identity rotation, zero translation # intrinsic: (3, 3) - camera intrinsic matrix intrinsic = np.eye(3) # depth_z: (H, W) - dummy depth values depth_z = np.zeros_like(pred_pts3d[..., 0]) # Append to lists extrinsic_list.append(extrinsic) intrinsic_list.append(intrinsic) world_points_list.append(pred_pts3d) depth_maps_list.append(depth_z) images_list.append(image[0]) # Add image to list final_mask_list.append(final_mask) # Add final_mask to list # Add confidence data (or None if not available) if confidence_map is not None: confidence_list.append(confidence_map) elif has_confidence: # If some views have confidence but this one doesn't, add dummy confidence confidence_list.append(np.ones_like(depth_z)) # Convert lists to numpy arrays with required shapes # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) # world_points: (S, H, W, 3) - batch of 3D world points predictions["world_points"] = np.stack(world_points_list, axis=0) # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps depth_maps = np.stack(depth_maps_list, axis=0) # Add channel dimension if needed to match (S, H, W, 1) format if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] predictions["depth"] = depth_maps # images: (S, H, W, 3) - batch of input images predictions["images"] = np.stack(images_list, axis=0) # confidence: (S, H, W) - batch of confidence maps (only if available) if confidence_list: predictions["confidence"] = np.stack(confidence_list, axis=0) # final_mask: (S, H, W) - batch of final masks for filtering predictions["final_mask"] = np.stack(final_mask_list, axis=0) world_points = unproject_depth_map_to_point_map( depth_maps, predictions["extrinsic"], predictions["intrinsic"] ) predictions["world_points_from_depth"] = world_points # Process data for visualization tabs (depth, normal, measure) processed_data = process_predictions_for_visualization( pred_result, views, high_level_config ) # Clean up torch.cuda.empty_cache() return predictions, processed_data def update_view_selectors(processed_data): """Update view selector dropdowns based on available views""" if processed_data is None or len(processed_data) == 0: choices = ["View 1"] else: num_views = len(processed_data) choices = [f"View {i + 1}" for i in range(num_views)] return ( gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector ) def get_view_data_by_index(processed_data, view_index): """Get view data by index, handling bounds""" if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index, conf_thres=None): """Update depth view for a specific view index with optional confidence filtering""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None # Use confidence filtering if available confidence = view_data.get("confidence") return colorize_depth( view_data["depth"], confidence=confidence, conf_thres=conf_thres ) def update_normal_view(processed_data, view_index, conf_thres=None): """Update normal view for a specific view index with optional confidence filtering""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None # Use confidence filtering if available confidence = view_data.get("confidence") return colorize_normal( view_data["normal"], confidence=confidence, conf_thres=conf_thres ) def update_measure_view(processed_data, view_index): """Update measure view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None: return None, [] # image, measure_points return view_data["image"], [] def navigate_depth_view( processed_data, current_selector_value, direction, conf_thres=None ): """Navigate depth view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" depth_vis = update_depth_view(processed_data, new_view, conf_thres=conf_thres) return new_selector_value, depth_vis def navigate_normal_view( processed_data, current_selector_value, direction, conf_thres=None ): """Navigate normal view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" normal_vis = update_normal_view(processed_data, new_view, conf_thres=conf_thres) return new_selector_value, normal_vis def navigate_measure_view(processed_data, current_selector_value, direction): """Navigate measure view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None, [] # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" measure_image, measure_points = update_measure_view(processed_data, new_view) return new_selector_value, measure_image, measure_points def populate_visualization_tabs(processed_data, conf_thres=None): """Populate the depth, normal, and measure tabs with processed data""" if processed_data is None or len(processed_data) == 0: return None, None, None, [] # Use update functions to ensure confidence filtering is applied from the start depth_vis = update_depth_view(processed_data, 0, conf_thres=conf_thres) normal_vis = update_normal_view(processed_data, 0, conf_thres=conf_thres) measure_img, _ = update_measure_view(processed_data, 0) return depth_vis, normal_vis, measure_img, [] # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(input_video, input_images, s_time_interval=1.0): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle images --- if input_images is not None: for file_data in input_images: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = file_data dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # --- Handle video --- if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: video_path = input_video["name"] else: video_path = input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * s_time_interval) # 1 frame/sec count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: image_path = os.path.join( target_dir_images, f"{video_frame_num:06}.png" ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) return ( None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def gradio_demo( target_dir, conf_thres=3.0, frame_filter="All", show_cam=True, filter_sky=False, filter_black_bg=False, filter_white_bg=False, mask_ambiguous=False, ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() # Always use Pointmap Branch for MapAnything prediction_mode = "Pointmap Branch" # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files print("Running MapAnything model...") with torch.no_grad(): predictions, processed_data = run_model(target_dir, None) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_mask{mask_ambiguous}_pred{prediction_mode.replace(' ', '_')}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, show_cam=show_cam, target_dir=target_dir, prediction_mode=prediction_mode, mask_sky=filter_sky, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, mask_ambiguous=mask_ambiguous, ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds") log_msg = ( f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." ) # Populate visualization tabs with processed data depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( processed_data, conf_thres=conf_thres ) # Update view selectors based on available views depth_selector, normal_selector, measure_selector = update_view_selectors( processed_data ) return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, depth_vis, normal_vis, measure_img, "", # measure_text (empty initially) depth_selector, normal_selector, measure_selector, ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def apply_confidence_filtering(data, confidence, conf_thres): """Apply confidence filtering to data arrays""" if confidence is None or data is None: return data # Convert confidence threshold from percentage to confidence value conf_threshold = np.percentile(confidence, conf_thres) conf_mask = (confidence >= conf_threshold) & (confidence > 1e-5) # conf_mask = confidence >= (conf_thres) # Apply mask to data if len(data.shape) == 3: # 3D data (H, W, C) filtered_data = data.copy() for c in range(data.shape[2]): filtered_data[:, :, c] = np.where(conf_mask, data[:, :, c], 0) elif len(data.shape) == 2: # 2D data (H, W) filtered_data = np.where(conf_mask, data, 0) else: filtered_data = data return filtered_data def colorize_depth(depth_map, confidence=None, conf_thres=None): """Convert depth map to colorized visualization with optional confidence filtering""" if depth_map is None: return None # Apply confidence filtering if available if confidence is not None and conf_thres is not None: depth_map = apply_confidence_filtering(depth_map, confidence, conf_thres) # Normalize depth to 0-1 range depth_normalized = depth_map.copy() valid_mask = depth_normalized > 0 if valid_mask.sum() > 0: valid_depths = depth_normalized[valid_mask] p5 = np.percentile(valid_depths, 5) p95 = np.percentile(valid_depths, 95) depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) # Apply colormap import matplotlib.pyplot as plt colormap = plt.cm.turbo_r # colormap = plt.cm.plasma # colormap = plt.cm.viridis colored = colormap(depth_normalized) colored = (colored[:, :, :3] * 255).astype(np.uint8) # Set invalid pixels to white colored[~valid_mask] = [255, 255, 255] return colored def colorize_normal(normal_map, confidence=None, conf_thres=None): """Convert normal map to colorized visualization with optional confidence filtering""" if normal_map is None: return None # Apply confidence filtering if available if confidence is not None and conf_thres is not None: normal_map = apply_confidence_filtering(normal_map, confidence, conf_thres) # Normalize normals to [0, 1] range for visualization normal_vis = (normal_map + 1.0) / 2.0 normal_vis = (normal_vis * 255).astype(np.uint8) return normal_vis def process_predictions_for_visualization(pred_result, views, high_level_config): """Extract depth, normal, and 3D points from predictions for visualization""" processed_data = {} # Check if confidence data is available in any view has_confidence_data = False for view_idx, view in enumerate(views): view_key = f"pred{view_idx + 1}" if view_key in pred_result and "conf" in pred_result[view_key]: has_confidence_data = True break # Process each view for view_idx, view in enumerate(views): view_key = f"pred{view_idx + 1}" if view_key not in pred_result: continue # Get image image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) # Get predicted points pred_pts3d = pred_result[view_key]["pts3d"][0].cpu().numpy() # Initialize data for this view view_data = { "image": image[0], "points3d": pred_pts3d, "depth": None, "normal": None, "mask": None, "confidence": None, "has_confidence": has_confidence_data, } # Get confidence data if available if "conf" in pred_result[view_key]: confidence = pred_result[view_key]["conf"][0].cpu().numpy() view_data["confidence"] = confidence # Get masks if available has_non_ambiguous_mask = "non_ambiguous_mask" in pred_result[view_key] if has_non_ambiguous_mask: view_data["mask"] = ( pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy() ) # Extract depth and camera info if available if "cam_quats" in pred_result[view_key]: ray_directions = pred_result[view_key]["ray_directions"][0].cpu() ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu() # Compute depth local_pts3d = ray_directions * ray_depth depth_z = local_pts3d[..., 2].numpy() view_data["depth"] = depth_z # Compute normals if we have valid points if has_non_ambiguous_mask: try: normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) view_data["normal"] = normals except: # If normal computation fails, skip it pass processed_data[view_idx] = view_data return processed_data def reset_measure(processed_data): """Reset measure points""" if processed_data is None or len(processed_data) == 0: return None, [], "" # Return the first view image first_view = list(processed_data.values())[0] return first_view["image"], [], "" def measure( processed_data, measure_points, current_view_selector, event: gr.SelectData ): """Handle measurement on images""" try: print(f"Measure function called with selector: {current_view_selector}") if processed_data is None or len(processed_data) == 0: return None, [], "No data available" # Use the currently selected view instead of always using the first view try: current_view_index = int(current_view_selector.split()[1]) - 1 except: current_view_index = 0 print(f"Using view index: {current_view_index}") # Get view data safely if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0 view_keys = list(processed_data.keys()) current_view = processed_data[view_keys[current_view_index]] if current_view is None: return None, [], "No view data available" point2d = event.index[0], event.index[1] print(f"Clicked point: {point2d}") measure_points.append(point2d) # Get image and ensure it's valid image = current_view["image"] if image is None: return None, [], "No image available" image = image.copy() points3d = current_view["points3d"] # Ensure image is in uint8 format for proper cv2 operations try: if image.dtype != np.uint8: if image.max() <= 1.0: # Image is in [0, 1] range, convert to [0, 255] image = (image * 255).astype(np.uint8) else: # Image is already in [0, 255] range image = image.astype(np.uint8) except Exception as e: print(f"Image conversion error: {e}") return None, [], f"Image conversion error: {e}" # Draw circles for points try: for p in measure_points: if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: image = cv2.circle( image, p, radius=5, color=(255, 0, 0), thickness=2 ) except Exception as e: print(f"Drawing error: {e}") return None, [], f"Drawing error: {e}" depth_text = "" try: for i, p in enumerate(measure_points): if ( current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1] ): d = current_view["depth"][p[1], p[0]] depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" else: # Use Z coordinate of 3D points if depth not available if ( points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] ): z = points3d[p[1], p[0], 2] depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" except Exception as e: print(f"Depth text error: {e}") depth_text = f"Error computing depth: {e}\n" if len(measure_points) == 2: try: point1, point2 = measure_points # Draw line if ( 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] ): image = cv2.line( image, point1, point2, color=(255, 0, 0), thickness=2 ) # Compute 3D distance distance_text = "- **Distance: Unable to compute**" if ( points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1] ): try: p1_3d = points3d[point1[1], point1[0]] p2_3d = points3d[point2[1], point2[0]] distance = np.linalg.norm(p1_3d - p2_3d) distance_text = f"- **Distance: {distance:.2f}m**" except Exception as e: print(f"Distance computation error: {e}") distance_text = f"- **Distance computation error: {e}**" measure_points = [] text = depth_text + distance_text print(f"Measurement complete: {text}") return [image, measure_points, text] except Exception as e: print(f"Final measurement error: {e}") return None, [], f"Measurement error: {e}" else: print(f"Single point measurement: {depth_text}") return [image, measure_points, depth_text] except Exception as e: print(f"Overall measure function error: {e}") return None, [], f"Measure function error: {e}" def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, conf_thres, frame_filter, show_cam, is_example, filter_sky=False, filter_black_bg=False, filter_white_bg=False, mask_ambiguous=False, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return ( gr.update(), "No reconstruction available. Please click the Reconstruct button first.", ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ( gr.update(), "No reconstruction available. Please click the Reconstruct button first.", ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return ( gr.update(), f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} # Always use Pointmap Branch for MapAnything prediction_mode = "Pointmap Branch" glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, show_cam=show_cam, target_dir=target_dir, prediction_mode=prediction_mode, mask_sky=filter_sky, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, mask_ambiguous=mask_ambiguous, ) glbscene.export(file_obj=glbfile) return ( glbfile, "Visualization updated.", ) # ------------------------------------------------------------------------- # Example scene functions # ------------------------------------------------------------------------- def get_scene_info(examples_dir): """Get information about scenes in the examples directory""" import glob scenes = [] if not os.path.exists(examples_dir): return scenes for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): # Find all image files in the scene folder image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) if image_files: # Sort images and get the first one for thumbnail image_files = sorted(image_files) first_image = image_files[0] num_images = len(image_files) scenes.append( { "name": scene_folder, "path": scene_path, "thumbnail": first_image, "num_images": num_images, "image_files": image_files, } ) return scenes def load_example_scene(scene_name, examples_dir="examples"): """Load a scene from examples directory""" scenes = get_scene_info(examples_dir) # Find the selected scene selected_scene = None for scene in scenes: if scene["name"] == scene_name: selected_scene = scene break if selected_scene is None: return None, None, None, "Scene not found" # Create target directory and copy images target_dir, image_paths = handle_uploads(None, selected_scene["image_files"]) return ( None, # Clear reconstruction output target_dir, # Set target directory image_paths, # Set gallery f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = get_gradio_theme() with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo: # State variables for the tabbed interface is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") processed_data_state = gr.State(value=None) measure_points_state = gr.State(value=[]) current_view_index = gr.State(value=0) # Track current view index for navigation gr.HTML(get_header_html(get_logo_base64())) gr.HTML(get_description_html()) target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") with gr.Row(): with gr.Column(scale=2): input_video = gr.Video(label="Upload Video", interactive=True) s_time_interval = gr.Slider( minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Sample time interval (take a sample every x sec.)", interactive=True, visible=True, ) input_images = gr.File( file_count="multiple", label="Upload Images", interactive=True ) image_gallery = gr.Gallery( label="Preview", columns=4, height="300px", show_download_button=True, object_fit="contain", preview=True, ) with gr.Column(scale=4): with gr.Column(): gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**") log_output = gr.Markdown( "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"], ) # Add tabbed interface similar to MoGe with gr.Tabs(): with gr.Tab("3D View"): reconstruction_output = gr.Model3D( height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0], key="persistent_3d_viewer", elem_id="reconstruction_3d_viewer", ) with gr.Tab("Depth"): with gr.Row(elem_classes=["navigation-row"]): prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1) depth_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_depth_btn = gr.Button("Next ▶", size="sm", scale=1) depth_map = gr.Image( type="numpy", label="Colorized Depth Map", format="png", interactive=False, ) with gr.Tab("Normal"): with gr.Row(elem_classes=["navigation-row"]): prev_normal_btn = gr.Button( "◀ Previous", size="sm", scale=1 ) normal_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_normal_btn = gr.Button("Next ▶", size="sm", scale=1) normal_map = gr.Image( type="numpy", label="Normal Map", format="png", interactive=False, ) with gr.Tab("Measure"): gr.Markdown(MEASURE_INSTRUCTIONS_HTML) with gr.Row(elem_classes=["navigation-row"]): prev_measure_btn = gr.Button( "◀ Previous", size="sm", scale=1 ) measure_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_measure_btn = gr.Button("Next ▶", size="sm", scale=1) measure_image = gr.Image( type="numpy", show_label=False, format="webp", interactive=False, sources=[], ) measure_text = gr.Markdown("") with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") clear_btn = gr.ClearButton( [ input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, ], scale=1, ) with gr.Row(): conf_thres = gr.Slider( minimum=0, maximum=100, value=0, step=0.1, label="Confidence Threshold (%), only shown in depth and normals", ) frame_filter = gr.Dropdown( choices=["All"], value="All", label="Show Points from Frame" ) with gr.Column(): show_cam = gr.Checkbox(label="Show Camera", value=True) filter_sky = gr.Checkbox( label="Filter Sky (using skyseg.onnx)", value=False ) filter_black_bg = gr.Checkbox( label="Filter Black Background", value=False ) filter_white_bg = gr.Checkbox( label="Filter White Background", value=False ) mask_ambiguous = gr.Checkbox(label="Mask Ambiguous", value=True) # ---------------------- Example Scenes Section ---------------------- gr.Markdown("## Example Scenes") gr.Markdown("Click any thumbnail to load the scene for reconstruction.") # Get scene information scenes = get_scene_info("examples") # Create thumbnail grid (4 columns, N rows) if scenes: for i in range(0, len(scenes), 4): # Process 4 scenes per row with gr.Row(): for j in range(4): scene_idx = i + j if scene_idx < len(scenes): scene = scenes[scene_idx] with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): # Clickable thumbnail scene_img = gr.Image( value=scene["thumbnail"], height=150, interactive=False, show_label=False, elem_id=f"scene_thumb_{scene['name']}", sources=[], ) # Scene name and image count as text below thumbnail gr.Markdown( f"**{scene['name']}** \n {scene['num_images']} images", elem_classes=["scene-info"], ) # Connect thumbnail click to load scene scene_img.select( fn=lambda name=scene["name"]: load_example_scene(name), outputs=[ reconstruction_output, target_dir_output, image_gallery, log_output, ], ) else: # Empty column to maintain grid structure with gr.Column(scale=1): pass # ------------------------------------------------------------------------- # "Reconstruct" button logic: # - Clear fields # - Update log # - gradio_demo(...) with the existing target_dir # - Then set is_example = "False" # ------------------------------------------------------------------------- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( fn=update_log, inputs=[], outputs=[log_output] ).then( fn=gradio_demo, inputs=[ target_dir_output, conf_thres, frame_filter, show_cam, filter_sky, filter_black_bg, filter_white_bg, mask_ambiguous, ], outputs=[ reconstruction_output, log_output, frame_filter, processed_data_state, depth_map, normal_map, measure_image, measure_text, depth_view_selector, normal_view_selector, measure_view_selector, ], ).then( fn=lambda: "False", inputs=[], outputs=[is_example], # set is_example to "False" ) # ------------------------------------------------------------------------- # Real-time Visualization Updates # ------------------------------------------------------------------------- def update_all_visualizations_on_conf_change( processed_data, depth_selector, normal_selector, conf_thres_val, target_dir, frame_filter, show_cam, is_example, ): """Update 3D view and all tabs when confidence threshold changes""" # Update 3D pointcloud visualization glb_file, log_msg = update_visualization( target_dir, conf_thres_val, frame_filter, show_cam, is_example, ) # Update depth and normal tabs with new confidence threshold depth_vis = None normal_vis = None if processed_data is not None: # Get current view indices from selectors try: depth_view_idx = ( int(depth_selector.split()[1]) - 1 if depth_selector else 0 ) except: depth_view_idx = 0 try: normal_view_idx = ( int(normal_selector.split()[1]) - 1 if normal_selector else 0 ) except: normal_view_idx = 0 # Update visualizations with new confidence threshold depth_vis = update_depth_view( processed_data, depth_view_idx, conf_thres=conf_thres_val ) normal_vis = update_normal_view( processed_data, normal_view_idx, conf_thres=conf_thres_val ) return glb_file, log_msg, depth_vis, normal_vis conf_thres.change( fn=update_all_visualizations_on_conf_change, inputs=[ processed_data_state, depth_view_selector, normal_view_selector, conf_thres, target_dir_output, frame_filter, show_cam, is_example, ], outputs=[reconstruction_output, log_output, depth_map, normal_map], ) frame_filter.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, show_cam, is_example, ], [reconstruction_output, log_output], ) show_cam.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, show_cam, is_example, ], [reconstruction_output, log_output], ) filter_sky.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, show_cam, is_example, filter_sky, filter_black_bg, filter_white_bg, mask_ambiguous, ], [reconstruction_output, log_output], ) filter_black_bg.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, show_cam, is_example, filter_sky, filter_black_bg, filter_white_bg, mask_ambiguous, ], [reconstruction_output, log_output], ) filter_white_bg.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, show_cam, is_example, filter_sky, filter_black_bg, filter_white_bg, mask_ambiguous, ], [reconstruction_output, log_output], ) mask_ambiguous.change( update_visualization, [ target_dir_output, conf_thres, frame_filter, show_cam, is_example, filter_sky, filter_black_bg, filter_white_bg, mask_ambiguous, ], [reconstruction_output, log_output], ) # ------------------------------------------------------------------------- # Auto-update gallery whenever user uploads or changes their files # ------------------------------------------------------------------------- input_video.change( fn=update_gallery_on_upload, inputs=[input_video, input_images, s_time_interval], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) input_images.change( fn=update_gallery_on_upload, inputs=[input_video, input_images, s_time_interval], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) # ------------------------------------------------------------------------- # Measure tab functionality # ------------------------------------------------------------------------- measure_image.select( fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text], ) # ------------------------------------------------------------------------- # Navigation functionality for Depth, Normal, and Measure tabs # ------------------------------------------------------------------------- # Depth tab navigation prev_depth_btn.click( fn=lambda processed_data, current_selector, conf_thres_val: navigate_depth_view( processed_data, current_selector, -1, conf_thres=conf_thres_val ), inputs=[processed_data_state, depth_view_selector, conf_thres], outputs=[depth_view_selector, depth_map], ) next_depth_btn.click( fn=lambda processed_data, current_selector, conf_thres_val: navigate_depth_view( processed_data, current_selector, 1, conf_thres=conf_thres_val ), inputs=[processed_data_state, depth_view_selector, conf_thres], outputs=[depth_view_selector, depth_map], ) depth_view_selector.change( fn=lambda processed_data, selector_value, conf_thres_val: ( update_depth_view( processed_data, int(selector_value.split()[1]) - 1, conf_thres=conf_thres_val, ) if selector_value else None ), inputs=[processed_data_state, depth_view_selector, conf_thres], outputs=[depth_map], ) # Normal tab navigation prev_normal_btn.click( fn=lambda processed_data, current_selector, conf_thres_val: navigate_normal_view( processed_data, current_selector, -1, conf_thres=conf_thres_val ), inputs=[processed_data_state, normal_view_selector, conf_thres], outputs=[normal_view_selector, normal_map], ) next_normal_btn.click( fn=lambda processed_data, current_selector, conf_thres_val: navigate_normal_view( processed_data, current_selector, 1, conf_thres=conf_thres_val ), inputs=[processed_data_state, normal_view_selector, conf_thres], outputs=[normal_view_selector, normal_map], ) normal_view_selector.change( fn=lambda processed_data, selector_value, conf_thres_val: ( update_normal_view( processed_data, int(selector_value.split()[1]) - 1, conf_thres=conf_thres_val, ) if selector_value else None ), inputs=[processed_data_state, normal_view_selector, conf_thres], outputs=[normal_map], ) # Measure tab navigation prev_measure_btn.click( fn=lambda processed_data, current_selector: navigate_measure_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) next_measure_btn.click( fn=lambda processed_data, current_selector: navigate_measure_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) measure_view_selector.change( fn=lambda processed_data, selector_value: ( update_measure_view(processed_data, int(selector_value.split()[1]) - 1) if selector_value else (None, []) ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state], ) # ------------------------------------------------------------------------- # Acknowledgement section # ------------------------------------------------------------------------- gr.HTML(get_acknowledgements_html()) demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)