Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Upload app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -5,11 +5,10 @@ | |
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 |  | 
| 7 | 
             
            """
         | 
| 8 | 
            -
            MapAnything V2 | 
| 9 | 
            -
            -  | 
| 10 | 
            -
            -  | 
| 11 | 
            -
            -  | 
| 12 | 
            -
            - DBSCAN clustering for cross-view object matching
         | 
| 13 | 
             
            """
         | 
| 14 |  | 
| 15 | 
             
            import gc
         | 
| @@ -18,8 +17,6 @@ import shutil | |
| 18 | 
             
            import sys
         | 
| 19 | 
             
            import time
         | 
| 20 | 
             
            from datetime import datetime
         | 
| 21 | 
            -
            from pathlib import Path
         | 
| 22 | 
            -
            from collections import defaultdict
         | 
| 23 |  | 
| 24 | 
             
            os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
         | 
| 25 |  | 
| @@ -28,10 +25,8 @@ import gradio as gr | |
| 28 | 
             
            import numpy as np
         | 
| 29 | 
             
            import spaces
         | 
| 30 | 
             
            import torch
         | 
| 31 | 
            -
            import trimesh
         | 
| 32 | 
             
            from PIL import Image
         | 
| 33 | 
             
            from pillow_heif import register_heif_opener
         | 
| 34 | 
            -
            from sklearn.cluster import DBSCAN
         | 
| 35 |  | 
| 36 | 
             
            register_heif_opener()
         | 
| 37 |  | 
| @@ -65,10 +60,6 @@ def get_logo_base64(): | |
| 65 | 
             
                    return None
         | 
| 66 |  | 
| 67 |  | 
| 68 | 
            -
            # ============================================================================
         | 
| 69 | 
            -
            # Configuration
         | 
| 70 | 
            -
            # ============================================================================
         | 
| 71 | 
            -
             | 
| 72 | 
             
            # MapAnything Configuration
         | 
| 73 | 
             
            high_level_config = {
         | 
| 74 | 
             
                "path": "configs/train.yaml",
         | 
| @@ -89,846 +80,13 @@ high_level_config = { | |
| 89 | 
             
                "resolution": 518,
         | 
| 90 | 
             
            }
         | 
| 91 |  | 
| 92 | 
            -
            #  | 
| 93 | 
            -
            # 方案选择:
         | 
| 94 | 
            -
            # 1. "segformer" - SegFormer (最轻量,~14MB,最快)
         | 
| 95 | 
            -
            # 2. "maskformer" - MaskFormer (中等,~100MB,实例分割)
         | 
| 96 | 
            -
            # 3. "grounding_sam" - GroundingDINO + SAM (最强,~110MB,文本提示)
         | 
| 97 | 
            -
             | 
| 98 | 
            -
            SEGMENTATION_METHOD = "segformer"  # 默认使用最轻量的方案
         | 
| 99 | 
            -
             | 
| 100 | 
            -
            # SegFormer Configuration (推荐 - CPU友好)
         | 
| 101 | 
            -
            SEGFORMER_MODEL_ID = "nvidia/segformer-b0-finetuned-ade-512-512"  # 14MB,150类物体
         | 
| 102 | 
            -
             | 
| 103 | 
            -
            # MaskFormer Configuration (备选)
         | 
| 104 | 
            -
            MASKFORMER_MODEL_ID = "facebook/maskformer-swin-tiny-ade"  # 100MB,实例分割
         | 
| 105 | 
            -
             | 
| 106 | 
            -
            # GroundingDINO + SAM Configuration (原方案 - 需要文本提示)
         | 
| 107 | 
            -
            GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
         | 
| 108 | 
            -
            GROUNDING_DINO_BOX_THRESHOLD = 0.25
         | 
| 109 | 
            -
            GROUNDING_DINO_TEXT_THRESHOLD = 0.2
         | 
| 110 | 
            -
            SAM_MODEL_ID = "dhkim2810/MobileSAM"
         | 
| 111 | 
            -
            USE_MOBILE_SAM = True
         | 
| 112 | 
            -
             | 
| 113 | 
            -
            DEFAULT_TEXT_PROMPT = "chair . table . sofa . bed . desk . cabinet"
         | 
| 114 | 
            -
             | 
| 115 | 
            -
            # Common objects prompt for detection
         | 
| 116 | 
            -
            COMMON_OBJECTS_PROMPT = (
         | 
| 117 | 
            -
                "person . face . hand . "
         | 
| 118 | 
            -
                "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . "
         | 
| 119 | 
            -
                "door . window . wall . floor . ceiling . curtain . "
         | 
| 120 | 
            -
                "tv . monitor . screen . computer . laptop . keyboard . mouse . "
         | 
| 121 | 
            -
                "phone . tablet . remote . "
         | 
| 122 | 
            -
                "lamp . light . chandelier . "
         | 
| 123 | 
            -
                "book . magazine . paper . pen . pencil . "
         | 
| 124 | 
            -
                "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . "
         | 
| 125 | 
            -
                "vase . plant . flower . pot . "
         | 
| 126 | 
            -
                "clock . picture . frame . mirror . "
         | 
| 127 | 
            -
                "pillow . cushion . blanket . towel . "
         | 
| 128 | 
            -
                "bag . backpack . suitcase . "
         | 
| 129 | 
            -
                "box . basket . container . "
         | 
| 130 | 
            -
                "shoe . hat . coat . "
         | 
| 131 | 
            -
                "toy . ball . "
         | 
| 132 | 
            -
                "car . bicycle . motorcycle . bus . truck . "
         | 
| 133 | 
            -
                "tree . grass . sky . cloud . sun . "
         | 
| 134 | 
            -
                "dog . cat . bird . "
         | 
| 135 | 
            -
                "building . house . bridge . road . street . "
         | 
| 136 | 
            -
                "sign . pole . bench"
         | 
| 137 | 
            -
            )
         | 
| 138 | 
            -
             | 
| 139 | 
            -
            # DBSCAN clustering configuration (eps in meters)
         | 
| 140 | 
            -
            DBSCAN_EPS_CONFIG = {
         | 
| 141 | 
            -
                'sofa': 1.5,
         | 
| 142 | 
            -
                'bed': 1.5,
         | 
| 143 | 
            -
                'couch': 1.5,
         | 
| 144 | 
            -
                'desk': 0.8,
         | 
| 145 | 
            -
                'table': 0.8,
         | 
| 146 | 
            -
                'chair': 0.6,
         | 
| 147 | 
            -
                'cabinet': 0.8,
         | 
| 148 | 
            -
                'window': 0.5,
         | 
| 149 | 
            -
                'door': 0.6,
         | 
| 150 | 
            -
                'tv': 0.6,
         | 
| 151 | 
            -
                'default': 1.0
         | 
| 152 | 
            -
            }
         | 
| 153 | 
            -
             | 
| 154 | 
            -
            DBSCAN_MIN_SAMPLES = 1
         | 
| 155 | 
            -
             | 
| 156 | 
            -
            # Quality control
         | 
| 157 | 
            -
            MIN_DETECTION_CONFIDENCE = 0.35
         | 
| 158 | 
            -
            MIN_MASK_AREA = 100
         | 
| 159 | 
            -
             | 
| 160 | 
            -
            # Global model variables
         | 
| 161 | 
             
            model = None
         | 
| 162 | 
            -
            grounding_dino_model = None
         | 
| 163 | 
            -
            grounding_dino_processor = None
         | 
| 164 | 
            -
            sam_predictor = None
         | 
| 165 | 
            -
             | 
| 166 | 
            -
            # SegFormer 模型(轻量级语义分割)
         | 
| 167 | 
            -
            segformer_processor = None
         | 
| 168 | 
            -
            segformer_model = None
         | 
| 169 | 
            -
             | 
| 170 | 
            -
            # MaskFormer 模型(实例分割)
         | 
| 171 | 
            -
            maskformer_processor = None
         | 
| 172 | 
            -
            maskformer_model = None
         | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
            # ============================================================================
         | 
| 176 | 
            -
            # Model Loading Functions
         | 
| 177 | 
            -
            # ============================================================================
         | 
| 178 | 
            -
             | 
| 179 | 
            -
            def load_segformer_model(device="cpu"):
         | 
| 180 | 
            -
                """加载 SegFormer 模型(最轻量,CPU友好)"""
         | 
| 181 | 
            -
                global segformer_processor, segformer_model
         | 
| 182 | 
            -
                
         | 
| 183 | 
            -
                if segformer_model is not None:
         | 
| 184 | 
            -
                    print("✅ SegFormer already loaded")
         | 
| 185 | 
            -
                    return
         | 
| 186 | 
            -
                
         | 
| 187 | 
            -
                try:
         | 
| 188 | 
            -
                    from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
         | 
| 189 | 
            -
                    import os
         | 
| 190 | 
            -
                    
         | 
| 191 | 
            -
                    print(f"📥 Loading SegFormer from HuggingFace: {SEGFORMER_MODEL_ID}")
         | 
| 192 | 
            -
                    print(f"   💡 SegFormer-B0: ~14MB, 150类物体, CPU优化")
         | 
| 193 | 
            -
                    
         | 
| 194 | 
            -
                    cache_dir = os.getenv("HF_HOME", "./hf_cache")
         | 
| 195 | 
            -
                    
         | 
| 196 | 
            -
                    print(f"   正在下载 processor...")
         | 
| 197 | 
            -
                    segformer_processor = SegformerImageProcessor.from_pretrained(
         | 
| 198 | 
            -
                        SEGFORMER_MODEL_ID,
         | 
| 199 | 
            -
                        cache_dir=cache_dir
         | 
| 200 | 
            -
                    )
         | 
| 201 | 
            -
                    
         | 
| 202 | 
            -
                    print(f"   正在下载 model...")
         | 
| 203 | 
            -
                    segformer_model = SegformerForSemanticSegmentation.from_pretrained(
         | 
| 204 | 
            -
                        SEGFORMER_MODEL_ID,
         | 
| 205 | 
            -
                        cache_dir=cache_dir,
         | 
| 206 | 
            -
                        low_cpu_mem_usage=True
         | 
| 207 | 
            -
                    ).to(device).eval()
         | 
| 208 | 
            -
                    
         | 
| 209 | 
            -
                    print(f"✅ SegFormer loaded successfully on {device.upper()}")
         | 
| 210 | 
            -
                    print(f"   可识别类别: 人、家具、墙壁、地板等150类")
         | 
| 211 | 
            -
                    
         | 
| 212 | 
            -
                except Exception as e:
         | 
| 213 | 
            -
                    print(f"❌ SegFormer loading failed: {type(e).__name__}: {e}")
         | 
| 214 | 
            -
                    import traceback
         | 
| 215 | 
            -
                    traceback.print_exc()
         | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
            def load_maskformer_model(device="cpu"):
         | 
| 219 | 
            -
                """加载 MaskFormer 模型(实例分割)"""
         | 
| 220 | 
            -
                global maskformer_processor, maskformer_model
         | 
| 221 | 
            -
                
         | 
| 222 | 
            -
                if maskformer_model is not None:
         | 
| 223 | 
            -
                    print("✅ MaskFormer already loaded")
         | 
| 224 | 
            -
                    return
         | 
| 225 | 
            -
                
         | 
| 226 | 
            -
                try:
         | 
| 227 | 
            -
                    from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
         | 
| 228 | 
            -
                    import os
         | 
| 229 | 
            -
                    
         | 
| 230 | 
            -
                    print(f"📥 Loading MaskFormer from HuggingFace: {MASKFORMER_MODEL_ID}")
         | 
| 231 | 
            -
                    print(f"   💡 MaskFormer: ~100MB, 实例分割")
         | 
| 232 | 
            -
                    
         | 
| 233 | 
            -
                    cache_dir = os.getenv("HF_HOME", "./hf_cache")
         | 
| 234 | 
            -
                    
         | 
| 235 | 
            -
                    print(f"   正在下载 processor...")
         | 
| 236 | 
            -
                    maskformer_processor = MaskFormerImageProcessor.from_pretrained(
         | 
| 237 | 
            -
                        MASKFORMER_MODEL_ID,
         | 
| 238 | 
            -
                        cache_dir=cache_dir
         | 
| 239 | 
            -
                    )
         | 
| 240 | 
            -
                    
         | 
| 241 | 
            -
                    print(f"   正在下载 model...")
         | 
| 242 | 
            -
                    maskformer_model = MaskFormerForInstanceSegmentation.from_pretrained(
         | 
| 243 | 
            -
                        MASKFORMER_MODEL_ID,
         | 
| 244 | 
            -
                        cache_dir=cache_dir,
         | 
| 245 | 
            -
                        low_cpu_mem_usage=True
         | 
| 246 | 
            -
                    ).to(device).eval()
         | 
| 247 | 
            -
                    
         | 
| 248 | 
            -
                    print(f"✅ MaskFormer loaded successfully on {device.upper()}")
         | 
| 249 | 
            -
                    
         | 
| 250 | 
            -
                except Exception as e:
         | 
| 251 | 
            -
                    print(f"❌ MaskFormer loading failed: {type(e).__name__}: {e}")
         | 
| 252 | 
            -
                    import traceback
         | 
| 253 | 
            -
                    traceback.print_exc()
         | 
| 254 | 
            -
             | 
| 255 | 
            -
            def load_grounding_dino_model(device="cpu"):
         | 
| 256 | 
            -
                """Load GroundingDINO model from HuggingFace (CPU优化)"""
         | 
| 257 | 
            -
                global grounding_dino_model, grounding_dino_processor
         | 
| 258 | 
            -
                
         | 
| 259 | 
            -
                if grounding_dino_model is not None:
         | 
| 260 | 
            -
                    print("✅ GroundingDINO already loaded")
         | 
| 261 | 
            -
                    return
         | 
| 262 | 
            -
                
         | 
| 263 | 
            -
                try:
         | 
| 264 | 
            -
                    from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
         | 
| 265 | 
            -
                    import os
         | 
| 266 | 
            -
                    
         | 
| 267 | 
            -
                    # 强制使用 CPU 进行分割(节省 GPU 资源)
         | 
| 268 | 
            -
                    seg_device = "cpu"
         | 
| 269 | 
            -
                    print(f"📥 Loading GroundingDINO from HuggingFace: {GROUNDING_DINO_MODEL_ID} (使用 {seg_device.upper()})")
         | 
| 270 | 
            -
                    
         | 
| 271 | 
            -
                    # 设置缓存目录(HuggingFace Spaces友好)
         | 
| 272 | 
            -
                    cache_dir = os.getenv("HF_HOME", "./hf_cache")
         | 
| 273 | 
            -
                    
         | 
| 274 | 
            -
                    # 加载模型(带重试和详细日志)
         | 
| 275 | 
            -
                    print(f"   正在下载 processor...")
         | 
| 276 | 
            -
                    grounding_dino_processor = AutoProcessor.from_pretrained(
         | 
| 277 | 
            -
                        GROUNDING_DINO_MODEL_ID,
         | 
| 278 | 
            -
                        cache_dir=cache_dir,
         | 
| 279 | 
            -
                        trust_remote_code=True  # 允许运行远程代码
         | 
| 280 | 
            -
                    )
         | 
| 281 | 
            -
                    
         | 
| 282 | 
            -
                    print(f"   正在下载 model...")
         | 
| 283 | 
            -
                    grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
         | 
| 284 | 
            -
                        GROUNDING_DINO_MODEL_ID,
         | 
| 285 | 
            -
                        cache_dir=cache_dir,
         | 
| 286 | 
            -
                        trust_remote_code=True,
         | 
| 287 | 
            -
                        low_cpu_mem_usage=True  # 降低CPU内存使用
         | 
| 288 | 
            -
                    ).to(seg_device).eval()
         | 
| 289 | 
            -
                    
         | 
| 290 | 
            -
                    print(f"✅ GroundingDINO loaded successfully on {seg_device.upper()}")
         | 
| 291 | 
            -
                    
         | 
| 292 | 
            -
                except ImportError as e:
         | 
| 293 | 
            -
                    print(f"❌ ImportError: {e}")
         | 
| 294 | 
            -
                    print(f"💡 请检查 requirements.txt 是否包含 transformers 库")
         | 
| 295 | 
            -
                    import traceback
         | 
| 296 | 
            -
                    traceback.print_exc()
         | 
| 297 | 
            -
                except OSError as e:
         | 
| 298 | 
            -
                    print(f"❌ OSError (网络/文件问题): {e}")
         | 
| 299 | 
            -
                    print(f"💡 可能是网络连接问题或模型仓库不可访问")
         | 
| 300 | 
            -
                    print(f"💡 尝试解决方案:")
         | 
| 301 | 
            -
                    print(f"   1. 检查 HuggingFace Spaces 的网络连接")
         | 
| 302 | 
            -
                    print(f"   2. 检查模型ID是否正确: {GROUNDING_DINO_MODEL_ID}")
         | 
| 303 | 
            -
                    print(f"   3. 确保有足够的磁盘空间")
         | 
| 304 | 
            -
                    import traceback
         | 
| 305 | 
            -
                    traceback.print_exc()
         | 
| 306 | 
            -
                except Exception as e:
         | 
| 307 | 
            -
                    print(f"❌ GroundingDINO loading failed: {type(e).__name__}: {e}")
         | 
| 308 | 
            -
                    import traceback
         | 
| 309 | 
            -
                    traceback.print_exc()
         | 
| 310 | 
            -
             | 
| 311 | 
            -
             | 
| 312 | 
            -
            def load_sam_model(device="cpu"):
         | 
| 313 | 
            -
                """Load MobileSAM model from HuggingFace (CPU优化,比SAM快60倍)"""
         | 
| 314 | 
            -
                global sam_predictor
         | 
| 315 | 
            -
                
         | 
| 316 | 
            -
                if sam_predictor is not None:
         | 
| 317 | 
            -
                    print("✅ SAM already loaded")
         | 
| 318 | 
            -
                    return
         | 
| 319 | 
            -
                
         | 
| 320 | 
            -
                try:
         | 
| 321 | 
            -
                    from transformers import SamModel, SamProcessor
         | 
| 322 | 
            -
                    import os
         | 
| 323 | 
            -
                    
         | 
| 324 | 
            -
                    # 强制使用 CPU 进行分割(MobileSAM 专为移动设备/CPU优化)
         | 
| 325 | 
            -
                    seg_device = "cpu"
         | 
| 326 | 
            -
                    print(f"📥 Loading MobileSAM from HuggingFace: {SAM_MODEL_ID} (使用 {seg_device.upper()})")
         | 
| 327 | 
            -
                    print(f"   💡 MobileSAM 是轻量级版本,比 SAM-huge 快60倍,只有10MB,适合CPU运行")
         | 
| 328 | 
            -
                    
         | 
| 329 | 
            -
                    # 设置缓存目录
         | 
| 330 | 
            -
                    cache_dir = os.getenv("HF_HOME", "./hf_cache")
         | 
| 331 | 
            -
                    
         | 
| 332 | 
            -
                    print(f"   正在下载 processor...")
         | 
| 333 | 
            -
                    sam_processor = SamProcessor.from_pretrained(
         | 
| 334 | 
            -
                        SAM_MODEL_ID,
         | 
| 335 | 
            -
                        cache_dir=cache_dir
         | 
| 336 | 
            -
                    )
         | 
| 337 | 
            -
                    
         | 
| 338 | 
            -
                    print(f"   正在下载 model...")
         | 
| 339 | 
            -
                    sam_model = SamModel.from_pretrained(
         | 
| 340 | 
            -
                        SAM_MODEL_ID,
         | 
| 341 | 
            -
                        cache_dir=cache_dir,
         | 
| 342 | 
            -
                        low_cpu_mem_usage=True
         | 
| 343 | 
            -
                    ).to(seg_device).eval()
         | 
| 344 | 
            -
                    
         | 
| 345 | 
            -
                    # Wrap in a predictor-like interface
         | 
| 346 | 
            -
                    class SAMPredictor:
         | 
| 347 | 
            -
                        def __init__(self, model, processor, device):
         | 
| 348 | 
            -
                            self.model = model
         | 
| 349 | 
            -
                            self.processor = processor
         | 
| 350 | 
            -
                            self.device = device
         | 
| 351 | 
            -
                            self.image = None
         | 
| 352 | 
            -
                            
         | 
| 353 | 
            -
                        def set_image(self, image):
         | 
| 354 | 
            -
                            """Set image for prediction"""
         | 
| 355 | 
            -
                            if image.dtype == np.uint8:
         | 
| 356 | 
            -
                                self.image = Image.fromarray(image)
         | 
| 357 | 
            -
                            else:
         | 
| 358 | 
            -
                                self.image = Image.fromarray((image * 255).astype(np.uint8))
         | 
| 359 | 
            -
                        
         | 
| 360 | 
            -
                        def predict(self, box, multimask_output=False):
         | 
| 361 | 
            -
                            """Predict mask from box (CPU优化)"""
         | 
| 362 | 
            -
                            inputs = self.processor(
         | 
| 363 | 
            -
                                self.image, 
         | 
| 364 | 
            -
                                input_boxes=[[[box]]], 
         | 
| 365 | 
            -
                                return_tensors="pt"
         | 
| 366 | 
            -
                            )
         | 
| 367 | 
            -
                            # 确保在CPU上运行
         | 
| 368 | 
            -
                            inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
         | 
| 369 | 
            -
                            
         | 
| 370 | 
            -
                            with torch.no_grad():
         | 
| 371 | 
            -
                                outputs = self.model(**inputs)
         | 
| 372 | 
            -
                            
         | 
| 373 | 
            -
                            masks = self.processor.image_processor.post_process_masks(
         | 
| 374 | 
            -
                                outputs.pred_masks.cpu(),
         | 
| 375 | 
            -
                                inputs["original_sizes"].cpu() if "original_sizes" in inputs else outputs.pred_masks.new_tensor([[self.image.height, self.image.width]]),
         | 
| 376 | 
            -
                                inputs["reshaped_input_sizes"].cpu() if "reshaped_input_sizes" in inputs else outputs.pred_masks.new_tensor([[self.image.height, self.image.width]])
         | 
| 377 | 
            -
                            )[0].squeeze().numpy()
         | 
| 378 | 
            -
                            
         | 
| 379 | 
            -
                            if len(masks.shape) == 2:
         | 
| 380 | 
            -
                                masks = masks[np.newaxis, ...]
         | 
| 381 | 
            -
                            
         | 
| 382 | 
            -
                            return masks, None, None
         | 
| 383 | 
            -
                    
         | 
| 384 | 
            -
                    sam_predictor = SAMPredictor(sam_model, sam_processor, seg_device)
         | 
| 385 | 
            -
                    print(f"✅ MobileSAM loaded successfully on {seg_device.upper()}")
         | 
| 386 | 
            -
                    
         | 
| 387 | 
            -
                except ImportError as e:
         | 
| 388 | 
            -
                    print(f"❌ ImportError: {e}")
         | 
| 389 | 
            -
                    print(f"💡 请检查 requirements.txt 是否包含 transformers 库")
         | 
| 390 | 
            -
                    import traceback
         | 
| 391 | 
            -
                    traceback.print_exc()
         | 
| 392 | 
            -
                except OSError as e:
         | 
| 393 | 
            -
                    print(f"❌ OSError (网络/文件问题): {e}")
         | 
| 394 | 
            -
                    print(f"💡 可能是网络连接问题或模型仓库不可访问")
         | 
| 395 | 
            -
                    print(f"💡 尝试解决方案:")
         | 
| 396 | 
            -
                    print(f"   1. 检查 HuggingFace Spaces 的网络连接")
         | 
| 397 | 
            -
                    print(f"   2. 检查模型ID是否正确: {SAM_MODEL_ID}")
         | 
| 398 | 
            -
                    print(f"   3. 确保有足够的磁盘空间")
         | 
| 399 | 
            -
                    import traceback
         | 
| 400 | 
            -
                    traceback.print_exc()
         | 
| 401 | 
            -
                except Exception as e:
         | 
| 402 | 
            -
                    print(f"❌ SAM loading failed: {type(e).__name__}: {e}")
         | 
| 403 | 
            -
                    import traceback
         | 
| 404 | 
            -
                    traceback.print_exc()
         | 
| 405 | 
            -
             | 
| 406 | 
            -
             | 
| 407 | 
            -
            # ============================================================================
         | 
| 408 | 
            -
            # Segmentation Functions
         | 
| 409 | 
            -
            # ============================================================================
         | 
| 410 | 
            -
             | 
| 411 | 
            -
            def generate_distinct_colors(n):
         | 
| 412 | 
            -
                """Generate N visually distinct colors (RGB, 0-255)"""
         | 
| 413 | 
            -
                import colorsys
         | 
| 414 | 
            -
                if n == 0:
         | 
| 415 | 
            -
                    return []
         | 
| 416 | 
            -
                
         | 
| 417 | 
            -
                colors = []
         | 
| 418 | 
            -
                for i in range(n):
         | 
| 419 | 
            -
                    hue = i / max(n, 1)
         | 
| 420 | 
            -
                    rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
         | 
| 421 | 
            -
                    rgb_color = tuple(int(c * 255) for c in rgb)
         | 
| 422 | 
            -
                    colors.append(rgb_color)
         | 
| 423 | 
            -
                
         | 
| 424 | 
            -
                return colors
         | 
| 425 | 
            -
             | 
| 426 | 
            -
             | 
| 427 | 
            -
            # ============================================================================
         | 
| 428 | 
            -
            # SegFormer 分割函数(简化方案)
         | 
| 429 | 
            -
            # ============================================================================
         | 
| 430 | 
            -
             | 
| 431 | 
            -
            def run_segformer_segmentation(image_np, device="cpu"):
         | 
| 432 | 
            -
                """使用 SegFormer 进行语义分割(最简单,CPU友好)"""
         | 
| 433 | 
            -
                if segformer_model is None or segformer_processor is None:
         | 
| 434 | 
            -
                    print("❌ SegFormer model not loaded")
         | 
| 435 | 
            -
                    return []
         | 
| 436 | 
            -
                
         | 
| 437 | 
            -
                try:
         | 
| 438 | 
            -
                    import torch
         | 
| 439 | 
            -
                    from PIL import Image
         | 
| 440 | 
            -
                    
         | 
| 441 | 
            -
                    # 准备图片
         | 
| 442 | 
            -
                    if image_np.dtype != np.uint8:
         | 
| 443 | 
            -
                        image_np = (image_np * 255).astype(np.uint8)
         | 
| 444 | 
            -
                    image_pil = Image.fromarray(image_np)
         | 
| 445 | 
            -
                    
         | 
| 446 | 
            -
                    # 推理
         | 
| 447 | 
            -
                    inputs = segformer_processor(images=image_pil, return_tensors="pt")
         | 
| 448 | 
            -
                    inputs = {k: v.to(device) for k, v in inputs.items()}
         | 
| 449 | 
            -
                    
         | 
| 450 | 
            -
                    with torch.no_grad():
         | 
| 451 | 
            -
                        outputs = segformer_model(**inputs)
         | 
| 452 | 
            -
                    
         | 
| 453 | 
            -
                    # 获取分割结果
         | 
| 454 | 
            -
                    logits = outputs.logits  # (1, num_classes, H, W)
         | 
| 455 | 
            -
                    predicted_segmentation = logits.argmax(dim=1).squeeze().cpu().numpy()
         | 
| 456 | 
            -
                    
         | 
| 457 | 
            -
                    # 生成实例掩码(将相同类别的连续区域分开)
         | 
| 458 | 
            -
                    from scipy import ndimage
         | 
| 459 | 
            -
                    
         | 
| 460 | 
            -
                    # ADE20K 常见类别映射(部分)
         | 
| 461 | 
            -
                    ade20k_labels = {
         | 
| 462 | 
            -
                        5: "wall", 7: "floor", 11: "ceiling", 18: "window", 14: "door",
         | 
| 463 | 
            -
                        19: "table", 20: "chair", 22: "sofa", 23: "bed", 28: "cabinet",
         | 
| 464 | 
            -
                        34: "desk", 39: "lamp", 65: "television", 89: "shelf"
         | 
| 465 | 
            -
                    }
         | 
| 466 | 
            -
                    
         | 
| 467 | 
            -
                    detections = []
         | 
| 468 | 
            -
                    masks = []
         | 
| 469 | 
            -
                    
         | 
| 470 | 
            -
                    # 对每个类别提取实例
         | 
| 471 | 
            -
                    unique_labels = np.unique(predicted_segmentation)
         | 
| 472 | 
            -
                    for label_id in unique_labels:
         | 
| 473 | 
            -
                        if label_id == 0:  # 跳过背景
         | 
| 474 | 
            -
                            continue
         | 
| 475 | 
            -
                            
         | 
| 476 | 
            -
                        # 获取该类别的掩码
         | 
| 477 | 
            -
                        class_mask = (predicted_segmentation == label_id)
         | 
| 478 | 
            -
                        
         | 
| 479 | 
            -
                        # 分离连通区域(不同实例)
         | 
| 480 | 
            -
                        labeled_mask, num_features = ndimage.label(class_mask)
         | 
| 481 | 
            -
                        
         | 
| 482 | 
            -
                        for instance_id in range(1, num_features + 1):
         | 
| 483 | 
            -
                            instance_mask = (labeled_mask == instance_id)
         | 
| 484 | 
            -
                            mask_area = instance_mask.sum()
         | 
| 485 | 
            -
                            
         | 
| 486 | 
            -
                            # 过滤小区域
         | 
| 487 | 
            -
                            if mask_area < MIN_MASK_AREA:
         | 
| 488 | 
            -
                                continue
         | 
| 489 | 
            -
                            
         | 
| 490 | 
            -
                            # 计算边界框
         | 
| 491 | 
            -
                            rows, cols = np.where(instance_mask)
         | 
| 492 | 
            -
                            if len(rows) == 0:
         | 
| 493 | 
            -
                                continue
         | 
| 494 | 
            -
                                
         | 
| 495 | 
            -
                            y_min, y_max = rows.min(), rows.max()
         | 
| 496 | 
            -
                            x_min, x_max = cols.min(), cols.max()
         | 
| 497 | 
            -
                            bbox = [x_min, y_min, x_max, y_max]
         | 
| 498 | 
            -
                            
         | 
| 499 | 
            -
                            # 获取类别名称
         | 
| 500 | 
            -
                            label_name = ade20k_labels.get(int(label_id), f"object_{label_id}")
         | 
| 501 | 
            -
                            
         | 
| 502 | 
            -
                            detections.append({
         | 
| 503 | 
            -
                                'bbox': bbox,
         | 
| 504 | 
            -
                                'label': label_name,
         | 
| 505 | 
            -
                                'confidence': 0.9,  # SegFormer 不提供置信度,给固定值
         | 
| 506 | 
            -
                                'class_id': int(label_id)
         | 
| 507 | 
            -
                            })
         | 
| 508 | 
            -
                            masks.append(instance_mask)
         | 
| 509 | 
            -
                    
         | 
| 510 | 
            -
                    return detections, masks
         | 
| 511 | 
            -
                    
         | 
| 512 | 
            -
                except Exception as e:
         | 
| 513 | 
            -
                    print(f"❌ SegFormer segmentation failed: {e}")
         | 
| 514 | 
            -
                    import traceback
         | 
| 515 | 
            -
                    traceback.print_exc()
         | 
| 516 | 
            -
                    return [], []
         | 
| 517 |  | 
| 518 |  | 
| 519 | 
            -
             | 
| 520 | 
            -
             | 
| 521 | 
            -
             | 
| 522 | 
            -
                    print("⚠️ GroundingDINO not loaded")
         | 
| 523 | 
            -
                    return []
         | 
| 524 | 
            -
                
         | 
| 525 | 
            -
                try:
         | 
| 526 | 
            -
                    print(f"🔍 GroundingDINO detection (CPU): {text_prompt}")
         | 
| 527 | 
            -
                    
         | 
| 528 | 
            -
                    # Convert to PIL Image
         | 
| 529 | 
            -
                    if image_np.dtype == np.uint8:
         | 
| 530 | 
            -
                        pil_image = Image.fromarray(image_np)
         | 
| 531 | 
            -
                    else:
         | 
| 532 | 
            -
                        pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
         | 
| 533 | 
            -
                    
         | 
| 534 | 
            -
                    # Preprocess - 强制使用CPU
         | 
| 535 | 
            -
                    seg_device = "cpu"
         | 
| 536 | 
            -
                    inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt")
         | 
| 537 | 
            -
                    inputs = {k: v.to(seg_device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
         | 
| 538 | 
            -
                    
         | 
| 539 | 
            -
                    # Inference
         | 
| 540 | 
            -
                    with torch.no_grad():
         | 
| 541 | 
            -
                        outputs = grounding_dino_model(**inputs)
         | 
| 542 | 
            -
                    
         | 
| 543 | 
            -
                    # Post-process
         | 
| 544 | 
            -
                    results = grounding_dino_processor.post_process_grounded_object_detection(
         | 
| 545 | 
            -
                        outputs,
         | 
| 546 | 
            -
                        inputs["input_ids"],
         | 
| 547 | 
            -
                        threshold=GROUNDING_DINO_BOX_THRESHOLD,
         | 
| 548 | 
            -
                        text_threshold=GROUNDING_DINO_TEXT_THRESHOLD,
         | 
| 549 | 
            -
                        target_sizes=[pil_image.size[::-1]]
         | 
| 550 | 
            -
                    )[0]
         | 
| 551 | 
            -
                    
         | 
| 552 | 
            -
                    # Convert to unified format
         | 
| 553 | 
            -
                    detections = []
         | 
| 554 | 
            -
                    boxes = results["boxes"].cpu().numpy()
         | 
| 555 | 
            -
                    scores = results["scores"].cpu().numpy()
         | 
| 556 | 
            -
                    labels = results["labels"]
         | 
| 557 | 
            -
                    
         | 
| 558 | 
            -
                    print(f"✅ Detected {len(boxes)} objects")
         | 
| 559 | 
            -
                    
         | 
| 560 | 
            -
                    for box, score, label in zip(boxes, scores, labels):
         | 
| 561 | 
            -
                        detection = {
         | 
| 562 | 
            -
                            'bbox': box.tolist(),  # [x1, y1, x2, y2]
         | 
| 563 | 
            -
                            'label': label,
         | 
| 564 | 
            -
                            'confidence': float(score)
         | 
| 565 | 
            -
                        }
         | 
| 566 | 
            -
                        detections.append(detection)
         | 
| 567 | 
            -
                        print(f"   - {label}: {score:.2f}")
         | 
| 568 | 
            -
                    
         | 
| 569 | 
            -
                    return detections
         | 
| 570 | 
            -
                    
         | 
| 571 | 
            -
                except Exception as e:
         | 
| 572 | 
            -
                    print(f"❌ GroundingDINO detection failed: {e}")
         | 
| 573 | 
            -
                    import traceback
         | 
| 574 | 
            -
                    traceback.print_exc()
         | 
| 575 | 
            -
                    return []
         | 
| 576 | 
            -
             | 
| 577 | 
            -
             | 
| 578 | 
            -
            def run_sam_refinement(image_np, boxes):
         | 
| 579 | 
            -
                """Run SAM precise segmentation"""
         | 
| 580 | 
            -
                if sam_predictor is None:
         | 
| 581 | 
            -
                    print("⚠️ SAM not loaded, using bbox as mask")
         | 
| 582 | 
            -
                    # Use bbox to create simple rectangular mask
         | 
| 583 | 
            -
                    masks = []
         | 
| 584 | 
            -
                    h, w = image_np.shape[:2]
         | 
| 585 | 
            -
                    for box in boxes:
         | 
| 586 | 
            -
                        x1, y1, x2, y2 = map(int, box)
         | 
| 587 | 
            -
                        mask = np.zeros((h, w), dtype=bool)
         | 
| 588 | 
            -
                        mask[y1:y2, x1:x2] = True
         | 
| 589 | 
            -
                        masks.append(mask)
         | 
| 590 | 
            -
                    return masks
         | 
| 591 | 
            -
                
         | 
| 592 | 
            -
                try:
         | 
| 593 | 
            -
                    print(f"🎯 SAM precise segmentation for {len(boxes)} regions...")
         | 
| 594 | 
            -
                    sam_predictor.set_image(image_np)
         | 
| 595 | 
            -
                    
         | 
| 596 | 
            -
                    masks = []
         | 
| 597 | 
            -
                    for box in boxes:
         | 
| 598 | 
            -
                        x1, y1, x2, y2 = map(int, box)
         | 
| 599 | 
            -
                        box_array = np.array([x1, y1, x2, y2])
         | 
| 600 | 
            -
                        
         | 
| 601 | 
            -
                        mask_output, _, _ = sam_predictor.predict(
         | 
| 602 | 
            -
                            box=box_array,
         | 
| 603 | 
            -
                            multimask_output=False
         | 
| 604 | 
            -
                        )
         | 
| 605 | 
            -
                        masks.append(mask_output[0])
         | 
| 606 | 
            -
                    
         | 
| 607 | 
            -
                    print(f"✅ SAM segmentation complete")
         | 
| 608 | 
            -
                    return masks
         | 
| 609 | 
            -
                    
         | 
| 610 | 
            -
                except Exception as e:
         | 
| 611 | 
            -
                    print(f"❌ SAM segmentation failed: {e}")
         | 
| 612 | 
            -
                    # Fallback to bbox masks
         | 
| 613 | 
            -
                    masks = []
         | 
| 614 | 
            -
                    h, w = image_np.shape[:2]
         | 
| 615 | 
            -
                    for box in boxes:
         | 
| 616 | 
            -
                        x1, y1, x2, y2 = map(int, box)
         | 
| 617 | 
            -
                        mask = np.zeros((h, w), dtype=bool)
         | 
| 618 | 
            -
                        mask[y1:y2, x1:x2] = True
         | 
| 619 | 
            -
                        masks.append(mask)
         | 
| 620 | 
            -
                    return masks
         | 
| 621 | 
            -
             | 
| 622 | 
            -
             | 
| 623 | 
            -
            def normalize_label(label):
         | 
| 624 | 
            -
                """Normalize label to main category"""
         | 
| 625 | 
            -
                label = label.strip().lower()
         | 
| 626 | 
            -
                
         | 
| 627 | 
            -
                priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door']
         | 
| 628 | 
            -
                
         | 
| 629 | 
            -
                for priority in priority_labels:
         | 
| 630 | 
            -
                    if priority in label:
         | 
| 631 | 
            -
                        return priority
         | 
| 632 | 
            -
                
         | 
| 633 | 
            -
                first_word = label.split()[0] if label else label
         | 
| 634 | 
            -
                
         | 
| 635 | 
            -
                # Handle plural forms
         | 
| 636 | 
            -
                if first_word.endswith('s') and len(first_word) > 1:
         | 
| 637 | 
            -
                    singular = first_word[:-1]
         | 
| 638 | 
            -
                    if first_word.endswith('sses'):
         | 
| 639 | 
            -
                        singular = first_word[:-2]
         | 
| 640 | 
            -
                    elif first_word.endswith('ies'):
         | 
| 641 | 
            -
                        singular = first_word[:-3] + 'y'
         | 
| 642 | 
            -
                    elif first_word.endswith('ves'):
         | 
| 643 | 
            -
                        singular = first_word[:-3] + 'f'
         | 
| 644 | 
            -
                    return singular
         | 
| 645 | 
            -
                
         | 
| 646 | 
            -
                return first_word
         | 
| 647 | 
            -
             | 
| 648 | 
            -
             | 
| 649 | 
            -
            def compute_object_3d_center(points, mask):
         | 
| 650 | 
            -
                """Compute 3D center of object"""
         | 
| 651 | 
            -
                masked_points = points[mask]
         | 
| 652 | 
            -
                if len(masked_points) == 0:
         | 
| 653 | 
            -
                    return None
         | 
| 654 | 
            -
                return np.median(masked_points, axis=0)
         | 
| 655 | 
            -
             | 
| 656 | 
            -
             | 
| 657 | 
            -
            def compute_adaptive_eps(centers, base_eps):
         | 
| 658 | 
            -
                """Adaptively compute eps value based on object distribution"""
         | 
| 659 | 
            -
                if len(centers) <= 1:
         | 
| 660 | 
            -
                    return base_eps
         | 
| 661 | 
            -
                
         | 
| 662 | 
            -
                from scipy.spatial.distance import pdist
         | 
| 663 | 
            -
                distances = pdist(centers)
         | 
| 664 | 
            -
                
         | 
| 665 | 
            -
                if len(distances) == 0:
         | 
| 666 | 
            -
                    return base_eps
         | 
| 667 | 
            -
                
         | 
| 668 | 
            -
                median_dist = np.median(distances)
         | 
| 669 | 
            -
                
         | 
| 670 | 
            -
                if median_dist > base_eps * 2:
         | 
| 671 | 
            -
                    adaptive_eps = min(median_dist * 0.6, base_eps * 2.5)
         | 
| 672 | 
            -
                elif median_dist > base_eps:
         | 
| 673 | 
            -
                    adaptive_eps = median_dist * 0.5
         | 
| 674 | 
            -
                else:
         | 
| 675 | 
            -
                    adaptive_eps = base_eps
         | 
| 676 | 
            -
                
         | 
| 677 | 
            -
                return adaptive_eps
         | 
| 678 | 
            -
             | 
| 679 | 
            -
             | 
| 680 | 
            -
            def match_objects_across_views(all_view_detections):
         | 
| 681 | 
            -
                """Match objects across views using DBSCAN clustering"""
         | 
| 682 | 
            -
                print("\n🔗 Matching objects across views using DBSCAN clustering...")
         | 
| 683 | 
            -
                
         | 
| 684 | 
            -
                objects_by_label = defaultdict(list)
         | 
| 685 | 
            -
                
         | 
| 686 | 
            -
                for view_idx, detections in enumerate(all_view_detections):
         | 
| 687 | 
            -
                    for det_idx, det in enumerate(detections):
         | 
| 688 | 
            -
                        if det.get('center_3d') is None:
         | 
| 689 | 
            -
                            continue
         | 
| 690 | 
            -
                        
         | 
| 691 | 
            -
                        norm_label = normalize_label(det['label'])
         | 
| 692 | 
            -
                        objects_by_label[norm_label].append({
         | 
| 693 | 
            -
                            'view_idx': view_idx,
         | 
| 694 | 
            -
                            'det_idx': det_idx,
         | 
| 695 | 
            -
                            'label': det['label'],
         | 
| 696 | 
            -
                            'norm_label': norm_label,
         | 
| 697 | 
            -
                            'center_3d': det['center_3d'],
         | 
| 698 | 
            -
                            'confidence': det['confidence'],
         | 
| 699 | 
            -
                        })
         | 
| 700 | 
            -
                
         | 
| 701 | 
            -
                if len(objects_by_label) == 0:
         | 
| 702 | 
            -
                    return {}, []
         | 
| 703 | 
            -
                
         | 
| 704 | 
            -
                object_id_map = defaultdict(dict)
         | 
| 705 | 
            -
                unique_objects = []
         | 
| 706 | 
            -
                next_global_id = 0
         | 
| 707 | 
            -
                
         | 
| 708 | 
            -
                for norm_label, objects in objects_by_label.items():
         | 
| 709 | 
            -
                    print(f"\n   📦 Processing {norm_label}: {len(objects)} detections")
         | 
| 710 | 
            -
                    
         | 
| 711 | 
            -
                    if len(objects) == 1:
         | 
| 712 | 
            -
                        obj = objects[0]
         | 
| 713 | 
            -
                        unique_objects.append({
         | 
| 714 | 
            -
                            'global_id': next_global_id,
         | 
| 715 | 
            -
                            'label': obj['label'],
         | 
| 716 | 
            -
                            'views': [(obj['view_idx'], obj['det_idx'])],
         | 
| 717 | 
            -
                            'center_3d': obj['center_3d'],
         | 
| 718 | 
            -
                        })
         | 
| 719 | 
            -
                        object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
         | 
| 720 | 
            -
                        next_global_id += 1
         | 
| 721 | 
            -
                        print(f"      → 1 cluster (single detection)")
         | 
| 722 | 
            -
                        continue
         | 
| 723 | 
            -
                    
         | 
| 724 | 
            -
                    centers = np.array([obj['center_3d'] for obj in objects])
         | 
| 725 | 
            -
                    
         | 
| 726 | 
            -
                    base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0))
         | 
| 727 | 
            -
                    eps = compute_adaptive_eps(centers, base_eps)
         | 
| 728 | 
            -
                    
         | 
| 729 | 
            -
                    clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean')
         | 
| 730 | 
            -
                    cluster_labels = clustering.fit_predict(centers)
         | 
| 731 | 
            -
                    
         | 
| 732 | 
            -
                    n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
         | 
| 733 | 
            -
                    n_noise = list(cluster_labels).count(-1)
         | 
| 734 | 
            -
                    
         | 
| 735 | 
            -
                    if eps != base_eps:
         | 
| 736 | 
            -
                        print(f"      → {n_clusters} clusters (base_eps={base_eps}m → adaptive_eps={eps:.2f}m)")
         | 
| 737 | 
            -
                    else:
         | 
| 738 | 
            -
                        print(f"      → {n_clusters} clusters (eps={eps}m)")
         | 
| 739 | 
            -
                    if n_noise > 0:
         | 
| 740 | 
            -
                        print(f"      ⚠️  {n_noise} noise points (isolated detections)")
         | 
| 741 | 
            -
                    
         | 
| 742 | 
            -
                    for cluster_id in set(cluster_labels):
         | 
| 743 | 
            -
                        if cluster_id == -1:
         | 
| 744 | 
            -
                            for i, label in enumerate(cluster_labels):
         | 
| 745 | 
            -
                                if label == -1:
         | 
| 746 | 
            -
                                    obj = objects[i]
         | 
| 747 | 
            -
                                    unique_objects.append({
         | 
| 748 | 
            -
                                        'global_id': next_global_id,
         | 
| 749 | 
            -
                                        'label': obj['label'],
         | 
| 750 | 
            -
                                        'views': [(obj['view_idx'], obj['det_idx'])],
         | 
| 751 | 
            -
                                        'center_3d': obj['center_3d'],
         | 
| 752 | 
            -
                                    })
         | 
| 753 | 
            -
                                    object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
         | 
| 754 | 
            -
                                    next_global_id += 1
         | 
| 755 | 
            -
                        else:
         | 
| 756 | 
            -
                            cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id]
         | 
| 757 | 
            -
                            
         | 
| 758 | 
            -
                            total_conf = sum(o['confidence'] for o in cluster_objects)
         | 
| 759 | 
            -
                            weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf
         | 
| 760 | 
            -
                            
         | 
| 761 | 
            -
                            unique_objects.append({
         | 
| 762 | 
            -
                                'global_id': next_global_id,
         | 
| 763 | 
            -
                                'label': cluster_objects[0]['label'],
         | 
| 764 | 
            -
                                'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects],
         | 
| 765 | 
            -
                                'center_3d': weighted_center,
         | 
| 766 | 
            -
                            })
         | 
| 767 | 
            -
                            
         | 
| 768 | 
            -
                            for obj in cluster_objects:
         | 
| 769 | 
            -
                                object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
         | 
| 770 | 
            -
                            
         | 
| 771 | 
            -
                            next_global_id += 1
         | 
| 772 | 
            -
                
         | 
| 773 | 
            -
                print(f"\n   📊 Summary:")
         | 
| 774 | 
            -
                print(f"      Total detections: {sum(len(objs) for objs in objects_by_label.values())}")
         | 
| 775 | 
            -
                print(f"      Unique objects: {len(unique_objects)}")
         | 
| 776 | 
            -
                
         | 
| 777 | 
            -
                return object_id_map, unique_objects
         | 
| 778 | 
            -
             | 
| 779 | 
            -
             | 
| 780 | 
            -
            def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks, 
         | 
| 781 | 
            -
                                                 object_id_map, unique_objects, target_dir):
         | 
| 782 | 
            -
                """Create multi-view fused segmented mesh"""
         | 
| 783 | 
            -
                try:
         | 
| 784 | 
            -
                    print("\n🎨 Generating multi-view segmented mesh...")
         | 
| 785 | 
            -
                    
         | 
| 786 | 
            -
                    unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects))
         | 
| 787 | 
            -
                    label_colors = {}
         | 
| 788 | 
            -
                    colors = generate_distinct_colors(len(unique_normalized_labels))
         | 
| 789 | 
            -
                    
         | 
| 790 | 
            -
                    for i, norm_label in enumerate(unique_normalized_labels):
         | 
| 791 | 
            -
                        label_colors[norm_label] = colors[i]
         | 
| 792 | 
            -
                    
         | 
| 793 | 
            -
                    for obj in unique_objects:
         | 
| 794 | 
            -
                        norm_label = normalize_label(obj['label'])
         | 
| 795 | 
            -
                        obj['color'] = label_colors[norm_label]
         | 
| 796 | 
            -
                        obj['normalized_label'] = norm_label
         | 
| 797 | 
            -
                    
         | 
| 798 | 
            -
                    print(f"   Object category color mapping:")
         | 
| 799 | 
            -
                    for norm_label, color in sorted(label_colors.items()):
         | 
| 800 | 
            -
                        count = sum(1 for obj in unique_objects if normalize_label(obj['label']) == norm_label)
         | 
| 801 | 
            -
                        print(f"   {norm_label} × {count} → RGB{color}")
         | 
| 802 | 
            -
                    
         | 
| 803 | 
            -
                    import utils3d
         | 
| 804 | 
            -
                    
         | 
| 805 | 
            -
                    all_meshes = []
         | 
| 806 | 
            -
                    
         | 
| 807 | 
            -
                    for view_idx in range(len(processed_data)):
         | 
| 808 | 
            -
                        view_data = processed_data[view_idx]
         | 
| 809 | 
            -
                        image = view_data["image"]
         | 
| 810 | 
            -
                        points3d = view_data["points3d"]
         | 
| 811 | 
            -
                        mask = view_data.get("mask")
         | 
| 812 | 
            -
                        normal = view_data.get("normal")
         | 
| 813 | 
            -
                        
         | 
| 814 | 
            -
                        detections = all_view_detections[view_idx]
         | 
| 815 | 
            -
                        masks = all_view_masks[view_idx]
         | 
| 816 | 
            -
                        
         | 
| 817 | 
            -
                        if len(detections) == 0:
         | 
| 818 | 
            -
                            continue
         | 
| 819 | 
            -
                        
         | 
| 820 | 
            -
                        if image.dtype != np.uint8:
         | 
| 821 | 
            -
                            if image.max() <= 1.0:
         | 
| 822 | 
            -
                                image = (image * 255).astype(np.uint8)
         | 
| 823 | 
            -
                            else:
         | 
| 824 | 
            -
                                image = image.astype(np.uint8)
         | 
| 825 | 
            -
                        
         | 
| 826 | 
            -
                        colored_image = image.copy()
         | 
| 827 | 
            -
                        confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
         | 
| 828 | 
            -
                        
         | 
| 829 | 
            -
                        detections_info = []
         | 
| 830 | 
            -
                        filtered_count = 0
         | 
| 831 | 
            -
                        for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)):
         | 
| 832 | 
            -
                            if det['confidence'] < MIN_DETECTION_CONFIDENCE:
         | 
| 833 | 
            -
                                filtered_count += 1
         | 
| 834 | 
            -
                                continue
         | 
| 835 | 
            -
                            
         | 
| 836 | 
            -
                            mask_area = seg_mask.sum()
         | 
| 837 | 
            -
                            if mask_area < MIN_MASK_AREA:
         | 
| 838 | 
            -
                                filtered_count += 1
         | 
| 839 | 
            -
                                continue
         | 
| 840 | 
            -
                            
         | 
| 841 | 
            -
                            global_id = object_id_map[view_idx].get(det_idx)
         | 
| 842 | 
            -
                            if global_id is None:
         | 
| 843 | 
            -
                                continue
         | 
| 844 | 
            -
                            
         | 
| 845 | 
            -
                            unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None)
         | 
| 846 | 
            -
                            if unique_obj is None:
         | 
| 847 | 
            -
                                continue
         | 
| 848 | 
            -
                            
         | 
| 849 | 
            -
                            detections_info.append({
         | 
| 850 | 
            -
                                'mask': seg_mask,
         | 
| 851 | 
            -
                                'color': unique_obj['color'],
         | 
| 852 | 
            -
                                'confidence': det['confidence'],
         | 
| 853 | 
            -
                            })
         | 
| 854 | 
            -
                        
         | 
| 855 | 
            -
                        if filtered_count > 0:
         | 
| 856 | 
            -
                            print(f"   View {view_idx + 1}: filtered {filtered_count} low-quality detections")
         | 
| 857 | 
            -
                        
         | 
| 858 | 
            -
                        detections_info.sort(key=lambda x: x['confidence'])
         | 
| 859 | 
            -
                        
         | 
| 860 | 
            -
                        for info in detections_info:
         | 
| 861 | 
            -
                            seg_mask = info['mask']
         | 
| 862 | 
            -
                            color = info['color']
         | 
| 863 | 
            -
                            conf = info['confidence']
         | 
| 864 | 
            -
                            
         | 
| 865 | 
            -
                            update_mask = seg_mask & (conf > confidence_map)
         | 
| 866 | 
            -
                            colored_image[update_mask] = color
         | 
| 867 | 
            -
                            confidence_map[update_mask] = conf
         | 
| 868 | 
            -
                        
         | 
| 869 | 
            -
                        height, width = image.shape[:2]
         | 
| 870 | 
            -
                        
         | 
| 871 | 
            -
                        if normal is None:
         | 
| 872 | 
            -
                            faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
         | 
| 873 | 
            -
                                points3d,
         | 
| 874 | 
            -
                                colored_image.astype(np.float32) / 255,
         | 
| 875 | 
            -
                                utils3d.numpy.image_uv(width=width, height=height),
         | 
| 876 | 
            -
                                mask=mask if mask is not None else np.ones((height, width), dtype=bool),
         | 
| 877 | 
            -
                                tri=True
         | 
| 878 | 
            -
                            )
         | 
| 879 | 
            -
                            vertex_normals = None
         | 
| 880 | 
            -
                        else:
         | 
| 881 | 
            -
                            faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
         | 
| 882 | 
            -
                                points3d,
         | 
| 883 | 
            -
                                colored_image.astype(np.float32) / 255,
         | 
| 884 | 
            -
                                utils3d.numpy.image_uv(width=width, height=height),
         | 
| 885 | 
            -
                                normal,
         | 
| 886 | 
            -
                                mask=mask if mask is not None else np.ones((height, width), dtype=bool),
         | 
| 887 | 
            -
                                tri=True
         | 
| 888 | 
            -
                            )
         | 
| 889 | 
            -
                        
         | 
| 890 | 
            -
                        vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
         | 
| 891 | 
            -
                        if vertex_normals is not None:
         | 
| 892 | 
            -
                            vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
         | 
| 893 | 
            -
                        
         | 
| 894 | 
            -
                        view_mesh = trimesh.Trimesh(
         | 
| 895 | 
            -
                            vertices=vertices,
         | 
| 896 | 
            -
                            faces=faces,
         | 
| 897 | 
            -
                            vertex_normals=vertex_normals,
         | 
| 898 | 
            -
                            vertex_colors=(vertex_colors * 255).astype(np.uint8),
         | 
| 899 | 
            -
                            process=False
         | 
| 900 | 
            -
                        )
         | 
| 901 | 
            -
                        
         | 
| 902 | 
            -
                        all_meshes.append(view_mesh)
         | 
| 903 | 
            -
                        print(f"   View {view_idx + 1}: {len(vertices):,} vertices, {len(faces):,} faces")
         | 
| 904 | 
            -
                    
         | 
| 905 | 
            -
                    if len(all_meshes) == 0:
         | 
| 906 | 
            -
                        print("⚠️ No mesh generated")
         | 
| 907 | 
            -
                        return None
         | 
| 908 | 
            -
                    
         | 
| 909 | 
            -
                    print("   Fusing all views...")
         | 
| 910 | 
            -
                    combined_mesh = trimesh.util.concatenate(all_meshes)
         | 
| 911 | 
            -
                    
         | 
| 912 | 
            -
                    glb_path = os.path.join(target_dir, 'segmented_mesh.glb')
         | 
| 913 | 
            -
                    combined_mesh.export(glb_path)
         | 
| 914 | 
            -
                    
         | 
| 915 | 
            -
                    print(f"✅ Multi-view segmented mesh saved: {glb_path}")
         | 
| 916 | 
            -
                    print(f"   Total: {len(combined_mesh.vertices):,} vertices, {len(combined_mesh.faces):,} faces")
         | 
| 917 | 
            -
                    print(f"   {len(unique_objects)} unique objects")
         | 
| 918 | 
            -
                    
         | 
| 919 | 
            -
                    return glb_path
         | 
| 920 | 
            -
                    
         | 
| 921 | 
            -
                except Exception as e:
         | 
| 922 | 
            -
                    print(f"❌ Failed to generate multi-view mesh: {e}")
         | 
| 923 | 
            -
                    import traceback
         | 
| 924 | 
            -
                    traceback.print_exc()
         | 
| 925 | 
            -
                    return None
         | 
| 926 | 
            -
             | 
| 927 | 
            -
             | 
| 928 | 
            -
            # ============================================================================
         | 
| 929 | 
            -
            # Core Model Inference
         | 
| 930 | 
            -
            # ============================================================================
         | 
| 931 | 
            -
             | 
| 932 | 
             
            @spaces.GPU(duration=120)
         | 
| 933 | 
             
            def run_model(
         | 
| 934 | 
             
                target_dir,
         | 
| @@ -936,24 +94,24 @@ def run_model( | |
| 936 | 
             
                mask_edges=True,
         | 
| 937 | 
             
                filter_black_bg=False,
         | 
| 938 | 
             
                filter_white_bg=False,
         | 
| 939 | 
            -
                enable_segmentation=False,
         | 
| 940 | 
            -
                text_prompt=DEFAULT_TEXT_PROMPT,
         | 
| 941 | 
             
                progress=gr.Progress(),
         | 
| 942 | 
             
            ):
         | 
| 943 | 
             
                """
         | 
| 944 | 
            -
                Run the MapAnything model  | 
| 945 | 
             
                """
         | 
| 946 | 
             
                global model
         | 
| 947 | 
            -
                import torch
         | 
| 948 |  | 
| 949 | 
            -
                 | 
| 950 | 
             
                print(f"Processing images from {target_dir}")
         | 
| 951 |  | 
|  | |
|  | |
| 952 | 
             
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 953 | 
             
                device = torch.device(device)
         | 
| 954 |  | 
| 955 | 
            -
                # Initialize  | 
| 956 | 
            -
                progress(0.05, desc="📥  | 
| 957 | 
             
                if model is None:
         | 
| 958 | 
             
                    model = initialize_mapanything_model(high_level_config, device)
         | 
| 959 | 
             
                else:
         | 
| @@ -961,46 +119,8 @@ def run_model( | |
| 961 |  | 
| 962 | 
             
                model.eval()
         | 
| 963 |  | 
| 964 | 
            -
                # Load  | 
| 965 | 
            -
                 | 
| 966 | 
            -
                    progress(0.1, desc="🎯 加载分割模型 (CPU)...")
         | 
| 967 | 
            -
                    print(f"\n{'='*70}")
         | 
| 968 | 
            -
                    print(f"🎯 分割模型加载开始... (方案: {SEGMENTATION_METHOD})")
         | 
| 969 | 
            -
                    print(f"{'='*70}")
         | 
| 970 | 
            -
                    
         | 
| 971 | 
            -
                    if SEGMENTATION_METHOD == "segformer":
         | 
| 972 | 
            -
                        # 方案1: SegFormer (最轻量,~14MB,最快)
         | 
| 973 | 
            -
                        print("📌 使用方案: SegFormer (轻量级,无需文本提示)")
         | 
| 974 | 
            -
                        load_segformer_model("cpu")
         | 
| 975 | 
            -
                        if segformer_model is None:
         | 
| 976 | 
            -
                            print("❌ SegFormer 模型加载失败!")
         | 
| 977 | 
            -
                            raise RuntimeError("SegFormer 模型加载失败,请检查网络连接")
         | 
| 978 | 
            -
                            
         | 
| 979 | 
            -
                    elif SEGMENTATION_METHOD == "maskformer":
         | 
| 980 | 
            -
                        # 方案2: MaskFormer (中等,~100MB)
         | 
| 981 | 
            -
                        print("📌 使用方案: MaskFormer (实例分割)")
         | 
| 982 | 
            -
                        load_maskformer_model("cpu")
         | 
| 983 | 
            -
                        if maskformer_model is None:
         | 
| 984 | 
            -
                            print("❌ MaskFormer 模型加载失败!")
         | 
| 985 | 
            -
                            raise RuntimeError("MaskFormer 模型加载失败,请检查网络连接")
         | 
| 986 | 
            -
                            
         | 
| 987 | 
            -
                    else:  # "grounding_sam"
         | 
| 988 | 
            -
                        # 方案3: GroundingDINO + SAM (最强,~110MB,需要文本提示)
         | 
| 989 | 
            -
                        print("📌 使用方案: GroundingDINO + SAM (文本提示驱动)")
         | 
| 990 | 
            -
                        load_grounding_dino_model("cpu")
         | 
| 991 | 
            -
                        load_sam_model("cpu")
         | 
| 992 | 
            -
                        if grounding_dino_model is None:
         | 
| 993 | 
            -
                            print("❌ GroundingDINO 模型加载失败!")
         | 
| 994 | 
            -
                            raise RuntimeError("GroundingDINO 模型加载失败,请检查网络连接")
         | 
| 995 | 
            -
                        if sam_predictor is None:
         | 
| 996 | 
            -
                            print("❌ SAM 模型加载失败!")
         | 
| 997 | 
            -
                            raise RuntimeError("SAM 模型加载失败,请检查网络连接")
         | 
| 998 | 
            -
                    
         | 
| 999 | 
            -
                    print(f"✅ 分割模型加载成功")
         | 
| 1000 | 
            -
                    print(f"{'='*70}\n")
         | 
| 1001 | 
            -
             | 
| 1002 | 
            -
                # Load images
         | 
| 1003 | 
            -
                progress(0.15, desc="📷 加载图片...")
         | 
| 1004 | 
             
                print("Loading images...")
         | 
| 1005 | 
             
                image_folder_path = os.path.join(target_dir, "images")
         | 
| 1006 | 
             
                views = load_images(image_folder_path)
         | 
| @@ -1010,15 +130,22 @@ def run_model( | |
| 1010 | 
             
                    raise ValueError("No images found. Check your upload.")
         | 
| 1011 |  | 
| 1012 | 
             
                # Run model inference
         | 
| 1013 | 
            -
                 | 
|  | |
|  | |
| 1014 | 
             
                print("Running inference...")
         | 
|  | |
|  | |
| 1015 | 
             
                outputs = model.infer(
         | 
| 1016 | 
             
                    views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
         | 
| 1017 | 
             
                )
         | 
|  | |
| 1018 |  | 
| 1019 | 
            -
                # Convert predictions
         | 
| 1020 | 
            -
                progress(0. | 
| 1021 | 
             
                predictions = {}
         | 
|  | |
|  | |
| 1022 | 
             
                extrinsic_list = []
         | 
| 1023 | 
             
                intrinsic_list = []
         | 
| 1024 | 
             
                world_points_list = []
         | 
| @@ -1026,158 +153,81 @@ def run_model( | |
| 1026 | 
             
                images_list = []
         | 
| 1027 | 
             
                final_mask_list = []
         | 
| 1028 |  | 
| 1029 | 
            -
                 | 
| 1030 | 
            -
             | 
| 1031 | 
            -
                     | 
| 1032 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 1033 |  | 
|  | |
| 1034 | 
             
                    pts3d_computed, valid_mask = depthmap_to_world_frame(
         | 
| 1035 | 
             
                        depthmap_torch, intrinsics_torch, camera_pose_torch
         | 
| 1036 | 
             
                    )
         | 
| 1037 |  | 
|  | |
|  | |
| 1038 | 
             
                    if "mask" in pred:
         | 
| 1039 | 
             
                        mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
         | 
| 1040 | 
             
                    else:
         | 
|  | |
| 1041 | 
             
                        mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
         | 
| 1042 |  | 
|  | |
| 1043 | 
             
                    mask = mask & valid_mask.cpu().numpy()
         | 
|  | |
| 1044 | 
             
                    image = pred["img_no_norm"][0].cpu().numpy()
         | 
| 1045 |  | 
|  | |
| 1046 | 
             
                    extrinsic_list.append(camera_pose_torch.cpu().numpy())
         | 
| 1047 | 
             
                    intrinsic_list.append(intrinsics_torch.cpu().numpy())
         | 
| 1048 | 
             
                    world_points_list.append(pts3d_computed.cpu().numpy())
         | 
| 1049 | 
             
                    depth_maps_list.append(depthmap_torch.cpu().numpy())
         | 
| 1050 | 
            -
                    images_list.append(image)
         | 
| 1051 | 
            -
                    final_mask_list.append(mask)
         | 
| 1052 |  | 
|  | |
|  | |
| 1053 | 
             
                predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
         | 
|  | |
|  | |
| 1054 | 
             
                predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
         | 
|  | |
|  | |
| 1055 | 
             
                predictions["world_points"] = np.stack(world_points_list, axis=0)
         | 
| 1056 |  | 
|  | |
| 1057 | 
             
                depth_maps = np.stack(depth_maps_list, axis=0)
         | 
|  | |
| 1058 | 
             
                if len(depth_maps.shape) == 3:
         | 
| 1059 | 
             
                    depth_maps = depth_maps[..., np.newaxis]
         | 
|  | |
| 1060 | 
             
                predictions["depth"] = depth_maps
         | 
| 1061 |  | 
|  | |
| 1062 | 
             
                predictions["images"] = np.stack(images_list, axis=0)
         | 
|  | |
|  | |
| 1063 | 
             
                predictions["final_mask"] = np.stack(final_mask_list, axis=0)
         | 
| 1064 |  | 
| 1065 | 
            -
                # Process visualization  | 
| 1066 | 
            -
                progress(0. | 
| 1067 | 
             
                processed_data = process_predictions_for_visualization(
         | 
| 1068 | 
             
                    predictions, views, high_level_config, filter_black_bg, filter_white_bg
         | 
| 1069 | 
             
                )
         | 
| 1070 |  | 
| 1071 | 
            -
                #  | 
| 1072 | 
            -
                segmented_glb = None
         | 
| 1073 | 
            -
                if enable_segmentation:
         | 
| 1074 | 
            -
                    progress(0.65, desc="🎯 开始物体分割...")
         | 
| 1075 | 
            -
                    print(f"\n{'='*70}")
         | 
| 1076 | 
            -
                    print(f"🎯 开始物体分割... (方案: {SEGMENTATION_METHOD})")
         | 
| 1077 | 
            -
                    print(f"📐 最小掩码面积: {MIN_MASK_AREA} px")
         | 
| 1078 | 
            -
                    if SEGMENTATION_METHOD == "grounding_sam":
         | 
| 1079 | 
            -
                        print(f"🔍 检测提示词: {text_prompt[:100]}...")
         | 
| 1080 | 
            -
                        print(f"📊 置信度阈值: {GROUNDING_DINO_BOX_THRESHOLD}")
         | 
| 1081 | 
            -
                    print(f"{'='*70}\n")
         | 
| 1082 | 
            -
             | 
| 1083 | 
            -
                    all_view_detections = []
         | 
| 1084 | 
            -
                    all_view_masks = []
         | 
| 1085 | 
            -
             | 
| 1086 | 
            -
                    for view_idx, ref_image in enumerate(images_list):
         | 
| 1087 | 
            -
                        progress(0.65 + (view_idx / len(images_list)) * 0.2, 
         | 
| 1088 | 
            -
                                desc=f"🔍 检测视图 {view_idx + 1}/{len(images_list)}...")
         | 
| 1089 | 
            -
                        print(f"\n📸 Processing view {view_idx + 1}/{len(images_list)}...")
         | 
| 1090 | 
            -
             | 
| 1091 | 
            -
                        if ref_image.dtype != np.uint8:
         | 
| 1092 | 
            -
                            ref_image_np = (ref_image * 255).astype(np.uint8)
         | 
| 1093 | 
            -
                        else:
         | 
| 1094 | 
            -
                            ref_image_np = ref_image
         | 
| 1095 | 
            -
             | 
| 1096 | 
            -
                        # 根据分割方法选择不同的处理流程
         | 
| 1097 | 
            -
                        if SEGMENTATION_METHOD == "segformer":
         | 
| 1098 | 
            -
                            # SegFormer: 直接语义分割,无需文本提示
         | 
| 1099 | 
            -
                            detections, masks = run_segformer_segmentation(ref_image_np, "cpu")
         | 
| 1100 | 
            -
                            print(f"   ✓ 检测到 {len(detections)} 个物体")
         | 
| 1101 | 
            -
                            
         | 
| 1102 | 
            -
                            if len(detections) > 0:
         | 
| 1103 | 
            -
                                for i, det in enumerate(detections):
         | 
| 1104 | 
            -
                                    print(f"      物体 {i+1}: {det['label']}")
         | 
| 1105 | 
            -
                                
         | 
| 1106 | 
            -
                                points3d = world_points_list[view_idx]
         | 
| 1107 | 
            -
                                for det_idx, (det, mask) in enumerate(zip(detections, masks)):
         | 
| 1108 | 
            -
                                    center_3d = compute_object_3d_center(points3d, mask)
         | 
| 1109 | 
            -
                                    det['center_3d'] = center_3d
         | 
| 1110 | 
            -
                                    det['mask_2d'] = mask
         | 
| 1111 | 
            -
                                
         | 
| 1112 | 
            -
                                all_view_detections.append(detections)
         | 
| 1113 | 
            -
                                all_view_masks.append(masks)
         | 
| 1114 | 
            -
                            else:
         | 
| 1115 | 
            -
                                all_view_detections.append([])
         | 
| 1116 | 
            -
                                all_view_masks.append([])
         | 
| 1117 | 
            -
                                
         | 
| 1118 | 
            -
                        elif SEGMENTATION_METHOD == "grounding_sam":
         | 
| 1119 | 
            -
                            # GroundingDINO + SAM: 文本提示驱动
         | 
| 1120 | 
            -
                            detections = run_grounding_dino_detection(ref_image_np, text_prompt, "cpu")
         | 
| 1121 | 
            -
                            print(f"   ✓ 检测到 {len(detections)} 个物体")
         | 
| 1122 | 
            -
                            
         | 
| 1123 | 
            -
                            if len(detections) > 0:
         | 
| 1124 | 
            -
                                for i, det in enumerate(detections):
         | 
| 1125 | 
            -
                                    print(f"      物体 {i+1}: {det['label']} (置信度: {det['confidence']:.2f})")
         | 
| 1126 | 
            -
                                boxes = [d['bbox'] for d in detections]
         | 
| 1127 | 
            -
                                masks = run_sam_refinement(ref_image_np, boxes)
         | 
| 1128 | 
            -
             | 
| 1129 | 
            -
                                points3d = world_points_list[view_idx]
         | 
| 1130 | 
            -
                                for det_idx, (det, mask) in enumerate(zip(detections, masks)):
         | 
| 1131 | 
            -
                                    center_3d = compute_object_3d_center(points3d, mask)
         | 
| 1132 | 
            -
                                    det['center_3d'] = center_3d
         | 
| 1133 | 
            -
                                    det['mask_2d'] = mask
         | 
| 1134 | 
            -
             | 
| 1135 | 
            -
                                all_view_detections.append(detections)
         | 
| 1136 | 
            -
                            all_view_masks.append(masks)
         | 
| 1137 | 
            -
                        else:
         | 
| 1138 | 
            -
                            all_view_detections.append([])
         | 
| 1139 | 
            -
                            all_view_masks.append([])
         | 
| 1140 | 
            -
             | 
| 1141 | 
            -
                    # Match objects across views
         | 
| 1142 | 
            -
                    total_detections = sum(len(dets) for dets in all_view_detections)
         | 
| 1143 | 
            -
                    print(f"\n📊 总检测数: {total_detections}")
         | 
| 1144 | 
            -
                    
         | 
| 1145 | 
            -
                    if any(len(dets) > 0 for dets in all_view_detections):
         | 
| 1146 | 
            -
                        progress(0.85, desc="🔗 匹配跨视图物体...")
         | 
| 1147 | 
            -
                        object_id_map, unique_objects = match_objects_across_views(all_view_detections)
         | 
| 1148 | 
            -
             | 
| 1149 | 
            -
                        # Generate segmented mesh
         | 
| 1150 | 
            -
                        progress(0.9, desc="🏗️ 生成分割3D模型...")
         | 
| 1151 | 
            -
                        segmented_glb = create_multi_view_segmented_mesh(
         | 
| 1152 | 
            -
                            processed_data, all_view_detections, all_view_masks,
         | 
| 1153 | 
            -
                            object_id_map, unique_objects, target_dir
         | 
| 1154 | 
            -
                        )
         | 
| 1155 | 
            -
                        
         | 
| 1156 | 
            -
                        if segmented_glb:
         | 
| 1157 | 
            -
                            print(f"✅ 分割3D模型已生成: {segmented_glb}")
         | 
| 1158 | 
            -
                        else:
         | 
| 1159 | 
            -
                            print(f"⚠️ 分割3D模型生成失败")
         | 
| 1160 | 
            -
                    else:
         | 
| 1161 | 
            -
                        print(f"\n{'='*70}")
         | 
| 1162 | 
            -
                        print(f"⚠️ 未检测到任何物体,无法生成分割模型")
         | 
| 1163 | 
            -
                        print(f"\n💡 调试提示:")
         | 
| 1164 | 
            -
                        print(f"   1. 检查检测提示词是否准确(当前: {text_prompt[:50]}...)")
         | 
| 1165 | 
            -
                        print(f"   2. 当前置信度阈值: {GROUNDING_DINO_BOX_THRESHOLD}")
         | 
| 1166 | 
            -
                        print(f"   3. 尝试更通用的提示词,如: {COMMON_OBJECTS_PROMPT[:80]}...")
         | 
| 1167 | 
            -
                        print(f"   4. 确保图片中有清晰可见的物体")
         | 
| 1168 | 
            -
                        print(f"{'='*70}\n")
         | 
| 1169 | 
            -
             | 
| 1170 | 
            -
                # Cleanup
         | 
| 1171 | 
             
                progress(0.95, desc="🧹 清理内存...")
         | 
| 1172 | 
             
                torch.cuda.empty_cache()
         | 
| 1173 |  | 
| 1174 | 
            -
                 | 
| 1175 | 
            -
                 | 
|  | |
| 1176 |  | 
|  | |
| 1177 |  | 
| 1178 | 
            -
            # ============================================================================
         | 
| 1179 | 
            -
            # Helper Functions (from app.py)
         | 
| 1180 | 
            -
            # ============================================================================
         | 
| 1181 |  | 
| 1182 | 
             
            def update_view_selectors(processed_data):
         | 
| 1183 | 
             
                """Update view selector dropdowns based on available views"""
         | 
| @@ -1188,9 +238,9 @@ def update_view_selectors(processed_data): | |
| 1188 | 
             
                    choices = [f"View {i + 1}" for i in range(num_views)]
         | 
| 1189 |  | 
| 1190 | 
             
                return (
         | 
| 1191 | 
            -
                    gr.Dropdown(choices=choices, value=choices[0]),
         | 
| 1192 | 
            -
                    gr.Dropdown(choices=choices, value=choices[0]),
         | 
| 1193 | 
            -
                    gr.Dropdown(choices=choices, value=choices[0]),
         | 
| 1194 | 
             
                )
         | 
| 1195 |  | 
| 1196 |  | 
| @@ -1228,24 +278,33 @@ def update_measure_view(processed_data, view_index): | |
| 1228 | 
             
                """Update measure view for a specific view index with mask overlay"""
         | 
| 1229 | 
             
                view_data = get_view_data_by_index(processed_data, view_index)
         | 
| 1230 | 
             
                if view_data is None:
         | 
| 1231 | 
            -
                    return None, []
         | 
| 1232 |  | 
|  | |
| 1233 | 
             
                image = view_data["image"].copy()
         | 
| 1234 |  | 
|  | |
| 1235 | 
             
                if image.dtype != np.uint8:
         | 
| 1236 | 
             
                    if image.max() <= 1.0:
         | 
| 1237 | 
             
                        image = (image * 255).astype(np.uint8)
         | 
| 1238 | 
             
                    else:
         | 
| 1239 | 
             
                        image = image.astype(np.uint8)
         | 
| 1240 |  | 
|  | |
| 1241 | 
             
                if view_data["mask"] is not None:
         | 
| 1242 | 
             
                    mask = view_data["mask"]
         | 
| 1243 | 
            -
             | 
|  | |
|  | |
|  | |
| 1244 |  | 
| 1245 | 
             
                    if invalid_mask.any():
         | 
|  | |
| 1246 | 
             
                        overlay_color = np.array([255, 220, 220], dtype=np.uint8)
         | 
| 1247 | 
            -
             | 
| 1248 | 
            -
                         | 
|  | |
|  | |
| 1249 | 
             
                            image[:, :, c] = np.where(
         | 
| 1250 | 
             
                                invalid_mask,
         | 
| 1251 | 
             
                                (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
         | 
| @@ -1260,6 +319,7 @@ def navigate_depth_view(processed_data, current_selector_value, direction): | |
| 1260 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 1261 | 
             
                    return "View 1", None
         | 
| 1262 |  | 
|  | |
| 1263 | 
             
                try:
         | 
| 1264 | 
             
                    current_view = int(current_selector_value.split()[1]) - 1
         | 
| 1265 | 
             
                except:
         | 
| @@ -1279,6 +339,7 @@ def navigate_normal_view(processed_data, current_selector_value, direction): | |
| 1279 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 1280 | 
             
                    return "View 1", None
         | 
| 1281 |  | 
|  | |
| 1282 | 
             
                try:
         | 
| 1283 | 
             
                    current_view = int(current_selector_value.split()[1]) - 1
         | 
| 1284 | 
             
                except:
         | 
| @@ -1298,6 +359,7 @@ def navigate_measure_view(processed_data, current_selector_value, direction): | |
| 1298 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 1299 | 
             
                    return "View 1", None, []
         | 
| 1300 |  | 
|  | |
| 1301 | 
             
                try:
         | 
| 1302 | 
             
                    current_view = int(current_selector_value.split()[1]) - 1
         | 
| 1303 | 
             
                except:
         | 
| @@ -1317,6 +379,7 @@ def populate_visualization_tabs(processed_data): | |
| 1317 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 1318 | 
             
                    return None, None, None, []
         | 
| 1319 |  | 
|  | |
| 1320 | 
             
                depth_vis = update_depth_view(processed_data, 0)
         | 
| 1321 | 
             
                normal_vis = update_normal_view(processed_data, 0)
         | 
| 1322 | 
             
                measure_img, _ = update_measure_view(processed_data, 0)
         | 
| @@ -1324,6 +387,9 @@ def populate_visualization_tabs(processed_data): | |
| 1324 | 
             
                return depth_vis, normal_vis, measure_img, []
         | 
| 1325 |  | 
| 1326 |  | 
|  | |
|  | |
|  | |
| 1327 | 
             
            def handle_uploads(unified_upload, s_time_interval=1.0):
         | 
| 1328 | 
             
                """
         | 
| 1329 | 
             
                Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
         | 
| @@ -1333,10 +399,12 @@ def handle_uploads(unified_upload, s_time_interval=1.0): | |
| 1333 | 
             
                gc.collect()
         | 
| 1334 | 
             
                torch.cuda.empty_cache()
         | 
| 1335 |  | 
|  | |
| 1336 | 
             
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
         | 
| 1337 | 
             
                target_dir = f"input_images_{timestamp}"
         | 
| 1338 | 
             
                target_dir_images = os.path.join(target_dir, "images")
         | 
| 1339 |  | 
|  | |
| 1340 | 
             
                if os.path.exists(target_dir):
         | 
| 1341 | 
             
                    shutil.rmtree(target_dir)
         | 
| 1342 | 
             
                os.makedirs(target_dir)
         | 
| @@ -1344,6 +412,7 @@ def handle_uploads(unified_upload, s_time_interval=1.0): | |
| 1344 |  | 
| 1345 | 
             
                image_paths = []
         | 
| 1346 |  | 
|  | |
| 1347 | 
             
                if unified_upload is not None:
         | 
| 1348 | 
             
                    for file_data in unified_upload:
         | 
| 1349 | 
             
                        if isinstance(file_data, dict) and "name" in file_data:
         | 
| @@ -1353,13 +422,23 @@ def handle_uploads(unified_upload, s_time_interval=1.0): | |
| 1353 |  | 
| 1354 | 
             
                        file_ext = os.path.splitext(file_path)[1].lower()
         | 
| 1355 |  | 
|  | |
| 1356 | 
             
                        video_extensions = [
         | 
| 1357 | 
            -
                            ".mp4", | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1358 | 
             
                        ]
         | 
| 1359 | 
             
                        if file_ext in video_extensions:
         | 
|  | |
| 1360 | 
             
                            vs = cv2.VideoCapture(file_path)
         | 
| 1361 | 
             
                            fps = vs.get(cv2.CAP_PROP_FPS)
         | 
| 1362 | 
            -
                            frame_interval = int(fps * s_time_interval)
         | 
| 1363 |  | 
| 1364 | 
             
                            count = 0
         | 
| 1365 | 
             
                            video_frame_num = 0
         | 
| @@ -1369,6 +448,7 @@ def handle_uploads(unified_upload, s_time_interval=1.0): | |
| 1369 | 
             
                                    break
         | 
| 1370 | 
             
                                count += 1
         | 
| 1371 | 
             
                                if count % frame_interval == 0:
         | 
|  | |
| 1372 | 
             
                                    base_name = os.path.splitext(os.path.basename(file_path))[0]
         | 
| 1373 | 
             
                                    image_path = os.path.join(
         | 
| 1374 | 
             
                                        target_dir_images, f"{base_name}_{video_frame_num:06}.png"
         | 
| @@ -1377,52 +457,82 @@ def handle_uploads(unified_upload, s_time_interval=1.0): | |
| 1377 | 
             
                                    image_paths.append(image_path)
         | 
| 1378 | 
             
                                    video_frame_num += 1
         | 
| 1379 | 
             
                            vs.release()
         | 
| 1380 | 
            -
                            print( | 
|  | |
|  | |
| 1381 |  | 
| 1382 | 
             
                        else:
         | 
|  | |
|  | |
| 1383 | 
             
                            if file_ext in [".heic", ".heif"]:
         | 
|  | |
| 1384 | 
             
                                try:
         | 
| 1385 | 
             
                                    with Image.open(file_path) as img:
         | 
|  | |
| 1386 | 
             
                                        if img.mode not in ("RGB", "L"):
         | 
| 1387 | 
             
                                            img = img.convert("RGB")
         | 
| 1388 |  | 
|  | |
| 1389 | 
             
                                        base_name = os.path.splitext(os.path.basename(file_path))[0]
         | 
| 1390 | 
            -
                                        dst_path = os.path.join( | 
|  | |
|  | |
| 1391 |  | 
|  | |
| 1392 | 
             
                                        img.save(dst_path, "JPEG", quality=95)
         | 
| 1393 | 
             
                                        image_paths.append(dst_path)
         | 
| 1394 | 
            -
                                        print( | 
|  | |
|  | |
| 1395 | 
             
                                except Exception as e:
         | 
| 1396 | 
             
                                    print(f"Error converting HEIC file {file_path}: {e}")
         | 
| 1397 | 
            -
                                     | 
|  | |
|  | |
|  | |
| 1398 | 
             
                                    shutil.copy(file_path, dst_path)
         | 
| 1399 | 
             
                                    image_paths.append(dst_path)
         | 
| 1400 | 
             
                            else:
         | 
| 1401 | 
            -
                                 | 
|  | |
|  | |
|  | |
| 1402 | 
             
                                shutil.copy(file_path, dst_path)
         | 
| 1403 | 
             
                                image_paths.append(dst_path)
         | 
| 1404 |  | 
|  | |
| 1405 | 
             
                image_paths = sorted(image_paths)
         | 
| 1406 |  | 
| 1407 | 
             
                end_time = time.time()
         | 
| 1408 | 
            -
                print( | 
|  | |
|  | |
| 1409 | 
             
                return target_dir, image_paths
         | 
| 1410 |  | 
| 1411 |  | 
|  | |
|  | |
|  | |
| 1412 | 
             
            def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
         | 
| 1413 | 
            -
                """ | 
|  | |
|  | |
|  | |
|  | |
| 1414 | 
             
                if not input_video and not input_images:
         | 
| 1415 | 
            -
                    return None, None, None, None | 
| 1416 | 
             
                target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
         | 
| 1417 | 
             
                return (
         | 
| 1418 | 
            -
                    None,
         | 
| 1419 | 
             
                    None,
         | 
| 1420 | 
             
                    target_dir,
         | 
| 1421 | 
             
                    image_paths,
         | 
| 1422 | 
            -
                    " | 
| 1423 | 
             
                )
         | 
| 1424 |  | 
| 1425 |  | 
|  | |
|  | |
|  | |
| 1426 | 
             
            @spaces.GPU(duration=120)
         | 
| 1427 | 
             
            def gradio_demo(
         | 
| 1428 | 
             
                target_dir,
         | 
| @@ -1432,19 +542,20 @@ def gradio_demo( | |
| 1432 | 
             
                filter_white_bg=False,
         | 
| 1433 | 
             
                apply_mask=True,
         | 
| 1434 | 
             
                show_mesh=True,
         | 
| 1435 | 
            -
                enable_segmentation=False,
         | 
| 1436 | 
            -
                text_prompt=DEFAULT_TEXT_PROMPT,
         | 
| 1437 | 
             
                progress=gr.Progress(),
         | 
| 1438 | 
             
            ):
         | 
| 1439 | 
            -
                """ | 
|  | |
|  | |
| 1440 | 
             
                if not os.path.isdir(target_dir) or target_dir == "None":
         | 
| 1441 | 
            -
                    return None,  | 
| 1442 |  | 
| 1443 | 
             
                progress(0, desc="🔄 准备重建...")
         | 
| 1444 | 
             
                start_time = time.time()
         | 
| 1445 | 
             
                gc.collect()
         | 
| 1446 | 
             
                torch.cuda.empty_cache()
         | 
| 1447 |  | 
|  | |
| 1448 | 
             
                target_dir_images = os.path.join(target_dir, "images")
         | 
| 1449 | 
             
                all_files = (
         | 
| 1450 | 
             
                    sorted(os.listdir(target_dir_images))
         | 
| @@ -1454,94 +565,92 @@ def gradio_demo( | |
| 1454 | 
             
                all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
         | 
| 1455 | 
             
                frame_filter_choices = ["All"] + all_files
         | 
| 1456 |  | 
| 1457 | 
            -
                progress(0.05, desc="🚀 运行 MapAnything 模型...")
         | 
| 1458 | 
            -
                print(" | 
| 1459 | 
             
                with torch.no_grad():
         | 
| 1460 | 
            -
                    predictions, processed_data | 
| 1461 | 
            -
                        target_dir, apply_mask, True, filter_black_bg, filter_white_bg,
         | 
| 1462 | 
            -
                        enable_segmentation, text_prompt, progress
         | 
| 1463 | 
             
                    )
         | 
| 1464 |  | 
| 1465 | 
            -
                #  | 
| 1466 | 
             
                progress(0.92, desc="💾 保存预测结果...")
         | 
| 1467 | 
             
                prediction_save_path = os.path.join(target_dir, "predictions.npz")
         | 
| 1468 | 
             
                np.savez(prediction_save_path, **predictions)
         | 
| 1469 |  | 
|  | |
| 1470 | 
             
                if frame_filter is None:
         | 
| 1471 | 
             
                    frame_filter = "All"
         | 
| 1472 |  | 
| 1473 | 
            -
                #  | 
| 1474 | 
            -
                progress(0.93, desc="🏗️  | 
| 1475 | 
             
                glbfile = os.path.join(
         | 
| 1476 | 
             
                    target_dir,
         | 
| 1477 | 
            -
                    f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb",
         | 
| 1478 | 
             
                )
         | 
| 1479 |  | 
| 1480 | 
            -
                #  | 
| 1481 | 
             
                glbscene = predictions_to_glb(
         | 
| 1482 | 
             
                    predictions,
         | 
| 1483 | 
             
                    filter_by_frames=frame_filter,
         | 
| 1484 | 
             
                    show_cam=show_cam,
         | 
| 1485 | 
             
                    mask_black_bg=filter_black_bg,
         | 
| 1486 | 
             
                    mask_white_bg=filter_white_bg,
         | 
| 1487 | 
            -
                    as_mesh=show_mesh,
         | 
| 1488 | 
             
                )
         | 
| 1489 | 
             
                glbscene.export(file_obj=glbfile)
         | 
| 1490 |  | 
| 1491 | 
            -
                #  | 
| 1492 | 
             
                progress(0.96, desc="🧹 清理内存...")
         | 
| 1493 | 
             
                del predictions
         | 
| 1494 | 
             
                gc.collect()
         | 
| 1495 | 
             
                torch.cuda.empty_cache()
         | 
| 1496 |  | 
| 1497 | 
             
                end_time = time.time()
         | 
| 1498 | 
            -
                 | 
| 1499 | 
            -
                 | 
|  | |
| 1500 |  | 
| 1501 | 
            -
                # Populate visualization tabs
         | 
| 1502 | 
             
                progress(0.98, desc="🎨 生成可视化...")
         | 
| 1503 | 
             
                depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
         | 
| 1504 | 
             
                    processed_data
         | 
| 1505 | 
             
                )
         | 
| 1506 |  | 
| 1507 | 
            -
                # Update view selectors
         | 
| 1508 | 
             
                depth_selector, normal_selector, measure_selector = update_view_selectors(
         | 
| 1509 | 
             
                    processed_data
         | 
| 1510 | 
             
                )
         | 
| 1511 |  | 
| 1512 | 
             
                progress(1.0, desc="✅ 全部完成!")
         | 
| 1513 | 
            -
                
         | 
| 1514 | 
            -
                # 添加分割状态信息
         | 
| 1515 | 
            -
                if enable_segmentation:
         | 
| 1516 | 
            -
                    if segmented_glb:
         | 
| 1517 | 
            -
                        log_msg += f"\n🎨 分割模型已生成"
         | 
| 1518 | 
            -
                    else:
         | 
| 1519 | 
            -
                        log_msg += f"\n⚠️ 未检测到物体,无分割模型"
         | 
| 1520 |  | 
| 1521 | 
             
                return (
         | 
| 1522 | 
             
                    glbfile,
         | 
| 1523 | 
            -
                    segmented_glb,
         | 
| 1524 | 
             
                    log_msg,
         | 
| 1525 | 
             
                    gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
         | 
| 1526 | 
             
                    processed_data,
         | 
| 1527 | 
             
                    depth_vis,
         | 
| 1528 | 
             
                    normal_vis,
         | 
| 1529 | 
             
                    measure_img,
         | 
| 1530 | 
            -
                    "",
         | 
| 1531 | 
             
                    depth_selector,
         | 
| 1532 | 
             
                    normal_selector,
         | 
| 1533 | 
             
                    measure_selector,
         | 
| 1534 | 
             
                )
         | 
| 1535 |  | 
| 1536 |  | 
|  | |
|  | |
|  | |
| 1537 | 
             
            def colorize_depth(depth_map, mask=None):
         | 
| 1538 | 
             
                """Convert depth map to colorized visualization with optional mask"""
         | 
| 1539 | 
             
                if depth_map is None:
         | 
| 1540 | 
             
                    return None
         | 
| 1541 |  | 
|  | |
| 1542 | 
             
                depth_normalized = depth_map.copy()
         | 
| 1543 | 
             
                valid_mask = depth_normalized > 0
         | 
| 1544 |  | 
|  | |
| 1545 | 
             
                if mask is not None:
         | 
| 1546 | 
             
                    valid_mask = valid_mask & mask
         | 
| 1547 |  | 
| @@ -1552,12 +661,14 @@ def colorize_depth(depth_map, mask=None): | |
| 1552 |  | 
| 1553 | 
             
                    depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
         | 
| 1554 |  | 
|  | |
| 1555 | 
             
                import matplotlib.pyplot as plt
         | 
| 1556 |  | 
| 1557 | 
             
                colormap = plt.cm.turbo_r
         | 
| 1558 | 
             
                colored = colormap(depth_normalized)
         | 
| 1559 | 
             
                colored = (colored[:, :, :3] * 255).astype(np.uint8)
         | 
| 1560 |  | 
|  | |
| 1561 | 
             
                colored[~valid_mask] = [255, 255, 255]
         | 
| 1562 |  | 
| 1563 | 
             
                return colored
         | 
| @@ -1568,12 +679,15 @@ def colorize_normal(normal_map, mask=None): | |
| 1568 | 
             
                if normal_map is None:
         | 
| 1569 | 
             
                    return None
         | 
| 1570 |  | 
|  | |
| 1571 | 
             
                normal_vis = normal_map.copy()
         | 
| 1572 |  | 
|  | |
| 1573 | 
             
                if mask is not None:
         | 
| 1574 | 
             
                    invalid_mask = ~mask
         | 
| 1575 | 
            -
                    normal_vis[invalid_mask] = [0, 0, 0]
         | 
| 1576 |  | 
|  | |
| 1577 | 
             
                normal_vis = (normal_vis + 1.0) / 2.0
         | 
| 1578 | 
             
                normal_vis = (normal_vis * 255).astype(np.uint8)
         | 
| 1579 |  | 
| @@ -1586,11 +700,15 @@ def process_predictions_for_visualization( | |
| 1586 | 
             
                """Extract depth, normal, and 3D points from predictions for visualization"""
         | 
| 1587 | 
             
                processed_data = {}
         | 
| 1588 |  | 
|  | |
| 1589 | 
             
                for view_idx, view in enumerate(views):
         | 
|  | |
| 1590 | 
             
                    image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
         | 
| 1591 |  | 
|  | |
| 1592 | 
             
                    pred_pts3d = predictions["world_points"][view_idx]
         | 
| 1593 |  | 
|  | |
| 1594 | 
             
                    view_data = {
         | 
| 1595 | 
             
                        "image": image[0],
         | 
| 1596 | 
             
                        "points3d": pred_pts3d,
         | 
| @@ -1599,15 +717,22 @@ def process_predictions_for_visualization( | |
| 1599 | 
             
                        "mask": None,
         | 
| 1600 | 
             
                    }
         | 
| 1601 |  | 
|  | |
| 1602 | 
             
                    mask = predictions["final_mask"][view_idx].copy()
         | 
| 1603 |  | 
|  | |
| 1604 | 
             
                    if filter_black_bg:
         | 
|  | |
| 1605 | 
             
                        view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
         | 
|  | |
| 1606 | 
             
                        black_bg_mask = view_colors.sum(axis=2) >= 16
         | 
| 1607 | 
             
                        mask = mask & black_bg_mask
         | 
| 1608 |  | 
|  | |
| 1609 | 
             
                    if filter_white_bg:
         | 
|  | |
| 1610 | 
             
                        view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
         | 
|  | |
| 1611 | 
             
                        white_bg_mask = ~(
         | 
| 1612 | 
             
                            (view_colors[:, :, 0] > 240)
         | 
| 1613 | 
             
                            & (view_colors[:, :, 1] > 240)
         | 
| @@ -1631,6 +756,7 @@ def reset_measure(processed_data): | |
| 1631 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 1632 | 
             
                    return None, [], ""
         | 
| 1633 |  | 
|  | |
| 1634 | 
             
                first_view = list(processed_data.values())[0]
         | 
| 1635 | 
             
                return first_view["image"], [], ""
         | 
| 1636 |  | 
| @@ -1640,18 +766,20 @@ def measure( | |
| 1640 | 
             
            ):
         | 
| 1641 | 
             
                """Handle measurement on images"""
         | 
| 1642 | 
             
                try:
         | 
| 1643 | 
            -
                    print(f" | 
| 1644 |  | 
| 1645 | 
             
                    if processed_data is None or len(processed_data) == 0:
         | 
| 1646 | 
            -
                        return None, [], " | 
| 1647 |  | 
|  | |
| 1648 | 
             
                    try:
         | 
| 1649 | 
             
                        current_view_index = int(current_view_selector.split()[1]) - 1
         | 
| 1650 | 
             
                    except:
         | 
| 1651 | 
             
                        current_view_index = 0
         | 
| 1652 |  | 
| 1653 | 
            -
                    print(f" | 
| 1654 |  | 
|  | |
| 1655 | 
             
                    if current_view_index < 0 or current_view_index >= len(processed_data):
         | 
| 1656 | 
             
                        current_view_index = 0
         | 
| 1657 |  | 
| @@ -1659,46 +787,54 @@ def measure( | |
| 1659 | 
             
                    current_view = processed_data[view_keys[current_view_index]]
         | 
| 1660 |  | 
| 1661 | 
             
                    if current_view is None:
         | 
| 1662 | 
            -
                        return None, [], " | 
| 1663 |  | 
| 1664 | 
             
                    point2d = event.index[0], event.index[1]
         | 
| 1665 | 
            -
                    print(f" | 
| 1666 |  | 
|  | |
| 1667 | 
             
                    if (
         | 
| 1668 | 
             
                        current_view["mask"] is not None
         | 
| 1669 | 
             
                        and 0 <= point2d[1] < current_view["mask"].shape[0]
         | 
| 1670 | 
             
                        and 0 <= point2d[0] < current_view["mask"].shape[1]
         | 
| 1671 | 
             
                    ):
         | 
|  | |
| 1672 | 
             
                        if not current_view["mask"][point2d[1], point2d[0]]:
         | 
| 1673 | 
            -
                            print(f" | 
|  | |
| 1674 | 
             
                            masked_image, _ = update_measure_view(
         | 
| 1675 | 
             
                                processed_data, current_view_index
         | 
| 1676 | 
             
                            )
         | 
| 1677 | 
             
                            return (
         | 
| 1678 | 
             
                                masked_image,
         | 
| 1679 | 
             
                                measure_points,
         | 
| 1680 | 
            -
                                '<span style="color: red; font-weight: bold;" | 
| 1681 | 
             
                            )
         | 
| 1682 |  | 
| 1683 | 
             
                    measure_points.append(point2d)
         | 
| 1684 |  | 
|  | |
| 1685 | 
             
                    image, _ = update_measure_view(processed_data, current_view_index)
         | 
| 1686 | 
             
                    if image is None:
         | 
| 1687 | 
            -
                        return None, [], " | 
| 1688 |  | 
| 1689 | 
             
                    image = image.copy()
         | 
| 1690 | 
             
                    points3d = current_view["points3d"]
         | 
| 1691 |  | 
|  | |
| 1692 | 
             
                    try:
         | 
| 1693 | 
             
                        if image.dtype != np.uint8:
         | 
| 1694 | 
             
                            if image.max() <= 1.0:
         | 
|  | |
| 1695 | 
             
                                image = (image * 255).astype(np.uint8)
         | 
| 1696 | 
             
                            else:
         | 
|  | |
| 1697 | 
             
                                image = image.astype(np.uint8)
         | 
| 1698 | 
             
                    except Exception as e:
         | 
| 1699 | 
            -
                        print(f" | 
| 1700 | 
            -
                        return None, [], f" | 
| 1701 |  | 
|  | |
| 1702 | 
             
                    try:
         | 
| 1703 | 
             
                        for p in measure_points:
         | 
| 1704 | 
             
                            if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
         | 
| @@ -1706,8 +842,8 @@ def measure( | |
| 1706 | 
             
                                    image, p, radius=5, color=(255, 0, 0), thickness=2
         | 
| 1707 | 
             
                                )
         | 
| 1708 | 
             
                    except Exception as e:
         | 
| 1709 | 
            -
                        print(f" | 
| 1710 | 
            -
                        return None, [], f" | 
| 1711 |  | 
| 1712 | 
             
                    depth_text = ""
         | 
| 1713 | 
             
                    try:
         | 
| @@ -1718,22 +854,24 @@ def measure( | |
| 1718 | 
             
                                and 0 <= p[0] < current_view["depth"].shape[1]
         | 
| 1719 | 
             
                            ):
         | 
| 1720 | 
             
                                d = current_view["depth"][p[1], p[0]]
         | 
| 1721 | 
            -
                                depth_text += f"- **P{i + 1}  | 
| 1722 | 
             
                            else:
         | 
|  | |
| 1723 | 
             
                                if (
         | 
| 1724 | 
             
                                    points3d is not None
         | 
| 1725 | 
             
                                    and 0 <= p[1] < points3d.shape[0]
         | 
| 1726 | 
             
                                    and 0 <= p[0] < points3d.shape[1]
         | 
| 1727 | 
             
                                ):
         | 
| 1728 | 
             
                                    z = points3d[p[1], p[0], 2]
         | 
| 1729 | 
            -
                                    depth_text += f"- **P{i + 1} Z | 
| 1730 | 
             
                    except Exception as e:
         | 
| 1731 | 
            -
                        print(f" | 
| 1732 | 
            -
                        depth_text = f" | 
| 1733 |  | 
| 1734 | 
             
                    if len(measure_points) == 2:
         | 
| 1735 | 
             
                        try:
         | 
| 1736 | 
             
                            point1, point2 = measure_points
         | 
|  | |
| 1737 | 
             
                            if (
         | 
| 1738 | 
             
                                0 <= point1[0] < image.shape[1]
         | 
| 1739 | 
             
                                and 0 <= point1[1] < image.shape[0]
         | 
| @@ -1744,7 +882,8 @@ def measure( | |
| 1744 | 
             
                                    image, point1, point2, color=(255, 0, 0), thickness=2
         | 
| 1745 | 
             
                                )
         | 
| 1746 |  | 
| 1747 | 
            -
                             | 
|  | |
| 1748 | 
             
                            if (
         | 
| 1749 | 
             
                                points3d is not None
         | 
| 1750 | 
             
                                and 0 <= point1[1] < points3d.shape[0]
         | 
| @@ -1756,35 +895,39 @@ def measure( | |
| 1756 | 
             
                                    p1_3d = points3d[point1[1], point1[0]]
         | 
| 1757 | 
             
                                    p2_3d = points3d[point2[1], point2[0]]
         | 
| 1758 | 
             
                                    distance = np.linalg.norm(p1_3d - p2_3d)
         | 
| 1759 | 
            -
                                    distance_text = f"-  | 
| 1760 | 
             
                                except Exception as e:
         | 
| 1761 | 
            -
                                    print(f" | 
| 1762 | 
            -
                                    distance_text = f"-  | 
| 1763 |  | 
| 1764 | 
             
                            measure_points = []
         | 
| 1765 | 
             
                            text = depth_text + distance_text
         | 
| 1766 | 
            -
                            print(f" | 
| 1767 | 
             
                            return [image, measure_points, text]
         | 
| 1768 | 
             
                        except Exception as e:
         | 
| 1769 | 
            -
                            print(f" | 
| 1770 | 
            -
                            return None, [], f" | 
| 1771 | 
             
                    else:
         | 
| 1772 | 
            -
                        print(f" | 
| 1773 | 
             
                        return [image, measure_points, depth_text]
         | 
| 1774 |  | 
| 1775 | 
             
                except Exception as e:
         | 
| 1776 | 
            -
                    print(f" | 
| 1777 | 
            -
                    return None, [], f" | 
| 1778 |  | 
| 1779 |  | 
| 1780 | 
             
            def clear_fields():
         | 
| 1781 | 
            -
                """ | 
| 1782 | 
            -
                 | 
|  | |
|  | |
| 1783 |  | 
| 1784 |  | 
| 1785 | 
             
            def update_log():
         | 
| 1786 | 
            -
                """ | 
| 1787 | 
            -
                 | 
|  | |
|  | |
| 1788 |  | 
| 1789 |  | 
| 1790 | 
             
            def update_visualization(
         | 
| @@ -1796,16 +939,30 @@ def update_visualization( | |
| 1796 | 
             
                filter_white_bg=False,
         | 
| 1797 | 
             
                show_mesh=True,
         | 
| 1798 | 
             
            ):
         | 
| 1799 | 
            -
                """ | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1800 | 
             
                if is_example == "True":
         | 
| 1801 | 
            -
                    return  | 
|  | |
|  | |
|  | |
| 1802 |  | 
| 1803 | 
             
                if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
         | 
| 1804 | 
            -
                    return  | 
|  | |
|  | |
|  | |
| 1805 |  | 
| 1806 | 
             
                predictions_path = os.path.join(target_dir, "predictions.npz")
         | 
| 1807 | 
             
                if not os.path.exists(predictions_path):
         | 
| 1808 | 
            -
                    return  | 
|  | |
|  | |
|  | |
| 1809 |  | 
| 1810 | 
             
                loaded = np.load(predictions_path, allow_pickle=True)
         | 
| 1811 | 
             
                predictions = {key: loaded[key] for key in loaded.keys()}
         | 
| @@ -1815,17 +972,21 @@ def update_visualization( | |
| 1815 | 
             
                    f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
         | 
| 1816 | 
             
                )
         | 
| 1817 |  | 
| 1818 | 
            -
                 | 
| 1819 | 
            -
                     | 
| 1820 | 
            -
             | 
| 1821 | 
            -
             | 
| 1822 | 
            -
             | 
| 1823 | 
            -
             | 
| 1824 | 
            -
             | 
| 1825 | 
            -
             | 
| 1826 | 
            -
             | 
|  | |
| 1827 |  | 
| 1828 | 
            -
                return  | 
|  | |
|  | |
|  | |
| 1829 |  | 
| 1830 |  | 
| 1831 | 
             
            def update_all_views_on_filter_change(
         | 
| @@ -1837,7 +998,11 @@ def update_all_views_on_filter_change( | |
| 1837 | 
             
                normal_view_selector,
         | 
| 1838 | 
             
                measure_view_selector,
         | 
| 1839 | 
             
            ):
         | 
| 1840 | 
            -
                """ | 
|  | |
|  | |
|  | |
|  | |
| 1841 | 
             
                if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
         | 
| 1842 | 
             
                    return processed_data, None, None, None, []
         | 
| 1843 |  | 
| @@ -1846,16 +1011,20 @@ def update_all_views_on_filter_change( | |
| 1846 | 
             
                    return processed_data, None, None, None, []
         | 
| 1847 |  | 
| 1848 | 
             
                try:
         | 
|  | |
| 1849 | 
             
                    loaded = np.load(predictions_path, allow_pickle=True)
         | 
| 1850 | 
             
                    predictions = {key: loaded[key] for key in loaded.keys()}
         | 
| 1851 |  | 
|  | |
| 1852 | 
             
                    image_folder_path = os.path.join(target_dir, "images")
         | 
| 1853 | 
             
                    views = load_images(image_folder_path)
         | 
| 1854 |  | 
|  | |
| 1855 | 
             
                    new_processed_data = process_predictions_for_visualization(
         | 
| 1856 | 
             
                        predictions, views, high_level_config, filter_black_bg, filter_white_bg
         | 
| 1857 | 
             
                    )
         | 
| 1858 |  | 
|  | |
| 1859 | 
             
                    try:
         | 
| 1860 | 
             
                        depth_view_idx = (
         | 
| 1861 | 
             
                            int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
         | 
| @@ -1879,6 +1048,7 @@ def update_all_views_on_filter_change( | |
| 1879 | 
             
                    except:
         | 
| 1880 | 
             
                        measure_view_idx = 0
         | 
| 1881 |  | 
|  | |
| 1882 | 
             
                    depth_vis = update_depth_view(new_processed_data, depth_view_idx)
         | 
| 1883 | 
             
                    normal_vis = update_normal_view(new_processed_data, normal_view_idx)
         | 
| 1884 | 
             
                    measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
         | 
| @@ -1890,10 +1060,9 @@ def update_all_views_on_filter_change( | |
| 1890 | 
             
                    return processed_data, None, None, None, []
         | 
| 1891 |  | 
| 1892 |  | 
| 1893 | 
            -
            #  | 
| 1894 | 
            -
            # Example  | 
| 1895 | 
            -
            #  | 
| 1896 | 
            -
             | 
| 1897 | 
             
            def get_scene_info(examples_dir):
         | 
| 1898 | 
             
                """Get information about scenes in the examples directory"""
         | 
| 1899 | 
             
                import glob
         | 
| @@ -1905,6 +1074,7 @@ def get_scene_info(examples_dir): | |
| 1905 | 
             
                for scene_folder in sorted(os.listdir(examples_dir)):
         | 
| 1906 | 
             
                    scene_path = os.path.join(examples_dir, scene_folder)
         | 
| 1907 | 
             
                    if os.path.isdir(scene_path):
         | 
|  | |
| 1908 | 
             
                        image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
         | 
| 1909 | 
             
                        image_files = []
         | 
| 1910 | 
             
                        for ext in image_extensions:
         | 
| @@ -1912,6 +1082,7 @@ def get_scene_info(examples_dir): | |
| 1912 | 
             
                            image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
         | 
| 1913 |  | 
| 1914 | 
             
                        if image_files:
         | 
|  | |
| 1915 | 
             
                            image_files = sorted(image_files)
         | 
| 1916 | 
             
                            first_image = image_files[0]
         | 
| 1917 | 
             
                            num_images = len(image_files)
         | 
| @@ -1930,9 +1101,10 @@ def get_scene_info(examples_dir): | |
| 1930 |  | 
| 1931 |  | 
| 1932 | 
             
            def load_example_scene(scene_name, examples_dir="examples"):
         | 
| 1933 | 
            -
                """ | 
| 1934 | 
             
                scenes = get_scene_info(examples_dir)
         | 
| 1935 |  | 
|  | |
| 1936 | 
             
                selected_scene = None
         | 
| 1937 | 
             
                for scene in scenes:
         | 
| 1938 | 
             
                    if scene["name"] == scene_name:
         | 
| @@ -1940,26 +1112,28 @@ def load_example_scene(scene_name, examples_dir="examples"): | |
| 1940 | 
             
                        break
         | 
| 1941 |  | 
| 1942 | 
             
                if selected_scene is None:
         | 
| 1943 | 
            -
                    return None, None, None, " | 
| 1944 |  | 
|  | |
|  | |
| 1945 | 
             
                file_objects = []
         | 
| 1946 | 
             
                for image_path in selected_scene["image_files"]:
         | 
| 1947 | 
             
                    file_objects.append(image_path)
         | 
| 1948 |  | 
|  | |
| 1949 | 
             
                target_dir, image_paths = handle_uploads(file_objects, 1.0)
         | 
| 1950 |  | 
| 1951 | 
             
                return (
         | 
| 1952 | 
            -
                    None,
         | 
| 1953 | 
            -
                    target_dir,
         | 
| 1954 | 
            -
                    image_paths,
         | 
| 1955 | 
            -
                    f" | 
| 1956 | 
             
                )
         | 
| 1957 |  | 
| 1958 |  | 
| 1959 | 
            -
            #  | 
| 1960 | 
            -
            # Gradio UI
         | 
| 1961 | 
            -
            #  | 
| 1962 | 
            -
             | 
| 1963 | 
             
            theme = get_gradio_theme()
         | 
| 1964 |  | 
| 1965 | 
             
            # 自定义CSS防止UI抖动
         | 
| @@ -2022,45 +1196,44 @@ CUSTOM_CSS = GRADIO_CSS + """ | |
| 2022 | 
             
            }
         | 
| 2023 | 
             
            """
         | 
| 2024 |  | 
| 2025 | 
            -
             | 
| 2026 | 
            -
             | 
| 2027 | 
            -
            <script>
         | 
| 2028 | 
            -
            // 添加粘贴板支持
         | 
| 2029 | 
            -
            document.addEventListener('paste', function(e) {
         | 
| 2030 | 
            -
                const items = e.clipboardData.items;
         | 
| 2031 | 
            -
                for (let i = 0; i < items.length; i++) {
         | 
| 2032 | 
            -
                    if (items[i].type.indexOf('image') !== -1) {
         | 
| 2033 | 
            -
                        const blob = items[i].getAsFile();
         | 
| 2034 | 
            -
                        const fileInput = document.querySelector('input[type="file"][multiple]');
         | 
| 2035 | 
            -
                        if (fileInput) {
         | 
| 2036 | 
            -
                            const dataTransfer = new DataTransfer();
         | 
| 2037 | 
            -
                            dataTransfer.items.add(blob);
         | 
| 2038 | 
            -
                            fileInput.files = dataTransfer.files;
         | 
| 2039 | 
            -
                            fileInput.dispatchEvent(new Event('change', { bubbles: true }));
         | 
| 2040 | 
            -
                            console.log('✅ 图片已从剪贴板粘贴');
         | 
| 2041 | 
            -
                        }
         | 
| 2042 | 
            -
                    }
         | 
| 2043 | 
            -
                }
         | 
| 2044 | 
            -
            });
         | 
| 2045 | 
            -
             | 
| 2046 | 
            -
            // 添加提示信息
         | 
| 2047 | 
            -
            console.log('💡 粘贴板功能已启用:使用 Ctrl+V 可直接粘贴截图');
         | 
| 2048 | 
            -
            </script>
         | 
| 2049 | 
            -
            """
         | 
| 2050 | 
            -
             | 
| 2051 | 
            -
            with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与物体分割") as demo:
         | 
| 2052 | 
             
                is_example = gr.Textbox(label="is_example", visible=False, value="None")
         | 
|  | |
| 2053 | 
             
                processed_data_state = gr.State(value=None)
         | 
| 2054 | 
             
                measure_points_state = gr.State(value=[])
         | 
|  | |
| 2055 |  | 
| 2056 | 
             
                # 添加粘贴板支持的 JavaScript
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2057 | 
             
                gr.HTML(PASTE_JS)
         | 
| 2058 | 
            -
             | 
| 2059 | 
            -
                #  | 
| 2060 | 
             
                gr.HTML("""
         | 
| 2061 | 
             
                <div style="text-align: center; margin: 20px 0;">
         | 
| 2062 | 
            -
                    <h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything  | 
| 2063 | 
            -
                    <p style="color: #666; font-size: 16px;" | 
| 2064 | 
             
                </div>
         | 
| 2065 | 
             
                """)
         | 
| 2066 |  | 
| @@ -2133,23 +1306,6 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2133 | 
             
                                    clear_color=[0.0, 0.0, 0.0, 0.0]
         | 
| 2134 | 
             
                                )
         | 
| 2135 |  | 
| 2136 | 
            -
                            with gr.Tab("🎨 分割3D"):
         | 
| 2137 | 
            -
                                gr.Markdown(
         | 
| 2138 | 
            -
                                    """
         | 
| 2139 | 
            -
                                    💡 **使用说明**:
         | 
| 2140 | 
            -
                                    1. 在下方「⚙️ 高级选项」中勾选「启用语义分割 (CPU)」
         | 
| 2141 | 
            -
                                    2. 点击「开始重建」按钮
         | 
| 2142 | 
            -
                                    3. 等待处理完成后,分割结果将显示在此处
         | 
| 2143 | 
            -
                                    
         | 
| 2144 | 
            -
                                    📌 如果没有显示分割结果,请查看控制台日志查找原因
         | 
| 2145 | 
            -
                                    """,
         | 
| 2146 | 
            -
                                    elem_classes=["info-box"]
         | 
| 2147 | 
            -
                                )
         | 
| 2148 | 
            -
                                segmented_output = gr.Model3D(
         | 
| 2149 | 
            -
                                    height=450, zoom_speed=0.5, pan_speed=0.5,
         | 
| 2150 | 
            -
                                    clear_color=[0.0, 0.0, 0.0, 0.0]
         | 
| 2151 | 
            -
                                )
         | 
| 2152 | 
            -
                            
         | 
| 2153 | 
             
                            with gr.Tab("📊 深度图"):
         | 
| 2154 | 
             
                                with gr.Row(elem_classes=["navigation-row"]):
         | 
| 2155 | 
             
                                    prev_depth_btn = gr.Button("◀", size="sm", scale=1)
         | 
| @@ -2200,8 +1356,8 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2200 | 
             
                            max_lines=1
         | 
| 2201 | 
             
                        )
         | 
| 2202 |  | 
| 2203 | 
            -
                #  | 
| 2204 | 
            -
                with gr.Accordion("⚙️ 高级选项", open= | 
| 2205 | 
             
                    with gr.Row(equal_height=False):
         | 
| 2206 | 
             
                        with gr.Column(scale=1, min_width=300):
         | 
| 2207 | 
             
                            gr.Markdown("#### 可视化参数")
         | 
| @@ -2218,32 +1374,13 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2218 | 
             
                            apply_mask_checkbox = gr.Checkbox(
         | 
| 2219 | 
             
                                label="应用深度掩码", value=True
         | 
| 2220 | 
             
                            )
         | 
| 2221 | 
            -
                            
         | 
| 2222 | 
            -
                            gr.Markdown("#### 分割参数")
         | 
| 2223 | 
            -
                            gr.Markdown("💡 **说明**: 分割使用 CPU 运行(MobileSAM轻量级模型),不占用GPU资源")
         | 
| 2224 | 
            -
                            enable_segmentation = gr.Checkbox(
         | 
| 2225 | 
            -
                                label="启用语义分割 (CPU)", value=False
         | 
| 2226 | 
            -
                            )
         | 
| 2227 | 
            -
                            
         | 
| 2228 | 
            -
                            text_prompt = gr.Textbox(
         | 
| 2229 | 
            -
                                value=DEFAULT_TEXT_PROMPT,
         | 
| 2230 | 
            -
                                label="检测物体(用 . 分隔)",
         | 
| 2231 | 
            -
                                placeholder="例如: chair . table . sofa",
         | 
| 2232 | 
            -
                                lines=2,
         | 
| 2233 | 
            -
                                max_lines=2
         | 
| 2234 | 
            -
                            )
         | 
| 2235 | 
            -
                            
         | 
| 2236 | 
            -
                            with gr.Row():
         | 
| 2237 | 
            -
                                detect_all_btn = gr.Button("🔍 检测所有", size="sm")
         | 
| 2238 | 
            -
                                restore_default_btn = gr.Button("↻ 默认", size="sm")
         | 
| 2239 | 
            -
                            
         | 
| 2240 | 
            -
                            gr.Markdown("📌 **提示**: 启用后会在「分割3D」标签页显示彩色分割模型")
         | 
| 2241 | 
            -
             | 
| 2242 | 
             
                # 示例场景(可折叠)
         | 
| 2243 | 
             
                with gr.Accordion("🖼️ 示例场景", open=False):
         | 
|  | |
| 2244 | 
             
                    scenes = get_scene_info("examples")
         | 
|  | |
| 2245 | 
             
                    if scenes:
         | 
| 2246 | 
            -
                        for i in range(0, len(scenes), 4):
         | 
| 2247 | 
             
                            with gr.Row(equal_height=True):
         | 
| 2248 | 
             
                                for j in range(4):
         | 
| 2249 | 
             
                                    scene_idx = i + j
         | 
| @@ -2251,10 +1388,10 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2251 | 
             
                                        scene = scenes[scene_idx]
         | 
| 2252 | 
             
                                        with gr.Column(scale=1, min_width=150):
         | 
| 2253 | 
             
                                            scene_img = gr.Image(
         | 
| 2254 | 
            -
                                                value=scene["thumbnail"], | 
| 2255 | 
             
                                                height=150,
         | 
| 2256 | 
            -
                                                interactive=False, | 
| 2257 | 
            -
                                                show_label=False, | 
| 2258 | 
             
                                                sources=[],
         | 
| 2259 | 
             
                                                container=False
         | 
| 2260 | 
             
                                            )
         | 
| @@ -2266,22 +1403,14 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2266 | 
             
                                                fn=lambda name=scene["name"]: load_example_scene(name),
         | 
| 2267 | 
             
                                                outputs=[
         | 
| 2268 | 
             
                                                    reconstruction_output,
         | 
| 2269 | 
            -
                                                    target_dir_output, | 
| 2270 | 
            -
             | 
|  | |
|  | |
| 2271 | 
             
                                            )
         | 
| 2272 |  | 
| 2273 | 
             
                # === 事件绑定 ===
         | 
| 2274 |  | 
| 2275 | 
            -
                # 分割选项按钮
         | 
| 2276 | 
            -
                detect_all_btn.click(
         | 
| 2277 | 
            -
                    fn=lambda: COMMON_OBJECTS_PROMPT,
         | 
| 2278 | 
            -
                    outputs=[text_prompt]
         | 
| 2279 | 
            -
                )
         | 
| 2280 | 
            -
                restore_default_btn.click(
         | 
| 2281 | 
            -
                    fn=lambda: DEFAULT_TEXT_PROMPT,
         | 
| 2282 | 
            -
                    outputs=[text_prompt]
         | 
| 2283 | 
            -
                )
         | 
| 2284 | 
            -
                
         | 
| 2285 | 
             
                # 上传文件自动更新
         | 
| 2286 | 
             
                def update_gallery_on_unified_upload(files, interval):
         | 
| 2287 | 
             
                    if not files:
         | 
| @@ -2411,7 +1540,7 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2411 | 
             
                # 重建按钮
         | 
| 2412 | 
             
                submit_btn.click(
         | 
| 2413 | 
             
                    fn=clear_fields,
         | 
| 2414 | 
            -
                    outputs=[reconstruction_output | 
| 2415 | 
             
                ).then(
         | 
| 2416 | 
             
                    fn=update_log,
         | 
| 2417 | 
             
                    outputs=[log_output]
         | 
| @@ -2420,11 +1549,10 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2420 | 
             
                    inputs=[
         | 
| 2421 | 
             
                        target_dir_output, frame_filter, show_cam,
         | 
| 2422 | 
             
                        filter_black_bg, filter_white_bg,
         | 
| 2423 | 
            -
                        apply_mask_checkbox, show_mesh | 
| 2424 | 
            -
                        enable_segmentation, text_prompt
         | 
| 2425 | 
             
                    ],
         | 
| 2426 | 
             
                    outputs=[
         | 
| 2427 | 
            -
                        reconstruction_output,  | 
| 2428 | 
             
                        processed_data_state, depth_map, normal_map, measure_image,
         | 
| 2429 | 
             
                        measure_text, depth_view_selector, normal_view_selector, measure_view_selector
         | 
| 2430 | 
             
                    ]
         | 
| @@ -2434,8 +1562,8 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2434 | 
             
                )
         | 
| 2435 |  | 
| 2436 | 
             
                # 清空按钮
         | 
| 2437 | 
            -
                clear_btn.add([reconstruction_output,  | 
| 2438 | 
            -
             | 
| 2439 | 
             
                # 可视化参数实时更新
         | 
| 2440 | 
             
                for component in [frame_filter, show_cam, show_mesh]:
         | 
| 2441 | 
             
                    component.change(
         | 
| @@ -2457,7 +1585,7 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2457 | 
             
                        ],
         | 
| 2458 | 
             
                        outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
         | 
| 2459 | 
             
                    )
         | 
| 2460 | 
            -
             | 
| 2461 | 
             
                # 深度图导航
         | 
| 2462 | 
             
                prev_depth_btn.click(
         | 
| 2463 | 
             
                    fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
         | 
| @@ -2514,17 +1642,4 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与 | |
| 2514 | 
             
                    outputs=[measure_image, measure_points_state]
         | 
| 2515 | 
             
                )
         | 
| 2516 |  | 
| 2517 | 
            -
            # 启动信息
         | 
| 2518 | 
            -
            print("\n" + "="*70)
         | 
| 2519 | 
            -
            print("🚀 MapAnything V2 - 3D重建与物体分割")
         | 
| 2520 | 
            -
            print("="*70)
         | 
| 2521 | 
            -
            print("📊 核心技术: 自适应DBSCAN聚类 + 多视图融合")
         | 
| 2522 | 
            -
            print(f"🔧 质量控制: 置信度≥{MIN_DETECTION_CONFIDENCE} | 面积≥{MIN_MASK_AREA}px")
         | 
| 2523 | 
            -
            print(f"🎯 聚类半径: 沙发{DBSCAN_EPS_CONFIG['sofa']}m | 桌子{DBSCAN_EPS_CONFIG['table']}m | 窗户{DBSCAN_EPS_CONFIG['window']}m | 默认{DBSCAN_EPS_CONFIG['default']}m")
         | 
| 2524 | 
            -
            print("\n💡 分割配置 (CPU优化):")
         | 
| 2525 | 
            -
            print(f"   - 检测模型: {GROUNDING_DINO_MODEL_ID} (CPU)")
         | 
| 2526 | 
            -
            print(f"   - 分割模型: {SAM_MODEL_ID} (MobileSAM, 10MB, CPU)")
         | 
| 2527 | 
            -
            print(f"   - 运行设备: CPU (不占用GPU资源,适合分离部署)")
         | 
| 2528 | 
            -
            print("="*70 + "\n")
         | 
| 2529 | 
            -
             | 
| 2530 | 
             
            demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)
         | 
|  | |
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 |  | 
| 7 | 
             
            """
         | 
| 8 | 
            +
            MapAnything V2 - 3D重建系统(中文版)
         | 
| 9 | 
            +
            - 多视图 3D 重建
         | 
| 10 | 
            +
            - 深度估计与法线计算  
         | 
| 11 | 
            +
            - 距离测量功能
         | 
|  | |
| 12 | 
             
            """
         | 
| 13 |  | 
| 14 | 
             
            import gc
         | 
|  | |
| 17 | 
             
            import sys
         | 
| 18 | 
             
            import time
         | 
| 19 | 
             
            from datetime import datetime
         | 
|  | |
|  | |
| 20 |  | 
| 21 | 
             
            os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
         | 
| 22 |  | 
|  | |
| 25 | 
             
            import numpy as np
         | 
| 26 | 
             
            import spaces
         | 
| 27 | 
             
            import torch
         | 
|  | |
| 28 | 
             
            from PIL import Image
         | 
| 29 | 
             
            from pillow_heif import register_heif_opener
         | 
|  | |
| 30 |  | 
| 31 | 
             
            register_heif_opener()
         | 
| 32 |  | 
|  | |
| 60 | 
             
                    return None
         | 
| 61 |  | 
| 62 |  | 
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
            # MapAnything Configuration
         | 
| 64 | 
             
            high_level_config = {
         | 
| 65 | 
             
                "path": "configs/train.yaml",
         | 
|  | |
| 80 | 
             
                "resolution": 518,
         | 
| 81 | 
             
            }
         | 
| 82 |  | 
| 83 | 
            +
            # Initialize model - this will be done on GPU when needed
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
            model = None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 85 |  | 
| 86 |  | 
| 87 | 
            +
            # -------------------------------------------------------------------------
         | 
| 88 | 
            +
            # 1) Core model inference
         | 
| 89 | 
            +
            # -------------------------------------------------------------------------
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 90 | 
             
            @spaces.GPU(duration=120)
         | 
| 91 | 
             
            def run_model(
         | 
| 92 | 
             
                target_dir,
         | 
|  | |
| 94 | 
             
                mask_edges=True,
         | 
| 95 | 
             
                filter_black_bg=False,
         | 
| 96 | 
             
                filter_white_bg=False,
         | 
|  | |
|  | |
| 97 | 
             
                progress=gr.Progress(),
         | 
| 98 | 
             
            ):
         | 
| 99 | 
             
                """
         | 
| 100 | 
            +
                Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
         | 
| 101 | 
             
                """
         | 
| 102 | 
             
                global model
         | 
| 103 | 
            +
                import torch  # Ensure torch is available in function scope
         | 
| 104 |  | 
| 105 | 
            +
                start_time = time.time()
         | 
| 106 | 
             
                print(f"Processing images from {target_dir}")
         | 
| 107 |  | 
| 108 | 
            +
                # Device check
         | 
| 109 | 
            +
                progress(0, desc="🔧 初始化设备...")
         | 
| 110 | 
             
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 111 | 
             
                device = torch.device(device)
         | 
| 112 |  | 
| 113 | 
            +
                # Initialize model if not already done
         | 
| 114 | 
            +
                progress(0.05, desc="📥 加载模型... (~5秒)")
         | 
| 115 | 
             
                if model is None:
         | 
| 116 | 
             
                    model = initialize_mapanything_model(high_level_config, device)
         | 
| 117 | 
             
                else:
         | 
|  | |
| 119 |  | 
| 120 | 
             
                model.eval()
         | 
| 121 |  | 
| 122 | 
            +
                # Load images using MapAnything's load_images function
         | 
| 123 | 
            +
                progress(0.15, desc="📷 加载图片... (~2秒)")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 124 | 
             
                print("Loading images...")
         | 
| 125 | 
             
                image_folder_path = os.path.join(target_dir, "images")
         | 
| 126 | 
             
                views = load_images(image_folder_path)
         | 
|  | |
| 130 | 
             
                    raise ValueError("No images found. Check your upload.")
         | 
| 131 |  | 
| 132 | 
             
                # Run model inference
         | 
| 133 | 
            +
                num_images = len(views)
         | 
| 134 | 
            +
                estimated_time = num_images * 3  # 估计每张图片3秒
         | 
| 135 | 
            +
                progress(0.2, desc=f"🚀 运行3D重建... ({num_images}张图片,预计{estimated_time}秒)")
         | 
| 136 | 
             
                print("Running inference...")
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                inference_start = time.time()
         | 
| 139 | 
             
                outputs = model.infer(
         | 
| 140 | 
             
                    views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
         | 
| 141 | 
             
                )
         | 
| 142 | 
            +
                inference_time = time.time() - inference_start
         | 
| 143 |  | 
| 144 | 
            +
                # Convert predictions to format expected by visualization
         | 
| 145 | 
            +
                progress(0.6, desc=f"🔄 处理预测结果... (推理耗时: {inference_time:.1f}秒)")
         | 
| 146 | 
             
                predictions = {}
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                # Initialize lists for the required keys
         | 
| 149 | 
             
                extrinsic_list = []
         | 
| 150 | 
             
                intrinsic_list = []
         | 
| 151 | 
             
                world_points_list = []
         | 
|  | |
| 153 | 
             
                images_list = []
         | 
| 154 | 
             
                final_mask_list = []
         | 
| 155 |  | 
| 156 | 
            +
                # Loop through the outputs
         | 
| 157 | 
            +
                for i, pred in enumerate(outputs):
         | 
| 158 | 
            +
                    if i % max(1, len(outputs) // 5) == 0:
         | 
| 159 | 
            +
                        progress(0.6 + (i / len(outputs)) * 0.25, desc=f"🔄 处理视图 {i+1}/{len(outputs)}...")
         | 
| 160 | 
            +
                    # Extract data from predictions
         | 
| 161 | 
            +
                    depthmap_torch = pred["depth_z"][0].squeeze(-1)  # (H, W)
         | 
| 162 | 
            +
                    intrinsics_torch = pred["intrinsics"][0]  # (3, 3)
         | 
| 163 | 
            +
                    camera_pose_torch = pred["camera_poses"][0]  # (4, 4)
         | 
| 164 |  | 
| 165 | 
            +
                    # Compute new pts3d using depth, intrinsics, and camera pose
         | 
| 166 | 
             
                    pts3d_computed, valid_mask = depthmap_to_world_frame(
         | 
| 167 | 
             
                        depthmap_torch, intrinsics_torch, camera_pose_torch
         | 
| 168 | 
             
                    )
         | 
| 169 |  | 
| 170 | 
            +
                    # Convert to numpy arrays for visualization
         | 
| 171 | 
            +
                    # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
         | 
| 172 | 
             
                    if "mask" in pred:
         | 
| 173 | 
             
                        mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
         | 
| 174 | 
             
                    else:
         | 
| 175 | 
            +
                        # Fill with boolean trues in the size of depthmap_torch
         | 
| 176 | 
             
                        mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
         | 
| 177 |  | 
| 178 | 
            +
                    # Combine with valid depth mask
         | 
| 179 | 
             
                    mask = mask & valid_mask.cpu().numpy()
         | 
| 180 | 
            +
             | 
| 181 | 
             
                    image = pred["img_no_norm"][0].cpu().numpy()
         | 
| 182 |  | 
| 183 | 
            +
                    # Append to lists
         | 
| 184 | 
             
                    extrinsic_list.append(camera_pose_torch.cpu().numpy())
         | 
| 185 | 
             
                    intrinsic_list.append(intrinsics_torch.cpu().numpy())
         | 
| 186 | 
             
                    world_points_list.append(pts3d_computed.cpu().numpy())
         | 
| 187 | 
             
                    depth_maps_list.append(depthmap_torch.cpu().numpy())
         | 
| 188 | 
            +
                    images_list.append(image)  # Add image to list
         | 
| 189 | 
            +
                    final_mask_list.append(mask)  # Add final_mask to list
         | 
| 190 |  | 
| 191 | 
            +
                # Convert lists to numpy arrays with required shapes
         | 
| 192 | 
            +
                # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
         | 
| 193 | 
             
                predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
         | 
| 196 | 
             
                predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # world_points: (S, H, W, 3) - batch of 3D world points
         | 
| 199 | 
             
                predictions["world_points"] = np.stack(world_points_list, axis=0)
         | 
| 200 |  | 
| 201 | 
            +
                # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
         | 
| 202 | 
             
                depth_maps = np.stack(depth_maps_list, axis=0)
         | 
| 203 | 
            +
                # Add channel dimension if needed to match (S, H, W, 1) format
         | 
| 204 | 
             
                if len(depth_maps.shape) == 3:
         | 
| 205 | 
             
                    depth_maps = depth_maps[..., np.newaxis]
         | 
| 206 | 
            +
             | 
| 207 | 
             
                predictions["depth"] = depth_maps
         | 
| 208 |  | 
| 209 | 
            +
                # images: (S, H, W, 3) - batch of input images
         | 
| 210 | 
             
                predictions["images"] = np.stack(images_list, axis=0)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                # final_mask: (S, H, W) - batch of final masks for filtering
         | 
| 213 | 
             
                predictions["final_mask"] = np.stack(final_mask_list, axis=0)
         | 
| 214 |  | 
| 215 | 
            +
                # Process data for visualization tabs (depth, normal, measure)
         | 
| 216 | 
            +
                progress(0.85, desc="🎨 生成深度图与法线图...")
         | 
| 217 | 
             
                processed_data = process_predictions_for_visualization(
         | 
| 218 | 
             
                    predictions, views, high_level_config, filter_black_bg, filter_white_bg
         | 
| 219 | 
             
                )
         | 
| 220 |  | 
| 221 | 
            +
                # Clean up
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 222 | 
             
                progress(0.95, desc="🧹 清理内存...")
         | 
| 223 | 
             
                torch.cuda.empty_cache()
         | 
| 224 |  | 
| 225 | 
            +
                total_time = time.time() - start_time
         | 
| 226 | 
            +
                progress(1.0, desc=f"✅ 完成!总耗时: {total_time:.1f}秒")
         | 
| 227 | 
            +
                print(f"Total processing time: {total_time:.2f} seconds")
         | 
| 228 |  | 
| 229 | 
            +
                return predictions, processed_data
         | 
| 230 |  | 
|  | |
|  | |
|  | |
| 231 |  | 
| 232 | 
             
            def update_view_selectors(processed_data):
         | 
| 233 | 
             
                """Update view selector dropdowns based on available views"""
         | 
|  | |
| 238 | 
             
                    choices = [f"View {i + 1}" for i in range(num_views)]
         | 
| 239 |  | 
| 240 | 
             
                return (
         | 
| 241 | 
            +
                    gr.Dropdown(choices=choices, value=choices[0]),  # depth_view_selector
         | 
| 242 | 
            +
                    gr.Dropdown(choices=choices, value=choices[0]),  # normal_view_selector
         | 
| 243 | 
            +
                    gr.Dropdown(choices=choices, value=choices[0]),  # measure_view_selector
         | 
| 244 | 
             
                )
         | 
| 245 |  | 
| 246 |  | 
|  | |
| 278 | 
             
                """Update measure view for a specific view index with mask overlay"""
         | 
| 279 | 
             
                view_data = get_view_data_by_index(processed_data, view_index)
         | 
| 280 | 
             
                if view_data is None:
         | 
| 281 | 
            +
                    return None, []  # image, measure_points
         | 
| 282 |  | 
| 283 | 
            +
                # Get the base image
         | 
| 284 | 
             
                image = view_data["image"].copy()
         | 
| 285 |  | 
| 286 | 
            +
                # Ensure image is in uint8 format
         | 
| 287 | 
             
                if image.dtype != np.uint8:
         | 
| 288 | 
             
                    if image.max() <= 1.0:
         | 
| 289 | 
             
                        image = (image * 255).astype(np.uint8)
         | 
| 290 | 
             
                    else:
         | 
| 291 | 
             
                        image = image.astype(np.uint8)
         | 
| 292 |  | 
| 293 | 
            +
                # Apply mask overlay if mask is available
         | 
| 294 | 
             
                if view_data["mask"] is not None:
         | 
| 295 | 
             
                    mask = view_data["mask"]
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    # Create light grey overlay for masked areas
         | 
| 298 | 
            +
                    # Masked areas (False values) will be overlaid with light grey
         | 
| 299 | 
            +
                    invalid_mask = ~mask  # Areas where mask is False
         | 
| 300 |  | 
| 301 | 
             
                    if invalid_mask.any():
         | 
| 302 | 
            +
                        # Create a light grey overlay (RGB: 192, 192, 192)
         | 
| 303 | 
             
                        overlay_color = np.array([255, 220, 220], dtype=np.uint8)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                        # Apply overlay with some transparency
         | 
| 306 | 
            +
                        alpha = 0.5  # Transparency level
         | 
| 307 | 
            +
                        for c in range(3):  # RGB channels
         | 
| 308 | 
             
                            image[:, :, c] = np.where(
         | 
| 309 | 
             
                                invalid_mask,
         | 
| 310 | 
             
                                (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
         | 
|  | |
| 319 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 320 | 
             
                    return "View 1", None
         | 
| 321 |  | 
| 322 | 
            +
                # Parse current view number
         | 
| 323 | 
             
                try:
         | 
| 324 | 
             
                    current_view = int(current_selector_value.split()[1]) - 1
         | 
| 325 | 
             
                except:
         | 
|  | |
| 339 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 340 | 
             
                    return "View 1", None
         | 
| 341 |  | 
| 342 | 
            +
                # Parse current view number
         | 
| 343 | 
             
                try:
         | 
| 344 | 
             
                    current_view = int(current_selector_value.split()[1]) - 1
         | 
| 345 | 
             
                except:
         | 
|  | |
| 359 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 360 | 
             
                    return "View 1", None, []
         | 
| 361 |  | 
| 362 | 
            +
                # Parse current view number
         | 
| 363 | 
             
                try:
         | 
| 364 | 
             
                    current_view = int(current_selector_value.split()[1]) - 1
         | 
| 365 | 
             
                except:
         | 
|  | |
| 379 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 380 | 
             
                    return None, None, None, []
         | 
| 381 |  | 
| 382 | 
            +
                # Use update functions to ensure confidence filtering is applied from the start
         | 
| 383 | 
             
                depth_vis = update_depth_view(processed_data, 0)
         | 
| 384 | 
             
                normal_vis = update_normal_view(processed_data, 0)
         | 
| 385 | 
             
                measure_img, _ = update_measure_view(processed_data, 0)
         | 
|  | |
| 387 | 
             
                return depth_vis, normal_vis, measure_img, []
         | 
| 388 |  | 
| 389 |  | 
| 390 | 
            +
            # -------------------------------------------------------------------------
         | 
| 391 | 
            +
            # 2) Handle uploaded video/images --> produce target_dir + images
         | 
| 392 | 
            +
            # -------------------------------------------------------------------------
         | 
| 393 | 
             
            def handle_uploads(unified_upload, s_time_interval=1.0):
         | 
| 394 | 
             
                """
         | 
| 395 | 
             
                Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
         | 
|  | |
| 399 | 
             
                gc.collect()
         | 
| 400 | 
             
                torch.cuda.empty_cache()
         | 
| 401 |  | 
| 402 | 
            +
                # Create a unique folder name
         | 
| 403 | 
             
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
         | 
| 404 | 
             
                target_dir = f"input_images_{timestamp}"
         | 
| 405 | 
             
                target_dir_images = os.path.join(target_dir, "images")
         | 
| 406 |  | 
| 407 | 
            +
                # Clean up if somehow that folder already exists
         | 
| 408 | 
             
                if os.path.exists(target_dir):
         | 
| 409 | 
             
                    shutil.rmtree(target_dir)
         | 
| 410 | 
             
                os.makedirs(target_dir)
         | 
|  | |
| 412 |  | 
| 413 | 
             
                image_paths = []
         | 
| 414 |  | 
| 415 | 
            +
                # --- Handle uploaded files (both images and videos) ---
         | 
| 416 | 
             
                if unified_upload is not None:
         | 
| 417 | 
             
                    for file_data in unified_upload:
         | 
| 418 | 
             
                        if isinstance(file_data, dict) and "name" in file_data:
         | 
|  | |
| 422 |  | 
| 423 | 
             
                        file_ext = os.path.splitext(file_path)[1].lower()
         | 
| 424 |  | 
| 425 | 
            +
                        # Check if it's a video file
         | 
| 426 | 
             
                        video_extensions = [
         | 
| 427 | 
            +
                            ".mp4",
         | 
| 428 | 
            +
                            ".avi",
         | 
| 429 | 
            +
                            ".mov",
         | 
| 430 | 
            +
                            ".mkv",
         | 
| 431 | 
            +
                            ".wmv",
         | 
| 432 | 
            +
                            ".flv",
         | 
| 433 | 
            +
                            ".webm",
         | 
| 434 | 
            +
                            ".m4v",
         | 
| 435 | 
            +
                            ".3gp",
         | 
| 436 | 
             
                        ]
         | 
| 437 | 
             
                        if file_ext in video_extensions:
         | 
| 438 | 
            +
                            # Handle as video
         | 
| 439 | 
             
                            vs = cv2.VideoCapture(file_path)
         | 
| 440 | 
             
                            fps = vs.get(cv2.CAP_PROP_FPS)
         | 
| 441 | 
            +
                            frame_interval = int(fps * s_time_interval)  # frames per interval
         | 
| 442 |  | 
| 443 | 
             
                            count = 0
         | 
| 444 | 
             
                            video_frame_num = 0
         | 
|  | |
| 448 | 
             
                                    break
         | 
| 449 | 
             
                                count += 1
         | 
| 450 | 
             
                                if count % frame_interval == 0:
         | 
| 451 | 
            +
                                    # Use original filename as prefix for frames
         | 
| 452 | 
             
                                    base_name = os.path.splitext(os.path.basename(file_path))[0]
         | 
| 453 | 
             
                                    image_path = os.path.join(
         | 
| 454 | 
             
                                        target_dir_images, f"{base_name}_{video_frame_num:06}.png"
         | 
|  | |
| 457 | 
             
                                    image_paths.append(image_path)
         | 
| 458 | 
             
                                    video_frame_num += 1
         | 
| 459 | 
             
                            vs.release()
         | 
| 460 | 
            +
                            print(
         | 
| 461 | 
            +
                                f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
         | 
| 462 | 
            +
                            )
         | 
| 463 |  | 
| 464 | 
             
                        else:
         | 
| 465 | 
            +
                            # Handle as image
         | 
| 466 | 
            +
                            # Check if the file is a HEIC image
         | 
| 467 | 
             
                            if file_ext in [".heic", ".heif"]:
         | 
| 468 | 
            +
                                # Convert HEIC to JPEG for better gallery compatibility
         | 
| 469 | 
             
                                try:
         | 
| 470 | 
             
                                    with Image.open(file_path) as img:
         | 
| 471 | 
            +
                                        # Convert to RGB if necessary (HEIC can have different color modes)
         | 
| 472 | 
             
                                        if img.mode not in ("RGB", "L"):
         | 
| 473 | 
             
                                            img = img.convert("RGB")
         | 
| 474 |  | 
| 475 | 
            +
                                        # Create JPEG filename
         | 
| 476 | 
             
                                        base_name = os.path.splitext(os.path.basename(file_path))[0]
         | 
| 477 | 
            +
                                        dst_path = os.path.join(
         | 
| 478 | 
            +
                                            target_dir_images, f"{base_name}.jpg"
         | 
| 479 | 
            +
                                        )
         | 
| 480 |  | 
| 481 | 
            +
                                        # Save as JPEG with high quality
         | 
| 482 | 
             
                                        img.save(dst_path, "JPEG", quality=95)
         | 
| 483 | 
             
                                        image_paths.append(dst_path)
         | 
| 484 | 
            +
                                        print(
         | 
| 485 | 
            +
                                            f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
         | 
| 486 | 
            +
                                        )
         | 
| 487 | 
             
                                except Exception as e:
         | 
| 488 | 
             
                                    print(f"Error converting HEIC file {file_path}: {e}")
         | 
| 489 | 
            +
                                    # Fall back to copying as is
         | 
| 490 | 
            +
                                    dst_path = os.path.join(
         | 
| 491 | 
            +
                                        target_dir_images, os.path.basename(file_path)
         | 
| 492 | 
            +
                                    )
         | 
| 493 | 
             
                                    shutil.copy(file_path, dst_path)
         | 
| 494 | 
             
                                    image_paths.append(dst_path)
         | 
| 495 | 
             
                            else:
         | 
| 496 | 
            +
                                # Regular image files - copy as is
         | 
| 497 | 
            +
                                dst_path = os.path.join(
         | 
| 498 | 
            +
                                    target_dir_images, os.path.basename(file_path)
         | 
| 499 | 
            +
                                )
         | 
| 500 | 
             
                                shutil.copy(file_path, dst_path)
         | 
| 501 | 
             
                                image_paths.append(dst_path)
         | 
| 502 |  | 
| 503 | 
            +
                # Sort final images for gallery
         | 
| 504 | 
             
                image_paths = sorted(image_paths)
         | 
| 505 |  | 
| 506 | 
             
                end_time = time.time()
         | 
| 507 | 
            +
                print(
         | 
| 508 | 
            +
                    f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
         | 
| 509 | 
            +
                )
         | 
| 510 | 
             
                return target_dir, image_paths
         | 
| 511 |  | 
| 512 |  | 
| 513 | 
            +
            # -------------------------------------------------------------------------
         | 
| 514 | 
            +
            # 3) Update gallery on upload
         | 
| 515 | 
            +
            # -------------------------------------------------------------------------
         | 
| 516 | 
             
            def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
         | 
| 517 | 
            +
                """
         | 
| 518 | 
            +
                Whenever user uploads or changes files, immediately handle them
         | 
| 519 | 
            +
                and show in the gallery. Return (target_dir, image_paths).
         | 
| 520 | 
            +
                If nothing is uploaded, returns "None" and empty list.
         | 
| 521 | 
            +
                """
         | 
| 522 | 
             
                if not input_video and not input_images:
         | 
| 523 | 
            +
                    return None, None, None, None
         | 
| 524 | 
             
                target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
         | 
| 525 | 
             
                return (
         | 
|  | |
| 526 | 
             
                    None,
         | 
| 527 | 
             
                    target_dir,
         | 
| 528 | 
             
                    image_paths,
         | 
| 529 | 
            +
                    "上传完成。点击「开始重建」进行3D处理",
         | 
| 530 | 
             
                )
         | 
| 531 |  | 
| 532 |  | 
| 533 | 
            +
            # -------------------------------------------------------------------------
         | 
| 534 | 
            +
            # 4) Reconstruction: uses the target_dir plus any viz parameters
         | 
| 535 | 
            +
            # -------------------------------------------------------------------------
         | 
| 536 | 
             
            @spaces.GPU(duration=120)
         | 
| 537 | 
             
            def gradio_demo(
         | 
| 538 | 
             
                target_dir,
         | 
|  | |
| 542 | 
             
                filter_white_bg=False,
         | 
| 543 | 
             
                apply_mask=True,
         | 
| 544 | 
             
                show_mesh=True,
         | 
|  | |
|  | |
| 545 | 
             
                progress=gr.Progress(),
         | 
| 546 | 
             
            ):
         | 
| 547 | 
            +
                """
         | 
| 548 | 
            +
                Perform reconstruction using the already-created target_dir/images.
         | 
| 549 | 
            +
                """
         | 
| 550 | 
             
                if not os.path.isdir(target_dir) or target_dir == "None":
         | 
| 551 | 
            +
                    return None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None
         | 
| 552 |  | 
| 553 | 
             
                progress(0, desc="🔄 准备重建...")
         | 
| 554 | 
             
                start_time = time.time()
         | 
| 555 | 
             
                gc.collect()
         | 
| 556 | 
             
                torch.cuda.empty_cache()
         | 
| 557 |  | 
| 558 | 
            +
                # Prepare frame_filter dropdown
         | 
| 559 | 
             
                target_dir_images = os.path.join(target_dir, "images")
         | 
| 560 | 
             
                all_files = (
         | 
| 561 | 
             
                    sorted(os.listdir(target_dir_images))
         | 
|  | |
| 565 | 
             
                all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
         | 
| 566 | 
             
                frame_filter_choices = ["All"] + all_files
         | 
| 567 |  | 
| 568 | 
            +
                progress(0.05, desc=f"🚀 运行 MapAnything 模型... ({len(all_files)}张图片)")
         | 
| 569 | 
            +
                print("Running MapAnything model...")
         | 
| 570 | 
             
                with torch.no_grad():
         | 
| 571 | 
            +
                    predictions, processed_data = run_model(
         | 
| 572 | 
            +
                        target_dir, apply_mask, True, filter_black_bg, filter_white_bg, progress
         | 
|  | |
| 573 | 
             
                    )
         | 
| 574 |  | 
| 575 | 
            +
                # Save predictions
         | 
| 576 | 
             
                progress(0.92, desc="💾 保存预测结果...")
         | 
| 577 | 
             
                prediction_save_path = os.path.join(target_dir, "predictions.npz")
         | 
| 578 | 
             
                np.savez(prediction_save_path, **predictions)
         | 
| 579 |  | 
| 580 | 
            +
                # Handle None frame_filter
         | 
| 581 | 
             
                if frame_filter is None:
         | 
| 582 | 
             
                    frame_filter = "All"
         | 
| 583 |  | 
| 584 | 
            +
                # Build a GLB file name
         | 
| 585 | 
            +
                progress(0.93, desc="🏗️ 生成3D模型文件...")
         | 
| 586 | 
             
                glbfile = os.path.join(
         | 
| 587 | 
             
                    target_dir,
         | 
| 588 | 
            +
                    f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
         | 
| 589 | 
             
                )
         | 
| 590 |  | 
| 591 | 
            +
                # Convert predictions to GLB
         | 
| 592 | 
             
                glbscene = predictions_to_glb(
         | 
| 593 | 
             
                    predictions,
         | 
| 594 | 
             
                    filter_by_frames=frame_filter,
         | 
| 595 | 
             
                    show_cam=show_cam,
         | 
| 596 | 
             
                    mask_black_bg=filter_black_bg,
         | 
| 597 | 
             
                    mask_white_bg=filter_white_bg,
         | 
| 598 | 
            +
                    as_mesh=show_mesh,  # Use the show_mesh parameter
         | 
| 599 | 
             
                )
         | 
| 600 | 
             
                glbscene.export(file_obj=glbfile)
         | 
| 601 |  | 
| 602 | 
            +
                # Cleanup
         | 
| 603 | 
             
                progress(0.96, desc="🧹 清理内存...")
         | 
| 604 | 
             
                del predictions
         | 
| 605 | 
             
                gc.collect()
         | 
| 606 | 
             
                torch.cuda.empty_cache()
         | 
| 607 |  | 
| 608 | 
             
                end_time = time.time()
         | 
| 609 | 
            +
                total_time = end_time - start_time
         | 
| 610 | 
            +
                print(f"总耗时: {total_time:.2f}秒")
         | 
| 611 | 
            +
                log_msg = f"✅ 重建成功 ({len(all_files)} 帧,耗时 {total_time:.1f}秒)"
         | 
| 612 |  | 
| 613 | 
            +
                # Populate visualization tabs with processed data
         | 
| 614 | 
             
                progress(0.98, desc="🎨 生成可视化...")
         | 
| 615 | 
             
                depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
         | 
| 616 | 
             
                    processed_data
         | 
| 617 | 
             
                )
         | 
| 618 |  | 
| 619 | 
            +
                # Update view selectors based on available views
         | 
| 620 | 
             
                depth_selector, normal_selector, measure_selector = update_view_selectors(
         | 
| 621 | 
             
                    processed_data
         | 
| 622 | 
             
                )
         | 
| 623 |  | 
| 624 | 
             
                progress(1.0, desc="✅ 全部完成!")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 625 |  | 
| 626 | 
             
                return (
         | 
| 627 | 
             
                    glbfile,
         | 
|  | |
| 628 | 
             
                    log_msg,
         | 
| 629 | 
             
                    gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
         | 
| 630 | 
             
                    processed_data,
         | 
| 631 | 
             
                    depth_vis,
         | 
| 632 | 
             
                    normal_vis,
         | 
| 633 | 
             
                    measure_img,
         | 
| 634 | 
            +
                    "",  # measure_text (empty initially)
         | 
| 635 | 
             
                    depth_selector,
         | 
| 636 | 
             
                    normal_selector,
         | 
| 637 | 
             
                    measure_selector,
         | 
| 638 | 
             
                )
         | 
| 639 |  | 
| 640 |  | 
| 641 | 
            +
            # -------------------------------------------------------------------------
         | 
| 642 | 
            +
            # 5) Helper functions for UI resets + re-visualization
         | 
| 643 | 
            +
            # -------------------------------------------------------------------------
         | 
| 644 | 
             
            def colorize_depth(depth_map, mask=None):
         | 
| 645 | 
             
                """Convert depth map to colorized visualization with optional mask"""
         | 
| 646 | 
             
                if depth_map is None:
         | 
| 647 | 
             
                    return None
         | 
| 648 |  | 
| 649 | 
            +
                # Normalize depth to 0-1 range
         | 
| 650 | 
             
                depth_normalized = depth_map.copy()
         | 
| 651 | 
             
                valid_mask = depth_normalized > 0
         | 
| 652 |  | 
| 653 | 
            +
                # Apply additional mask if provided (for background filtering)
         | 
| 654 | 
             
                if mask is not None:
         | 
| 655 | 
             
                    valid_mask = valid_mask & mask
         | 
| 656 |  | 
|  | |
| 661 |  | 
| 662 | 
             
                    depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
         | 
| 663 |  | 
| 664 | 
            +
                # Apply colormap
         | 
| 665 | 
             
                import matplotlib.pyplot as plt
         | 
| 666 |  | 
| 667 | 
             
                colormap = plt.cm.turbo_r
         | 
| 668 | 
             
                colored = colormap(depth_normalized)
         | 
| 669 | 
             
                colored = (colored[:, :, :3] * 255).astype(np.uint8)
         | 
| 670 |  | 
| 671 | 
            +
                # Set invalid pixels to white
         | 
| 672 | 
             
                colored[~valid_mask] = [255, 255, 255]
         | 
| 673 |  | 
| 674 | 
             
                return colored
         | 
|  | |
| 679 | 
             
                if normal_map is None:
         | 
| 680 | 
             
                    return None
         | 
| 681 |  | 
| 682 | 
            +
                # Create a copy for modification
         | 
| 683 | 
             
                normal_vis = normal_map.copy()
         | 
| 684 |  | 
| 685 | 
            +
                # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
         | 
| 686 | 
             
                if mask is not None:
         | 
| 687 | 
             
                    invalid_mask = ~mask
         | 
| 688 | 
            +
                    normal_vis[invalid_mask] = [0, 0, 0]  # Set invalid areas to zero
         | 
| 689 |  | 
| 690 | 
            +
                # Normalize normals to [0, 1] range for visualization
         | 
| 691 | 
             
                normal_vis = (normal_vis + 1.0) / 2.0
         | 
| 692 | 
             
                normal_vis = (normal_vis * 255).astype(np.uint8)
         | 
| 693 |  | 
|  | |
| 700 | 
             
                """Extract depth, normal, and 3D points from predictions for visualization"""
         | 
| 701 | 
             
                processed_data = {}
         | 
| 702 |  | 
| 703 | 
            +
                # Process each view
         | 
| 704 | 
             
                for view_idx, view in enumerate(views):
         | 
| 705 | 
            +
                    # Get image
         | 
| 706 | 
             
                    image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
         | 
| 707 |  | 
| 708 | 
            +
                    # Get predicted points
         | 
| 709 | 
             
                    pred_pts3d = predictions["world_points"][view_idx]
         | 
| 710 |  | 
| 711 | 
            +
                    # Initialize data for this view
         | 
| 712 | 
             
                    view_data = {
         | 
| 713 | 
             
                        "image": image[0],
         | 
| 714 | 
             
                        "points3d": pred_pts3d,
         | 
|  | |
| 717 | 
             
                        "mask": None,
         | 
| 718 | 
             
                    }
         | 
| 719 |  | 
| 720 | 
            +
                    # Start with the final mask from predictions
         | 
| 721 | 
             
                    mask = predictions["final_mask"][view_idx].copy()
         | 
| 722 |  | 
| 723 | 
            +
                    # Apply black background filtering if enabled
         | 
| 724 | 
             
                    if filter_black_bg:
         | 
| 725 | 
            +
                        # Get the image colors (ensure they're in 0-255 range)
         | 
| 726 | 
             
                        view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
         | 
| 727 | 
            +
                        # Filter out black background pixels (sum of RGB < 16)
         | 
| 728 | 
             
                        black_bg_mask = view_colors.sum(axis=2) >= 16
         | 
| 729 | 
             
                        mask = mask & black_bg_mask
         | 
| 730 |  | 
| 731 | 
            +
                    # Apply white background filtering if enabled
         | 
| 732 | 
             
                    if filter_white_bg:
         | 
| 733 | 
            +
                        # Get the image colors (ensure they're in 0-255 range)
         | 
| 734 | 
             
                        view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
         | 
| 735 | 
            +
                        # Filter out white background pixels (all RGB > 240)
         | 
| 736 | 
             
                        white_bg_mask = ~(
         | 
| 737 | 
             
                            (view_colors[:, :, 0] > 240)
         | 
| 738 | 
             
                            & (view_colors[:, :, 1] > 240)
         | 
|  | |
| 756 | 
             
                if processed_data is None or len(processed_data) == 0:
         | 
| 757 | 
             
                    return None, [], ""
         | 
| 758 |  | 
| 759 | 
            +
                # Return the first view image
         | 
| 760 | 
             
                first_view = list(processed_data.values())[0]
         | 
| 761 | 
             
                return first_view["image"], [], ""
         | 
| 762 |  | 
|  | |
| 766 | 
             
            ):
         | 
| 767 | 
             
                """Handle measurement on images"""
         | 
| 768 | 
             
                try:
         | 
| 769 | 
            +
                    print(f"Measure function called with selector: {current_view_selector}")
         | 
| 770 |  | 
| 771 | 
             
                    if processed_data is None or len(processed_data) == 0:
         | 
| 772 | 
            +
                        return None, [], "No data available"
         | 
| 773 |  | 
| 774 | 
            +
                    # Use the currently selected view instead of always using the first view
         | 
| 775 | 
             
                    try:
         | 
| 776 | 
             
                        current_view_index = int(current_view_selector.split()[1]) - 1
         | 
| 777 | 
             
                    except:
         | 
| 778 | 
             
                        current_view_index = 0
         | 
| 779 |  | 
| 780 | 
            +
                    print(f"Using view index: {current_view_index}")
         | 
| 781 |  | 
| 782 | 
            +
                    # Get view data safely
         | 
| 783 | 
             
                    if current_view_index < 0 or current_view_index >= len(processed_data):
         | 
| 784 | 
             
                        current_view_index = 0
         | 
| 785 |  | 
|  | |
| 787 | 
             
                    current_view = processed_data[view_keys[current_view_index]]
         | 
| 788 |  | 
| 789 | 
             
                    if current_view is None:
         | 
| 790 | 
            +
                        return None, [], "No view data available"
         | 
| 791 |  | 
| 792 | 
             
                    point2d = event.index[0], event.index[1]
         | 
| 793 | 
            +
                    print(f"Clicked point: {point2d}")
         | 
| 794 |  | 
| 795 | 
            +
                    # Check if the clicked point is in a masked area (prevent interaction)
         | 
| 796 | 
             
                    if (
         | 
| 797 | 
             
                        current_view["mask"] is not None
         | 
| 798 | 
             
                        and 0 <= point2d[1] < current_view["mask"].shape[0]
         | 
| 799 | 
             
                        and 0 <= point2d[0] < current_view["mask"].shape[1]
         | 
| 800 | 
             
                    ):
         | 
| 801 | 
            +
                        # Check if the point is in a masked (invalid) area
         | 
| 802 | 
             
                        if not current_view["mask"][point2d[1], point2d[0]]:
         | 
| 803 | 
            +
                            print(f"Clicked point {point2d} is in masked area, ignoring click")
         | 
| 804 | 
            +
                            # Always return image with mask overlay
         | 
| 805 | 
             
                            masked_image, _ = update_measure_view(
         | 
| 806 | 
             
                                processed_data, current_view_index
         | 
| 807 | 
             
                            )
         | 
| 808 | 
             
                            return (
         | 
| 809 | 
             
                                masked_image,
         | 
| 810 | 
             
                                measure_points,
         | 
| 811 | 
            +
                                '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
         | 
| 812 | 
             
                            )
         | 
| 813 |  | 
| 814 | 
             
                    measure_points.append(point2d)
         | 
| 815 |  | 
| 816 | 
            +
                    # Get image with mask overlay and ensure it's valid
         | 
| 817 | 
             
                    image, _ = update_measure_view(processed_data, current_view_index)
         | 
| 818 | 
             
                    if image is None:
         | 
| 819 | 
            +
                        return None, [], "No image available"
         | 
| 820 |  | 
| 821 | 
             
                    image = image.copy()
         | 
| 822 | 
             
                    points3d = current_view["points3d"]
         | 
| 823 |  | 
| 824 | 
            +
                    # Ensure image is in uint8 format for proper cv2 operations
         | 
| 825 | 
             
                    try:
         | 
| 826 | 
             
                        if image.dtype != np.uint8:
         | 
| 827 | 
             
                            if image.max() <= 1.0:
         | 
| 828 | 
            +
                                # Image is in [0, 1] range, convert to [0, 255]
         | 
| 829 | 
             
                                image = (image * 255).astype(np.uint8)
         | 
| 830 | 
             
                            else:
         | 
| 831 | 
            +
                                # Image is already in [0, 255] range
         | 
| 832 | 
             
                                image = image.astype(np.uint8)
         | 
| 833 | 
             
                    except Exception as e:
         | 
| 834 | 
            +
                        print(f"Image conversion error: {e}")
         | 
| 835 | 
            +
                        return None, [], f"Image conversion error: {e}"
         | 
| 836 |  | 
| 837 | 
            +
                    # Draw circles for points
         | 
| 838 | 
             
                    try:
         | 
| 839 | 
             
                        for p in measure_points:
         | 
| 840 | 
             
                            if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
         | 
|  | |
| 842 | 
             
                                    image, p, radius=5, color=(255, 0, 0), thickness=2
         | 
| 843 | 
             
                                )
         | 
| 844 | 
             
                    except Exception as e:
         | 
| 845 | 
            +
                        print(f"Drawing error: {e}")
         | 
| 846 | 
            +
                        return None, [], f"Drawing error: {e}"
         | 
| 847 |  | 
| 848 | 
             
                    depth_text = ""
         | 
| 849 | 
             
                    try:
         | 
|  | |
| 854 | 
             
                                and 0 <= p[0] < current_view["depth"].shape[1]
         | 
| 855 | 
             
                            ):
         | 
| 856 | 
             
                                d = current_view["depth"][p[1], p[0]]
         | 
| 857 | 
            +
                                depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
         | 
| 858 | 
             
                            else:
         | 
| 859 | 
            +
                                # Use Z coordinate of 3D points if depth not available
         | 
| 860 | 
             
                                if (
         | 
| 861 | 
             
                                    points3d is not None
         | 
| 862 | 
             
                                    and 0 <= p[1] < points3d.shape[0]
         | 
| 863 | 
             
                                    and 0 <= p[0] < points3d.shape[1]
         | 
| 864 | 
             
                                ):
         | 
| 865 | 
             
                                    z = points3d[p[1], p[0], 2]
         | 
| 866 | 
            +
                                    depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
         | 
| 867 | 
             
                    except Exception as e:
         | 
| 868 | 
            +
                        print(f"Depth text error: {e}")
         | 
| 869 | 
            +
                        depth_text = f"Error computing depth: {e}\n"
         | 
| 870 |  | 
| 871 | 
             
                    if len(measure_points) == 2:
         | 
| 872 | 
             
                        try:
         | 
| 873 | 
             
                            point1, point2 = measure_points
         | 
| 874 | 
            +
                            # Draw line
         | 
| 875 | 
             
                            if (
         | 
| 876 | 
             
                                0 <= point1[0] < image.shape[1]
         | 
| 877 | 
             
                                and 0 <= point1[1] < image.shape[0]
         | 
|  | |
| 882 | 
             
                                    image, point1, point2, color=(255, 0, 0), thickness=2
         | 
| 883 | 
             
                                )
         | 
| 884 |  | 
| 885 | 
            +
                            # Compute 3D distance
         | 
| 886 | 
            +
                            distance_text = "- **Distance: Unable to compute**"
         | 
| 887 | 
             
                            if (
         | 
| 888 | 
             
                                points3d is not None
         | 
| 889 | 
             
                                and 0 <= point1[1] < points3d.shape[0]
         | 
|  | |
| 895 | 
             
                                    p1_3d = points3d[point1[1], point1[0]]
         | 
| 896 | 
             
                                    p2_3d = points3d[point2[1], point2[0]]
         | 
| 897 | 
             
                                    distance = np.linalg.norm(p1_3d - p2_3d)
         | 
| 898 | 
            +
                                    distance_text = f"- **Distance: {distance:.2f}m**"
         | 
| 899 | 
             
                                except Exception as e:
         | 
| 900 | 
            +
                                    print(f"Distance computation error: {e}")
         | 
| 901 | 
            +
                                    distance_text = f"- **Distance computation error: {e}**"
         | 
| 902 |  | 
| 903 | 
             
                            measure_points = []
         | 
| 904 | 
             
                            text = depth_text + distance_text
         | 
| 905 | 
            +
                            print(f"Measurement complete: {text}")
         | 
| 906 | 
             
                            return [image, measure_points, text]
         | 
| 907 | 
             
                        except Exception as e:
         | 
| 908 | 
            +
                            print(f"Final measurement error: {e}")
         | 
| 909 | 
            +
                            return None, [], f"Measurement error: {e}"
         | 
| 910 | 
             
                    else:
         | 
| 911 | 
            +
                        print(f"Single point measurement: {depth_text}")
         | 
| 912 | 
             
                        return [image, measure_points, depth_text]
         | 
| 913 |  | 
| 914 | 
             
                except Exception as e:
         | 
| 915 | 
            +
                    print(f"Overall measure function error: {e}")
         | 
| 916 | 
            +
                    return None, [], f"Measure function error: {e}"
         | 
| 917 |  | 
| 918 |  | 
| 919 | 
             
            def clear_fields():
         | 
| 920 | 
            +
                """
         | 
| 921 | 
            +
                Clears the 3D viewer, the stored target_dir, and empties the gallery.
         | 
| 922 | 
            +
                """
         | 
| 923 | 
            +
                return None
         | 
| 924 |  | 
| 925 |  | 
| 926 | 
             
            def update_log():
         | 
| 927 | 
            +
                """
         | 
| 928 | 
            +
                Display a quick log message while waiting.
         | 
| 929 | 
            +
                """
         | 
| 930 | 
            +
                return "加载和重建中..."
         | 
| 931 |  | 
| 932 |  | 
| 933 | 
             
            def update_visualization(
         | 
|  | |
| 939 | 
             
                filter_white_bg=False,
         | 
| 940 | 
             
                show_mesh=True,
         | 
| 941 | 
             
            ):
         | 
| 942 | 
            +
                """
         | 
| 943 | 
            +
                Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
         | 
| 944 | 
            +
                and return it for the 3D viewer. If is_example == "True", skip.
         | 
| 945 | 
            +
                """
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                # If it's an example click, skip as requested
         | 
| 948 | 
             
                if is_example == "True":
         | 
| 949 | 
            +
                    return (
         | 
| 950 | 
            +
                        gr.update(),
         | 
| 951 | 
            +
                        "没有可用的重建。请先点击重建按钮。",
         | 
| 952 | 
            +
                    )
         | 
| 953 |  | 
| 954 | 
             
                if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
         | 
| 955 | 
            +
                    return (
         | 
| 956 | 
            +
                        gr.update(),
         | 
| 957 | 
            +
                        "没有可用的重建。请先点击重建按钮。",
         | 
| 958 | 
            +
                    )
         | 
| 959 |  | 
| 960 | 
             
                predictions_path = os.path.join(target_dir, "predictions.npz")
         | 
| 961 | 
             
                if not os.path.exists(predictions_path):
         | 
| 962 | 
            +
                    return (
         | 
| 963 | 
            +
                        gr.update(),
         | 
| 964 | 
            +
                        f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
         | 
| 965 | 
            +
                    )
         | 
| 966 |  | 
| 967 | 
             
                loaded = np.load(predictions_path, allow_pickle=True)
         | 
| 968 | 
             
                predictions = {key: loaded[key] for key in loaded.keys()}
         | 
|  | |
| 972 | 
             
                    f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
         | 
| 973 | 
             
                )
         | 
| 974 |  | 
| 975 | 
            +
                if not os.path.exists(glbfile):
         | 
| 976 | 
            +
                    glbscene = predictions_to_glb(
         | 
| 977 | 
            +
                        predictions,
         | 
| 978 | 
            +
                        filter_by_frames=frame_filter,
         | 
| 979 | 
            +
                        show_cam=show_cam,
         | 
| 980 | 
            +
                        mask_black_bg=filter_black_bg,
         | 
| 981 | 
            +
                        mask_white_bg=filter_white_bg,
         | 
| 982 | 
            +
                        as_mesh=show_mesh,
         | 
| 983 | 
            +
                    )
         | 
| 984 | 
            +
                    glbscene.export(file_obj=glbfile)
         | 
| 985 |  | 
| 986 | 
            +
                return (
         | 
| 987 | 
            +
                    glbfile,
         | 
| 988 | 
            +
                    "可视化已更新",
         | 
| 989 | 
            +
                )
         | 
| 990 |  | 
| 991 |  | 
| 992 | 
             
            def update_all_views_on_filter_change(
         | 
|  | |
| 998 | 
             
                normal_view_selector,
         | 
| 999 | 
             
                measure_view_selector,
         | 
| 1000 | 
             
            ):
         | 
| 1001 | 
            +
                """
         | 
| 1002 | 
            +
                Update all individual view tabs when background filtering checkboxes change.
         | 
| 1003 | 
            +
                This regenerates the processed data with new filtering and updates all views.
         | 
| 1004 | 
            +
                """
         | 
| 1005 | 
            +
                # Check if we have a valid target directory and predictions
         | 
| 1006 | 
             
                if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
         | 
| 1007 | 
             
                    return processed_data, None, None, None, []
         | 
| 1008 |  | 
|  | |
| 1011 | 
             
                    return processed_data, None, None, None, []
         | 
| 1012 |  | 
| 1013 | 
             
                try:
         | 
| 1014 | 
            +
                    # Load the original predictions and views
         | 
| 1015 | 
             
                    loaded = np.load(predictions_path, allow_pickle=True)
         | 
| 1016 | 
             
                    predictions = {key: loaded[key] for key in loaded.keys()}
         | 
| 1017 |  | 
| 1018 | 
            +
                    # Load images using MapAnything's load_images function
         | 
| 1019 | 
             
                    image_folder_path = os.path.join(target_dir, "images")
         | 
| 1020 | 
             
                    views = load_images(image_folder_path)
         | 
| 1021 |  | 
| 1022 | 
            +
                    # Regenerate processed data with new filtering settings
         | 
| 1023 | 
             
                    new_processed_data = process_predictions_for_visualization(
         | 
| 1024 | 
             
                        predictions, views, high_level_config, filter_black_bg, filter_white_bg
         | 
| 1025 | 
             
                    )
         | 
| 1026 |  | 
| 1027 | 
            +
                    # Get current view indices
         | 
| 1028 | 
             
                    try:
         | 
| 1029 | 
             
                        depth_view_idx = (
         | 
| 1030 | 
             
                            int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
         | 
|  | |
| 1048 | 
             
                    except:
         | 
| 1049 | 
             
                        measure_view_idx = 0
         | 
| 1050 |  | 
| 1051 | 
            +
                    # Update all views with new filtered data
         | 
| 1052 | 
             
                    depth_vis = update_depth_view(new_processed_data, depth_view_idx)
         | 
| 1053 | 
             
                    normal_vis = update_normal_view(new_processed_data, normal_view_idx)
         | 
| 1054 | 
             
                    measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
         | 
|  | |
| 1060 | 
             
                    return processed_data, None, None, None, []
         | 
| 1061 |  | 
| 1062 |  | 
| 1063 | 
            +
            # -------------------------------------------------------------------------
         | 
| 1064 | 
            +
            # Example scene functions
         | 
| 1065 | 
            +
            # -------------------------------------------------------------------------
         | 
|  | |
| 1066 | 
             
            def get_scene_info(examples_dir):
         | 
| 1067 | 
             
                """Get information about scenes in the examples directory"""
         | 
| 1068 | 
             
                import glob
         | 
|  | |
| 1074 | 
             
                for scene_folder in sorted(os.listdir(examples_dir)):
         | 
| 1075 | 
             
                    scene_path = os.path.join(examples_dir, scene_folder)
         | 
| 1076 | 
             
                    if os.path.isdir(scene_path):
         | 
| 1077 | 
            +
                        # Find all image files in the scene folder
         | 
| 1078 | 
             
                        image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
         | 
| 1079 | 
             
                        image_files = []
         | 
| 1080 | 
             
                        for ext in image_extensions:
         | 
|  | |
| 1082 | 
             
                            image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
         | 
| 1083 |  | 
| 1084 | 
             
                        if image_files:
         | 
| 1085 | 
            +
                            # Sort images and get the first one for thumbnail
         | 
| 1086 | 
             
                            image_files = sorted(image_files)
         | 
| 1087 | 
             
                            first_image = image_files[0]
         | 
| 1088 | 
             
                            num_images = len(image_files)
         | 
|  | |
| 1101 |  | 
| 1102 |  | 
| 1103 | 
             
            def load_example_scene(scene_name, examples_dir="examples"):
         | 
| 1104 | 
            +
                """Load a scene from examples directory"""
         | 
| 1105 | 
             
                scenes = get_scene_info(examples_dir)
         | 
| 1106 |  | 
| 1107 | 
            +
                # Find the selected scene
         | 
| 1108 | 
             
                selected_scene = None
         | 
| 1109 | 
             
                for scene in scenes:
         | 
| 1110 | 
             
                    if scene["name"] == scene_name:
         | 
|  | |
| 1112 | 
             
                        break
         | 
| 1113 |  | 
| 1114 | 
             
                if selected_scene is None:
         | 
| 1115 | 
            +
                    return None, None, None, "Scene not found"
         | 
| 1116 |  | 
| 1117 | 
            +
                # Create file-like objects for the unified upload system
         | 
| 1118 | 
            +
                # Convert image file paths to the format expected by unified_upload
         | 
| 1119 | 
             
                file_objects = []
         | 
| 1120 | 
             
                for image_path in selected_scene["image_files"]:
         | 
| 1121 | 
             
                    file_objects.append(image_path)
         | 
| 1122 |  | 
| 1123 | 
            +
                # Create target directory and copy images using the unified upload system
         | 
| 1124 | 
             
                target_dir, image_paths = handle_uploads(file_objects, 1.0)
         | 
| 1125 |  | 
| 1126 | 
             
                return (
         | 
| 1127 | 
            +
                    None,  # Clear reconstruction output
         | 
| 1128 | 
            +
                    target_dir,  # Set target directory
         | 
| 1129 | 
            +
                    image_paths,  # Set gallery
         | 
| 1130 | 
            +
                    f"已加载场景 '{scene_name}'({selected_scene['num_images']} 张图片)。点击「开始重建」进行3D处理。",
         | 
| 1131 | 
             
                )
         | 
| 1132 |  | 
| 1133 |  | 
| 1134 | 
            +
            # -------------------------------------------------------------------------
         | 
| 1135 | 
            +
            # 6) Build Gradio UI
         | 
| 1136 | 
            +
            # -------------------------------------------------------------------------
         | 
|  | |
| 1137 | 
             
            theme = get_gradio_theme()
         | 
| 1138 |  | 
| 1139 | 
             
            # 自定义CSS防止UI抖动
         | 
|  | |
| 1196 | 
             
            }
         | 
| 1197 | 
             
            """
         | 
| 1198 |  | 
| 1199 | 
            +
            with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything - 3D重建系统") as demo:
         | 
| 1200 | 
            +
                # State variables for the tabbed interface
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1201 | 
             
                is_example = gr.Textbox(label="is_example", visible=False, value="None")
         | 
| 1202 | 
            +
                num_images = gr.Textbox(label="num_images", visible=False, value="None")
         | 
| 1203 | 
             
                processed_data_state = gr.State(value=None)
         | 
| 1204 | 
             
                measure_points_state = gr.State(value=[])
         | 
| 1205 | 
            +
                current_view_index = gr.State(value=0)  # Track current view index for navigation
         | 
| 1206 |  | 
| 1207 | 
             
                # 添加粘贴板支持的 JavaScript
         | 
| 1208 | 
            +
                PASTE_JS = """
         | 
| 1209 | 
            +
                <script>
         | 
| 1210 | 
            +
                // 添加粘贴板支持
         | 
| 1211 | 
            +
                document.addEventListener('paste', function(e) {
         | 
| 1212 | 
            +
                    const items = e.clipboardData.items;
         | 
| 1213 | 
            +
                    for (let i = 0; i < items.length; i++) {
         | 
| 1214 | 
            +
                        if (items[i].type.indexOf('image') !== -1) {
         | 
| 1215 | 
            +
                            const blob = items[i].getAsFile();
         | 
| 1216 | 
            +
                            const fileInput = document.querySelector('input[type="file"][multiple]');
         | 
| 1217 | 
            +
                            if (fileInput) {
         | 
| 1218 | 
            +
                                const dataTransfer = new DataTransfer();
         | 
| 1219 | 
            +
                                dataTransfer.items.add(blob);
         | 
| 1220 | 
            +
                                fileInput.files = dataTransfer.files;
         | 
| 1221 | 
            +
                                fileInput.dispatchEvent(new Event('change', { bubbles: true }));
         | 
| 1222 | 
            +
                                console.log('✅ 图片已从剪贴板粘贴');
         | 
| 1223 | 
            +
                            }
         | 
| 1224 | 
            +
                        }
         | 
| 1225 | 
            +
                    }
         | 
| 1226 | 
            +
                });
         | 
| 1227 | 
            +
                console.log('💡 粘贴板功能已启用:使用 Ctrl+V 可直接粘贴截图');
         | 
| 1228 | 
            +
                </script>
         | 
| 1229 | 
            +
                """
         | 
| 1230 | 
             
                gr.HTML(PASTE_JS)
         | 
| 1231 | 
            +
                
         | 
| 1232 | 
            +
                # 美化的顶部标题
         | 
| 1233 | 
             
                gr.HTML("""
         | 
| 1234 | 
             
                <div style="text-align: center; margin: 20px 0;">
         | 
| 1235 | 
            +
                    <h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything - 3D重建系统</h2>
         | 
| 1236 | 
            +
                    <p style="color: #666; font-size: 16px;">多视图3D重建 | 深度估计 | 法线计算 | 距离测量</p>
         | 
| 1237 | 
             
                </div>
         | 
| 1238 | 
             
                """)
         | 
| 1239 |  | 
|  | |
| 1306 | 
             
                                    clear_color=[0.0, 0.0, 0.0, 0.0]
         | 
| 1307 | 
             
                                )
         | 
| 1308 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1309 | 
             
                            with gr.Tab("📊 深度图"):
         | 
| 1310 | 
             
                                with gr.Row(elem_classes=["navigation-row"]):
         | 
| 1311 | 
             
                                    prev_depth_btn = gr.Button("◀", size="sm", scale=1)
         | 
|  | |
| 1356 | 
             
                            max_lines=1
         | 
| 1357 | 
             
                        )
         | 
| 1358 |  | 
| 1359 | 
            +
                # 高级选项(默认折叠)
         | 
| 1360 | 
            +
                with gr.Accordion("⚙️ 高级选项", open=False):
         | 
| 1361 | 
             
                    with gr.Row(equal_height=False):
         | 
| 1362 | 
             
                        with gr.Column(scale=1, min_width=300):
         | 
| 1363 | 
             
                            gr.Markdown("#### 可视化参数")
         | 
|  | |
| 1374 | 
             
                            apply_mask_checkbox = gr.Checkbox(
         | 
| 1375 | 
             
                                label="应用深度掩码", value=True
         | 
| 1376 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1377 | 
             
                # 示例场景(可折叠)
         | 
| 1378 | 
             
                with gr.Accordion("🖼️ 示例场景", open=False):
         | 
| 1379 | 
            +
                    gr.Markdown("点击缩略图加载场景进行重建")
         | 
| 1380 | 
             
                    scenes = get_scene_info("examples")
         | 
| 1381 | 
            +
                    
         | 
| 1382 | 
             
                    if scenes:
         | 
| 1383 | 
            +
                        for i in range(0, len(scenes), 4):  # Process 4 scenes per row
         | 
| 1384 | 
             
                            with gr.Row(equal_height=True):
         | 
| 1385 | 
             
                                for j in range(4):
         | 
| 1386 | 
             
                                    scene_idx = i + j
         | 
|  | |
| 1388 | 
             
                                        scene = scenes[scene_idx]
         | 
| 1389 | 
             
                                        with gr.Column(scale=1, min_width=150):
         | 
| 1390 | 
             
                                            scene_img = gr.Image(
         | 
| 1391 | 
            +
                                                value=scene["thumbnail"],
         | 
| 1392 | 
             
                                                height=150,
         | 
| 1393 | 
            +
                                                interactive=False,
         | 
| 1394 | 
            +
                                                show_label=False,
         | 
| 1395 | 
             
                                                sources=[],
         | 
| 1396 | 
             
                                                container=False
         | 
| 1397 | 
             
                                            )
         | 
|  | |
| 1403 | 
             
                                                fn=lambda name=scene["name"]: load_example_scene(name),
         | 
| 1404 | 
             
                                                outputs=[
         | 
| 1405 | 
             
                                                    reconstruction_output,
         | 
| 1406 | 
            +
                                                    target_dir_output,
         | 
| 1407 | 
            +
                                                    image_gallery,
         | 
| 1408 | 
            +
                                                    log_output,
         | 
| 1409 | 
            +
                                                ],
         | 
| 1410 | 
             
                                            )
         | 
| 1411 |  | 
| 1412 | 
             
                # === 事件绑定 ===
         | 
| 1413 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1414 | 
             
                # 上传文件自动更新
         | 
| 1415 | 
             
                def update_gallery_on_unified_upload(files, interval):
         | 
| 1416 | 
             
                    if not files:
         | 
|  | |
| 1540 | 
             
                # 重建按钮
         | 
| 1541 | 
             
                submit_btn.click(
         | 
| 1542 | 
             
                    fn=clear_fields,
         | 
| 1543 | 
            +
                    outputs=[reconstruction_output]
         | 
| 1544 | 
             
                ).then(
         | 
| 1545 | 
             
                    fn=update_log,
         | 
| 1546 | 
             
                    outputs=[log_output]
         | 
|  | |
| 1549 | 
             
                    inputs=[
         | 
| 1550 | 
             
                        target_dir_output, frame_filter, show_cam,
         | 
| 1551 | 
             
                        filter_black_bg, filter_white_bg,
         | 
| 1552 | 
            +
                        apply_mask_checkbox, show_mesh
         | 
|  | |
| 1553 | 
             
                    ],
         | 
| 1554 | 
             
                    outputs=[
         | 
| 1555 | 
            +
                        reconstruction_output, log_output, frame_filter,
         | 
| 1556 | 
             
                        processed_data_state, depth_map, normal_map, measure_image,
         | 
| 1557 | 
             
                        measure_text, depth_view_selector, normal_view_selector, measure_view_selector
         | 
| 1558 | 
             
                    ]
         | 
|  | |
| 1562 | 
             
                )
         | 
| 1563 |  | 
| 1564 | 
             
                # 清空按钮
         | 
| 1565 | 
            +
                clear_btn.add([reconstruction_output, log_output])
         | 
| 1566 | 
            +
             | 
| 1567 | 
             
                # 可视化参数实时更新
         | 
| 1568 | 
             
                for component in [frame_filter, show_cam, show_mesh]:
         | 
| 1569 | 
             
                    component.change(
         | 
|  | |
| 1585 | 
             
                        ],
         | 
| 1586 | 
             
                        outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
         | 
| 1587 | 
             
                    )
         | 
| 1588 | 
            +
             | 
| 1589 | 
             
                # 深度图导航
         | 
| 1590 | 
             
                prev_depth_btn.click(
         | 
| 1591 | 
             
                    fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
         | 
|  | |
| 1642 | 
             
                    outputs=[measure_image, measure_points_state]
         | 
| 1643 | 
             
                )
         | 
| 1644 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1645 | 
             
            demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)
         | 
