ColamanAI commited on
Commit
d7a13ed
·
verified ·
1 Parent(s): 3e7f3f4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -1228
app.py CHANGED
@@ -5,11 +5,10 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  """
8
- MapAnything V2: 3D Reconstruction with Object Segmentation
9
- - Multi-view 3D reconstruction
10
- - GroundingDINO object detection
11
- - SAM precise segmentation
12
- - DBSCAN clustering for cross-view object matching
13
  """
14
 
15
  import gc
@@ -18,8 +17,6 @@ import shutil
18
  import sys
19
  import time
20
  from datetime import datetime
21
- from pathlib import Path
22
- from collections import defaultdict
23
 
24
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
25
 
@@ -28,10 +25,8 @@ import gradio as gr
28
  import numpy as np
29
  import spaces
30
  import torch
31
- import trimesh
32
  from PIL import Image
33
  from pillow_heif import register_heif_opener
34
- from sklearn.cluster import DBSCAN
35
 
36
  register_heif_opener()
37
 
@@ -65,10 +60,6 @@ def get_logo_base64():
65
  return None
66
 
67
 
68
- # ============================================================================
69
- # Configuration
70
- # ============================================================================
71
-
72
  # MapAnything Configuration
73
  high_level_config = {
74
  "path": "configs/train.yaml",
@@ -89,846 +80,13 @@ high_level_config = {
89
  "resolution": 518,
90
  }
91
 
92
- # ============ 分割模型配置 ============
93
- # 方案选择:
94
- # 1. "segformer" - SegFormer (最轻量,~14MB,最快)
95
- # 2. "maskformer" - MaskFormer (中等,~100MB,实例分割)
96
- # 3. "grounding_sam" - GroundingDINO + SAM (最强,~110MB,文本提示)
97
-
98
- SEGMENTATION_METHOD = "segformer" # 默认使用最轻量的方案
99
-
100
- # SegFormer Configuration (推荐 - CPU友好)
101
- SEGFORMER_MODEL_ID = "nvidia/segformer-b0-finetuned-ade-512-512" # 14MB,150类物体
102
-
103
- # MaskFormer Configuration (备选)
104
- MASKFORMER_MODEL_ID = "facebook/maskformer-swin-tiny-ade" # 100MB,实例分割
105
-
106
- # GroundingDINO + SAM Configuration (原方案 - 需要文本提示)
107
- GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
108
- GROUNDING_DINO_BOX_THRESHOLD = 0.25
109
- GROUNDING_DINO_TEXT_THRESHOLD = 0.2
110
- SAM_MODEL_ID = "dhkim2810/MobileSAM"
111
- USE_MOBILE_SAM = True
112
-
113
- DEFAULT_TEXT_PROMPT = "chair . table . sofa . bed . desk . cabinet"
114
-
115
- # Common objects prompt for detection
116
- COMMON_OBJECTS_PROMPT = (
117
- "person . face . hand . "
118
- "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . "
119
- "door . window . wall . floor . ceiling . curtain . "
120
- "tv . monitor . screen . computer . laptop . keyboard . mouse . "
121
- "phone . tablet . remote . "
122
- "lamp . light . chandelier . "
123
- "book . magazine . paper . pen . pencil . "
124
- "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . "
125
- "vase . plant . flower . pot . "
126
- "clock . picture . frame . mirror . "
127
- "pillow . cushion . blanket . towel . "
128
- "bag . backpack . suitcase . "
129
- "box . basket . container . "
130
- "shoe . hat . coat . "
131
- "toy . ball . "
132
- "car . bicycle . motorcycle . bus . truck . "
133
- "tree . grass . sky . cloud . sun . "
134
- "dog . cat . bird . "
135
- "building . house . bridge . road . street . "
136
- "sign . pole . bench"
137
- )
138
-
139
- # DBSCAN clustering configuration (eps in meters)
140
- DBSCAN_EPS_CONFIG = {
141
- 'sofa': 1.5,
142
- 'bed': 1.5,
143
- 'couch': 1.5,
144
- 'desk': 0.8,
145
- 'table': 0.8,
146
- 'chair': 0.6,
147
- 'cabinet': 0.8,
148
- 'window': 0.5,
149
- 'door': 0.6,
150
- 'tv': 0.6,
151
- 'default': 1.0
152
- }
153
-
154
- DBSCAN_MIN_SAMPLES = 1
155
-
156
- # Quality control
157
- MIN_DETECTION_CONFIDENCE = 0.35
158
- MIN_MASK_AREA = 100
159
-
160
- # Global model variables
161
  model = None
162
- grounding_dino_model = None
163
- grounding_dino_processor = None
164
- sam_predictor = None
165
-
166
- # SegFormer 模型(轻量级语义分割)
167
- segformer_processor = None
168
- segformer_model = None
169
-
170
- # MaskFormer 模型(实例分割)
171
- maskformer_processor = None
172
- maskformer_model = None
173
-
174
-
175
- # ============================================================================
176
- # Model Loading Functions
177
- # ============================================================================
178
-
179
- def load_segformer_model(device="cpu"):
180
- """加载 SegFormer 模型(最轻量,CPU友好)"""
181
- global segformer_processor, segformer_model
182
-
183
- if segformer_model is not None:
184
- print("✅ SegFormer already loaded")
185
- return
186
-
187
- try:
188
- from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
189
- import os
190
-
191
- print(f"📥 Loading SegFormer from HuggingFace: {SEGFORMER_MODEL_ID}")
192
- print(f" 💡 SegFormer-B0: ~14MB, 150类物体, CPU优化")
193
-
194
- cache_dir = os.getenv("HF_HOME", "./hf_cache")
195
-
196
- print(f" 正在下载 processor...")
197
- segformer_processor = SegformerImageProcessor.from_pretrained(
198
- SEGFORMER_MODEL_ID,
199
- cache_dir=cache_dir
200
- )
201
-
202
- print(f" 正在下载 model...")
203
- segformer_model = SegformerForSemanticSegmentation.from_pretrained(
204
- SEGFORMER_MODEL_ID,
205
- cache_dir=cache_dir,
206
- low_cpu_mem_usage=True
207
- ).to(device).eval()
208
-
209
- print(f"✅ SegFormer loaded successfully on {device.upper()}")
210
- print(f" 可识别类别: 人、家具、墙壁、地板等150类")
211
-
212
- except Exception as e:
213
- print(f"❌ SegFormer loading failed: {type(e).__name__}: {e}")
214
- import traceback
215
- traceback.print_exc()
216
-
217
-
218
- def load_maskformer_model(device="cpu"):
219
- """加载 MaskFormer 模型(实例分割)"""
220
- global maskformer_processor, maskformer_model
221
-
222
- if maskformer_model is not None:
223
- print("✅ MaskFormer already loaded")
224
- return
225
-
226
- try:
227
- from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
228
- import os
229
-
230
- print(f"📥 Loading MaskFormer from HuggingFace: {MASKFORMER_MODEL_ID}")
231
- print(f" 💡 MaskFormer: ~100MB, 实例分割")
232
-
233
- cache_dir = os.getenv("HF_HOME", "./hf_cache")
234
-
235
- print(f" 正在下载 processor...")
236
- maskformer_processor = MaskFormerImageProcessor.from_pretrained(
237
- MASKFORMER_MODEL_ID,
238
- cache_dir=cache_dir
239
- )
240
-
241
- print(f" 正在下载 model...")
242
- maskformer_model = MaskFormerForInstanceSegmentation.from_pretrained(
243
- MASKFORMER_MODEL_ID,
244
- cache_dir=cache_dir,
245
- low_cpu_mem_usage=True
246
- ).to(device).eval()
247
-
248
- print(f"✅ MaskFormer loaded successfully on {device.upper()}")
249
-
250
- except Exception as e:
251
- print(f"❌ MaskFormer loading failed: {type(e).__name__}: {e}")
252
- import traceback
253
- traceback.print_exc()
254
-
255
- def load_grounding_dino_model(device="cpu"):
256
- """Load GroundingDINO model from HuggingFace (CPU优化)"""
257
- global grounding_dino_model, grounding_dino_processor
258
-
259
- if grounding_dino_model is not None:
260
- print("✅ GroundingDINO already loaded")
261
- return
262
-
263
- try:
264
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
265
- import os
266
-
267
- # 强制使用 CPU 进行分割(节省 GPU 资源)
268
- seg_device = "cpu"
269
- print(f"📥 Loading GroundingDINO from HuggingFace: {GROUNDING_DINO_MODEL_ID} (使用 {seg_device.upper()})")
270
-
271
- # 设置缓存目录(HuggingFace Spaces友好)
272
- cache_dir = os.getenv("HF_HOME", "./hf_cache")
273
-
274
- # 加载模型(带重试和详细日志)
275
- print(f" 正在下载 processor...")
276
- grounding_dino_processor = AutoProcessor.from_pretrained(
277
- GROUNDING_DINO_MODEL_ID,
278
- cache_dir=cache_dir,
279
- trust_remote_code=True # 允许运行远程代码
280
- )
281
-
282
- print(f" 正在下载 model...")
283
- grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
284
- GROUNDING_DINO_MODEL_ID,
285
- cache_dir=cache_dir,
286
- trust_remote_code=True,
287
- low_cpu_mem_usage=True # 降低CPU内存使用
288
- ).to(seg_device).eval()
289
-
290
- print(f"✅ GroundingDINO loaded successfully on {seg_device.upper()}")
291
-
292
- except ImportError as e:
293
- print(f"❌ ImportError: {e}")
294
- print(f"💡 请检查 requirements.txt 是否包含 transformers 库")
295
- import traceback
296
- traceback.print_exc()
297
- except OSError as e:
298
- print(f"❌ OSError (网络/文件问题): {e}")
299
- print(f"💡 可能是网络连接问题或模型仓库不可访问")
300
- print(f"💡 尝试解决方案:")
301
- print(f" 1. 检查 HuggingFace Spaces 的网络连接")
302
- print(f" 2. 检查模型ID是否正确: {GROUNDING_DINO_MODEL_ID}")
303
- print(f" 3. 确保有足够的磁盘空间")
304
- import traceback
305
- traceback.print_exc()
306
- except Exception as e:
307
- print(f"❌ GroundingDINO loading failed: {type(e).__name__}: {e}")
308
- import traceback
309
- traceback.print_exc()
310
-
311
-
312
- def load_sam_model(device="cpu"):
313
- """Load MobileSAM model from HuggingFace (CPU优化,比SAM快60倍)"""
314
- global sam_predictor
315
-
316
- if sam_predictor is not None:
317
- print("✅ SAM already loaded")
318
- return
319
-
320
- try:
321
- from transformers import SamModel, SamProcessor
322
- import os
323
-
324
- # 强制使用 CPU 进行分割(MobileSAM 专为移动设备/CPU优化)
325
- seg_device = "cpu"
326
- print(f"📥 Loading MobileSAM from HuggingFace: {SAM_MODEL_ID} (使用 {seg_device.upper()})")
327
- print(f" 💡 MobileSAM 是轻量级版本,比 SAM-huge 快60倍,只有10MB,适合CPU运行")
328
-
329
- # 设置缓存目录
330
- cache_dir = os.getenv("HF_HOME", "./hf_cache")
331
-
332
- print(f" 正在下载 processor...")
333
- sam_processor = SamProcessor.from_pretrained(
334
- SAM_MODEL_ID,
335
- cache_dir=cache_dir
336
- )
337
-
338
- print(f" 正在下载 model...")
339
- sam_model = SamModel.from_pretrained(
340
- SAM_MODEL_ID,
341
- cache_dir=cache_dir,
342
- low_cpu_mem_usage=True
343
- ).to(seg_device).eval()
344
-
345
- # Wrap in a predictor-like interface
346
- class SAMPredictor:
347
- def __init__(self, model, processor, device):
348
- self.model = model
349
- self.processor = processor
350
- self.device = device
351
- self.image = None
352
-
353
- def set_image(self, image):
354
- """Set image for prediction"""
355
- if image.dtype == np.uint8:
356
- self.image = Image.fromarray(image)
357
- else:
358
- self.image = Image.fromarray((image * 255).astype(np.uint8))
359
-
360
- def predict(self, box, multimask_output=False):
361
- """Predict mask from box (CPU优化)"""
362
- inputs = self.processor(
363
- self.image,
364
- input_boxes=[[[box]]],
365
- return_tensors="pt"
366
- )
367
- # 确保在CPU上运行
368
- inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
369
-
370
- with torch.no_grad():
371
- outputs = self.model(**inputs)
372
-
373
- masks = self.processor.image_processor.post_process_masks(
374
- outputs.pred_masks.cpu(),
375
- inputs["original_sizes"].cpu() if "original_sizes" in inputs else outputs.pred_masks.new_tensor([[self.image.height, self.image.width]]),
376
- inputs["reshaped_input_sizes"].cpu() if "reshaped_input_sizes" in inputs else outputs.pred_masks.new_tensor([[self.image.height, self.image.width]])
377
- )[0].squeeze().numpy()
378
-
379
- if len(masks.shape) == 2:
380
- masks = masks[np.newaxis, ...]
381
-
382
- return masks, None, None
383
-
384
- sam_predictor = SAMPredictor(sam_model, sam_processor, seg_device)
385
- print(f"✅ MobileSAM loaded successfully on {seg_device.upper()}")
386
-
387
- except ImportError as e:
388
- print(f"❌ ImportError: {e}")
389
- print(f"💡 请检查 requirements.txt 是否包含 transformers 库")
390
- import traceback
391
- traceback.print_exc()
392
- except OSError as e:
393
- print(f"❌ OSError (网络/文件问题): {e}")
394
- print(f"💡 可能是网络连接问题或模型仓库不可访问")
395
- print(f"💡 尝试解决方案:")
396
- print(f" 1. 检查 HuggingFace Spaces 的网络连接")
397
- print(f" 2. 检查模型ID是否正确: {SAM_MODEL_ID}")
398
- print(f" 3. 确保有足够的磁盘空间")
399
- import traceback
400
- traceback.print_exc()
401
- except Exception as e:
402
- print(f"❌ SAM loading failed: {type(e).__name__}: {e}")
403
- import traceback
404
- traceback.print_exc()
405
-
406
-
407
- # ============================================================================
408
- # Segmentation Functions
409
- # ============================================================================
410
-
411
- def generate_distinct_colors(n):
412
- """Generate N visually distinct colors (RGB, 0-255)"""
413
- import colorsys
414
- if n == 0:
415
- return []
416
-
417
- colors = []
418
- for i in range(n):
419
- hue = i / max(n, 1)
420
- rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
421
- rgb_color = tuple(int(c * 255) for c in rgb)
422
- colors.append(rgb_color)
423
-
424
- return colors
425
-
426
-
427
- # ============================================================================
428
- # SegFormer 分割函数(简化方案)
429
- # ============================================================================
430
-
431
- def run_segformer_segmentation(image_np, device="cpu"):
432
- """使用 SegFormer 进行语义分割(最简单,CPU友好)"""
433
- if segformer_model is None or segformer_processor is None:
434
- print("❌ SegFormer model not loaded")
435
- return []
436
-
437
- try:
438
- import torch
439
- from PIL import Image
440
-
441
- # 准备图片
442
- if image_np.dtype != np.uint8:
443
- image_np = (image_np * 255).astype(np.uint8)
444
- image_pil = Image.fromarray(image_np)
445
-
446
- # 推理
447
- inputs = segformer_processor(images=image_pil, return_tensors="pt")
448
- inputs = {k: v.to(device) for k, v in inputs.items()}
449
-
450
- with torch.no_grad():
451
- outputs = segformer_model(**inputs)
452
-
453
- # 获取分割结果
454
- logits = outputs.logits # (1, num_classes, H, W)
455
- predicted_segmentation = logits.argmax(dim=1).squeeze().cpu().numpy()
456
-
457
- # 生成实例掩码(将相同类别的连续区域分开)
458
- from scipy import ndimage
459
-
460
- # ADE20K 常见类别映射(部分)
461
- ade20k_labels = {
462
- 5: "wall", 7: "floor", 11: "ceiling", 18: "window", 14: "door",
463
- 19: "table", 20: "chair", 22: "sofa", 23: "bed", 28: "cabinet",
464
- 34: "desk", 39: "lamp", 65: "television", 89: "shelf"
465
- }
466
-
467
- detections = []
468
- masks = []
469
-
470
- # 对每个类别提取实例
471
- unique_labels = np.unique(predicted_segmentation)
472
- for label_id in unique_labels:
473
- if label_id == 0: # 跳过背景
474
- continue
475
-
476
- # 获取该类别的掩码
477
- class_mask = (predicted_segmentation == label_id)
478
-
479
- # 分离连通区域(不同实例)
480
- labeled_mask, num_features = ndimage.label(class_mask)
481
-
482
- for instance_id in range(1, num_features + 1):
483
- instance_mask = (labeled_mask == instance_id)
484
- mask_area = instance_mask.sum()
485
-
486
- # 过滤小区域
487
- if mask_area < MIN_MASK_AREA:
488
- continue
489
-
490
- # 计算边界框
491
- rows, cols = np.where(instance_mask)
492
- if len(rows) == 0:
493
- continue
494
-
495
- y_min, y_max = rows.min(), rows.max()
496
- x_min, x_max = cols.min(), cols.max()
497
- bbox = [x_min, y_min, x_max, y_max]
498
-
499
- # 获取类别名称
500
- label_name = ade20k_labels.get(int(label_id), f"object_{label_id}")
501
-
502
- detections.append({
503
- 'bbox': bbox,
504
- 'label': label_name,
505
- 'confidence': 0.9, # SegFormer 不提供置信度,给固定值
506
- 'class_id': int(label_id)
507
- })
508
- masks.append(instance_mask)
509
-
510
- return detections, masks
511
-
512
- except Exception as e:
513
- print(f"❌ SegFormer segmentation failed: {e}")
514
- import traceback
515
- traceback.print_exc()
516
- return [], []
517
 
518
 
519
- def run_grounding_dino_detection(image_np, text_prompt, device="cpu"):
520
- """Run GroundingDINO detection (CPU优化)"""
521
- if grounding_dino_model is None or grounding_dino_processor is None:
522
- print("⚠️ GroundingDINO not loaded")
523
- return []
524
-
525
- try:
526
- print(f"🔍 GroundingDINO detection (CPU): {text_prompt}")
527
-
528
- # Convert to PIL Image
529
- if image_np.dtype == np.uint8:
530
- pil_image = Image.fromarray(image_np)
531
- else:
532
- pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
533
-
534
- # Preprocess - 强制使用CPU
535
- seg_device = "cpu"
536
- inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt")
537
- inputs = {k: v.to(seg_device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
538
-
539
- # Inference
540
- with torch.no_grad():
541
- outputs = grounding_dino_model(**inputs)
542
-
543
- # Post-process
544
- results = grounding_dino_processor.post_process_grounded_object_detection(
545
- outputs,
546
- inputs["input_ids"],
547
- threshold=GROUNDING_DINO_BOX_THRESHOLD,
548
- text_threshold=GROUNDING_DINO_TEXT_THRESHOLD,
549
- target_sizes=[pil_image.size[::-1]]
550
- )[0]
551
-
552
- # Convert to unified format
553
- detections = []
554
- boxes = results["boxes"].cpu().numpy()
555
- scores = results["scores"].cpu().numpy()
556
- labels = results["labels"]
557
-
558
- print(f"✅ Detected {len(boxes)} objects")
559
-
560
- for box, score, label in zip(boxes, scores, labels):
561
- detection = {
562
- 'bbox': box.tolist(), # [x1, y1, x2, y2]
563
- 'label': label,
564
- 'confidence': float(score)
565
- }
566
- detections.append(detection)
567
- print(f" - {label}: {score:.2f}")
568
-
569
- return detections
570
-
571
- except Exception as e:
572
- print(f"❌ GroundingDINO detection failed: {e}")
573
- import traceback
574
- traceback.print_exc()
575
- return []
576
-
577
-
578
- def run_sam_refinement(image_np, boxes):
579
- """Run SAM precise segmentation"""
580
- if sam_predictor is None:
581
- print("⚠️ SAM not loaded, using bbox as mask")
582
- # Use bbox to create simple rectangular mask
583
- masks = []
584
- h, w = image_np.shape[:2]
585
- for box in boxes:
586
- x1, y1, x2, y2 = map(int, box)
587
- mask = np.zeros((h, w), dtype=bool)
588
- mask[y1:y2, x1:x2] = True
589
- masks.append(mask)
590
- return masks
591
-
592
- try:
593
- print(f"🎯 SAM precise segmentation for {len(boxes)} regions...")
594
- sam_predictor.set_image(image_np)
595
-
596
- masks = []
597
- for box in boxes:
598
- x1, y1, x2, y2 = map(int, box)
599
- box_array = np.array([x1, y1, x2, y2])
600
-
601
- mask_output, _, _ = sam_predictor.predict(
602
- box=box_array,
603
- multimask_output=False
604
- )
605
- masks.append(mask_output[0])
606
-
607
- print(f"✅ SAM segmentation complete")
608
- return masks
609
-
610
- except Exception as e:
611
- print(f"❌ SAM segmentation failed: {e}")
612
- # Fallback to bbox masks
613
- masks = []
614
- h, w = image_np.shape[:2]
615
- for box in boxes:
616
- x1, y1, x2, y2 = map(int, box)
617
- mask = np.zeros((h, w), dtype=bool)
618
- mask[y1:y2, x1:x2] = True
619
- masks.append(mask)
620
- return masks
621
-
622
-
623
- def normalize_label(label):
624
- """Normalize label to main category"""
625
- label = label.strip().lower()
626
-
627
- priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door']
628
-
629
- for priority in priority_labels:
630
- if priority in label:
631
- return priority
632
-
633
- first_word = label.split()[0] if label else label
634
-
635
- # Handle plural forms
636
- if first_word.endswith('s') and len(first_word) > 1:
637
- singular = first_word[:-1]
638
- if first_word.endswith('sses'):
639
- singular = first_word[:-2]
640
- elif first_word.endswith('ies'):
641
- singular = first_word[:-3] + 'y'
642
- elif first_word.endswith('ves'):
643
- singular = first_word[:-3] + 'f'
644
- return singular
645
-
646
- return first_word
647
-
648
-
649
- def compute_object_3d_center(points, mask):
650
- """Compute 3D center of object"""
651
- masked_points = points[mask]
652
- if len(masked_points) == 0:
653
- return None
654
- return np.median(masked_points, axis=0)
655
-
656
-
657
- def compute_adaptive_eps(centers, base_eps):
658
- """Adaptively compute eps value based on object distribution"""
659
- if len(centers) <= 1:
660
- return base_eps
661
-
662
- from scipy.spatial.distance import pdist
663
- distances = pdist(centers)
664
-
665
- if len(distances) == 0:
666
- return base_eps
667
-
668
- median_dist = np.median(distances)
669
-
670
- if median_dist > base_eps * 2:
671
- adaptive_eps = min(median_dist * 0.6, base_eps * 2.5)
672
- elif median_dist > base_eps:
673
- adaptive_eps = median_dist * 0.5
674
- else:
675
- adaptive_eps = base_eps
676
-
677
- return adaptive_eps
678
-
679
-
680
- def match_objects_across_views(all_view_detections):
681
- """Match objects across views using DBSCAN clustering"""
682
- print("\n🔗 Matching objects across views using DBSCAN clustering...")
683
-
684
- objects_by_label = defaultdict(list)
685
-
686
- for view_idx, detections in enumerate(all_view_detections):
687
- for det_idx, det in enumerate(detections):
688
- if det.get('center_3d') is None:
689
- continue
690
-
691
- norm_label = normalize_label(det['label'])
692
- objects_by_label[norm_label].append({
693
- 'view_idx': view_idx,
694
- 'det_idx': det_idx,
695
- 'label': det['label'],
696
- 'norm_label': norm_label,
697
- 'center_3d': det['center_3d'],
698
- 'confidence': det['confidence'],
699
- })
700
-
701
- if len(objects_by_label) == 0:
702
- return {}, []
703
-
704
- object_id_map = defaultdict(dict)
705
- unique_objects = []
706
- next_global_id = 0
707
-
708
- for norm_label, objects in objects_by_label.items():
709
- print(f"\n 📦 Processing {norm_label}: {len(objects)} detections")
710
-
711
- if len(objects) == 1:
712
- obj = objects[0]
713
- unique_objects.append({
714
- 'global_id': next_global_id,
715
- 'label': obj['label'],
716
- 'views': [(obj['view_idx'], obj['det_idx'])],
717
- 'center_3d': obj['center_3d'],
718
- })
719
- object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
720
- next_global_id += 1
721
- print(f" → 1 cluster (single detection)")
722
- continue
723
-
724
- centers = np.array([obj['center_3d'] for obj in objects])
725
-
726
- base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0))
727
- eps = compute_adaptive_eps(centers, base_eps)
728
-
729
- clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean')
730
- cluster_labels = clustering.fit_predict(centers)
731
-
732
- n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
733
- n_noise = list(cluster_labels).count(-1)
734
-
735
- if eps != base_eps:
736
- print(f" → {n_clusters} clusters (base_eps={base_eps}m → adaptive_eps={eps:.2f}m)")
737
- else:
738
- print(f" → {n_clusters} clusters (eps={eps}m)")
739
- if n_noise > 0:
740
- print(f" ⚠️ {n_noise} noise points (isolated detections)")
741
-
742
- for cluster_id in set(cluster_labels):
743
- if cluster_id == -1:
744
- for i, label in enumerate(cluster_labels):
745
- if label == -1:
746
- obj = objects[i]
747
- unique_objects.append({
748
- 'global_id': next_global_id,
749
- 'label': obj['label'],
750
- 'views': [(obj['view_idx'], obj['det_idx'])],
751
- 'center_3d': obj['center_3d'],
752
- })
753
- object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
754
- next_global_id += 1
755
- else:
756
- cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id]
757
-
758
- total_conf = sum(o['confidence'] for o in cluster_objects)
759
- weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf
760
-
761
- unique_objects.append({
762
- 'global_id': next_global_id,
763
- 'label': cluster_objects[0]['label'],
764
- 'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects],
765
- 'center_3d': weighted_center,
766
- })
767
-
768
- for obj in cluster_objects:
769
- object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
770
-
771
- next_global_id += 1
772
-
773
- print(f"\n 📊 Summary:")
774
- print(f" Total detections: {sum(len(objs) for objs in objects_by_label.values())}")
775
- print(f" Unique objects: {len(unique_objects)}")
776
-
777
- return object_id_map, unique_objects
778
-
779
-
780
- def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks,
781
- object_id_map, unique_objects, target_dir):
782
- """Create multi-view fused segmented mesh"""
783
- try:
784
- print("\n🎨 Generating multi-view segmented mesh...")
785
-
786
- unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects))
787
- label_colors = {}
788
- colors = generate_distinct_colors(len(unique_normalized_labels))
789
-
790
- for i, norm_label in enumerate(unique_normalized_labels):
791
- label_colors[norm_label] = colors[i]
792
-
793
- for obj in unique_objects:
794
- norm_label = normalize_label(obj['label'])
795
- obj['color'] = label_colors[norm_label]
796
- obj['normalized_label'] = norm_label
797
-
798
- print(f" Object category color mapping:")
799
- for norm_label, color in sorted(label_colors.items()):
800
- count = sum(1 for obj in unique_objects if normalize_label(obj['label']) == norm_label)
801
- print(f" {norm_label} × {count} → RGB{color}")
802
-
803
- import utils3d
804
-
805
- all_meshes = []
806
-
807
- for view_idx in range(len(processed_data)):
808
- view_data = processed_data[view_idx]
809
- image = view_data["image"]
810
- points3d = view_data["points3d"]
811
- mask = view_data.get("mask")
812
- normal = view_data.get("normal")
813
-
814
- detections = all_view_detections[view_idx]
815
- masks = all_view_masks[view_idx]
816
-
817
- if len(detections) == 0:
818
- continue
819
-
820
- if image.dtype != np.uint8:
821
- if image.max() <= 1.0:
822
- image = (image * 255).astype(np.uint8)
823
- else:
824
- image = image.astype(np.uint8)
825
-
826
- colored_image = image.copy()
827
- confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
828
-
829
- detections_info = []
830
- filtered_count = 0
831
- for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)):
832
- if det['confidence'] < MIN_DETECTION_CONFIDENCE:
833
- filtered_count += 1
834
- continue
835
-
836
- mask_area = seg_mask.sum()
837
- if mask_area < MIN_MASK_AREA:
838
- filtered_count += 1
839
- continue
840
-
841
- global_id = object_id_map[view_idx].get(det_idx)
842
- if global_id is None:
843
- continue
844
-
845
- unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None)
846
- if unique_obj is None:
847
- continue
848
-
849
- detections_info.append({
850
- 'mask': seg_mask,
851
- 'color': unique_obj['color'],
852
- 'confidence': det['confidence'],
853
- })
854
-
855
- if filtered_count > 0:
856
- print(f" View {view_idx + 1}: filtered {filtered_count} low-quality detections")
857
-
858
- detections_info.sort(key=lambda x: x['confidence'])
859
-
860
- for info in detections_info:
861
- seg_mask = info['mask']
862
- color = info['color']
863
- conf = info['confidence']
864
-
865
- update_mask = seg_mask & (conf > confidence_map)
866
- colored_image[update_mask] = color
867
- confidence_map[update_mask] = conf
868
-
869
- height, width = image.shape[:2]
870
-
871
- if normal is None:
872
- faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
873
- points3d,
874
- colored_image.astype(np.float32) / 255,
875
- utils3d.numpy.image_uv(width=width, height=height),
876
- mask=mask if mask is not None else np.ones((height, width), dtype=bool),
877
- tri=True
878
- )
879
- vertex_normals = None
880
- else:
881
- faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
882
- points3d,
883
- colored_image.astype(np.float32) / 255,
884
- utils3d.numpy.image_uv(width=width, height=height),
885
- normal,
886
- mask=mask if mask is not None else np.ones((height, width), dtype=bool),
887
- tri=True
888
- )
889
-
890
- vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
891
- if vertex_normals is not None:
892
- vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
893
-
894
- view_mesh = trimesh.Trimesh(
895
- vertices=vertices,
896
- faces=faces,
897
- vertex_normals=vertex_normals,
898
- vertex_colors=(vertex_colors * 255).astype(np.uint8),
899
- process=False
900
- )
901
-
902
- all_meshes.append(view_mesh)
903
- print(f" View {view_idx + 1}: {len(vertices):,} vertices, {len(faces):,} faces")
904
-
905
- if len(all_meshes) == 0:
906
- print("⚠️ No mesh generated")
907
- return None
908
-
909
- print(" Fusing all views...")
910
- combined_mesh = trimesh.util.concatenate(all_meshes)
911
-
912
- glb_path = os.path.join(target_dir, 'segmented_mesh.glb')
913
- combined_mesh.export(glb_path)
914
-
915
- print(f"✅ Multi-view segmented mesh saved: {glb_path}")
916
- print(f" Total: {len(combined_mesh.vertices):,} vertices, {len(combined_mesh.faces):,} faces")
917
- print(f" {len(unique_objects)} unique objects")
918
-
919
- return glb_path
920
-
921
- except Exception as e:
922
- print(f"❌ Failed to generate multi-view mesh: {e}")
923
- import traceback
924
- traceback.print_exc()
925
- return None
926
-
927
-
928
- # ============================================================================
929
- # Core Model Inference
930
- # ============================================================================
931
-
932
  @spaces.GPU(duration=120)
933
  def run_model(
934
  target_dir,
@@ -936,24 +94,24 @@ def run_model(
936
  mask_edges=True,
937
  filter_black_bg=False,
938
  filter_white_bg=False,
939
- enable_segmentation=False,
940
- text_prompt=DEFAULT_TEXT_PROMPT,
941
  progress=gr.Progress(),
942
  ):
943
  """
944
- Run the MapAnything model + optional segmentation
945
  """
946
  global model
947
- import torch
948
 
949
- progress(0, desc="🔧 初始化设备...")
950
  print(f"Processing images from {target_dir}")
951
 
 
 
952
  device = "cuda" if torch.cuda.is_available() else "cpu"
953
  device = torch.device(device)
954
 
955
- # Initialize MapAnything model
956
- progress(0.05, desc="📥 加载 MapAnything 模型...")
957
  if model is None:
958
  model = initialize_mapanything_model(high_level_config, device)
959
  else:
@@ -961,46 +119,8 @@ def run_model(
961
 
962
  model.eval()
963
 
964
- # Load segmentation models if enabled (使用CPU节省GPU资源)
965
- if enable_segmentation:
966
- progress(0.1, desc="🎯 加载分割模型 (CPU)...")
967
- print(f"\n{'='*70}")
968
- print(f"🎯 分割模型加载开始... (方案: {SEGMENTATION_METHOD})")
969
- print(f"{'='*70}")
970
-
971
- if SEGMENTATION_METHOD == "segformer":
972
- # 方案1: SegFormer (最轻量,~14MB,最快)
973
- print("📌 使用方案: SegFormer (轻量级,无需文本提示)")
974
- load_segformer_model("cpu")
975
- if segformer_model is None:
976
- print("❌ SegFormer 模型加载失败!")
977
- raise RuntimeError("SegFormer 模型加载失败,请检查网络连接")
978
-
979
- elif SEGMENTATION_METHOD == "maskformer":
980
- # 方案2: MaskFormer (中等,~100MB)
981
- print("📌 使用方案: MaskFormer (实例分割)")
982
- load_maskformer_model("cpu")
983
- if maskformer_model is None:
984
- print("❌ MaskFormer 模型加载失败!")
985
- raise RuntimeError("MaskFormer 模型加载失败,请检查网络连接")
986
-
987
- else: # "grounding_sam"
988
- # 方案3: GroundingDINO + SAM (最强,~110MB,需要文本提示)
989
- print("📌 使用方案: GroundingDINO + SAM (文本提示驱动)")
990
- load_grounding_dino_model("cpu")
991
- load_sam_model("cpu")
992
- if grounding_dino_model is None:
993
- print("❌ GroundingDINO 模型加载失败!")
994
- raise RuntimeError("GroundingDINO 模型加载失败,请检查网络连接")
995
- if sam_predictor is None:
996
- print("❌ SAM 模型加载失败!")
997
- raise RuntimeError("SAM 模型加载失败,请检查网络连接")
998
-
999
- print(f"✅ 分割模型加载成功")
1000
- print(f"{'='*70}\n")
1001
-
1002
- # Load images
1003
- progress(0.15, desc="📷 加载图片...")
1004
  print("Loading images...")
1005
  image_folder_path = os.path.join(target_dir, "images")
1006
  views = load_images(image_folder_path)
@@ -1010,15 +130,22 @@ def run_model(
1010
  raise ValueError("No images found. Check your upload.")
1011
 
1012
  # Run model inference
1013
- progress(0.2, desc=f"🚀 运行 3D 重建 ({len(views)} 张图片)...")
 
 
1014
  print("Running inference...")
 
 
1015
  outputs = model.infer(
1016
  views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
1017
  )
 
1018
 
1019
- # Convert predictions
1020
- progress(0.5, desc="🔄 处理预测结果...")
1021
  predictions = {}
 
 
1022
  extrinsic_list = []
1023
  intrinsic_list = []
1024
  world_points_list = []
@@ -1026,158 +153,81 @@ def run_model(
1026
  images_list = []
1027
  final_mask_list = []
1028
 
1029
- for pred in outputs:
1030
- depthmap_torch = pred["depth_z"][0].squeeze(-1)
1031
- intrinsics_torch = pred["intrinsics"][0]
1032
- camera_pose_torch = pred["camera_poses"][0]
 
 
 
 
1033
 
 
1034
  pts3d_computed, valid_mask = depthmap_to_world_frame(
1035
  depthmap_torch, intrinsics_torch, camera_pose_torch
1036
  )
1037
 
 
 
1038
  if "mask" in pred:
1039
  mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
1040
  else:
 
1041
  mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
1042
 
 
1043
  mask = mask & valid_mask.cpu().numpy()
 
1044
  image = pred["img_no_norm"][0].cpu().numpy()
1045
 
 
1046
  extrinsic_list.append(camera_pose_torch.cpu().numpy())
1047
  intrinsic_list.append(intrinsics_torch.cpu().numpy())
1048
  world_points_list.append(pts3d_computed.cpu().numpy())
1049
  depth_maps_list.append(depthmap_torch.cpu().numpy())
1050
- images_list.append(image)
1051
- final_mask_list.append(mask)
1052
 
 
 
1053
  predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
 
 
1054
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
 
 
1055
  predictions["world_points"] = np.stack(world_points_list, axis=0)
1056
 
 
1057
  depth_maps = np.stack(depth_maps_list, axis=0)
 
1058
  if len(depth_maps.shape) == 3:
1059
  depth_maps = depth_maps[..., np.newaxis]
 
1060
  predictions["depth"] = depth_maps
1061
 
 
1062
  predictions["images"] = np.stack(images_list, axis=0)
 
 
1063
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
1064
 
1065
- # Process visualization data
1066
- progress(0.6, desc="🎨 准备可视化数据...")
1067
  processed_data = process_predictions_for_visualization(
1068
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
1069
  )
1070
 
1071
- # Segmentation processing
1072
- segmented_glb = None
1073
- if enable_segmentation:
1074
- progress(0.65, desc="🎯 开始物体分割...")
1075
- print(f"\n{'='*70}")
1076
- print(f"🎯 开始物体分割... (方案: {SEGMENTATION_METHOD})")
1077
- print(f"📐 最小掩码面积: {MIN_MASK_AREA} px")
1078
- if SEGMENTATION_METHOD == "grounding_sam":
1079
- print(f"🔍 检测提示词: {text_prompt[:100]}...")
1080
- print(f"📊 置信度阈值: {GROUNDING_DINO_BOX_THRESHOLD}")
1081
- print(f"{'='*70}\n")
1082
-
1083
- all_view_detections = []
1084
- all_view_masks = []
1085
-
1086
- for view_idx, ref_image in enumerate(images_list):
1087
- progress(0.65 + (view_idx / len(images_list)) * 0.2,
1088
- desc=f"🔍 检测视图 {view_idx + 1}/{len(images_list)}...")
1089
- print(f"\n📸 Processing view {view_idx + 1}/{len(images_list)}...")
1090
-
1091
- if ref_image.dtype != np.uint8:
1092
- ref_image_np = (ref_image * 255).astype(np.uint8)
1093
- else:
1094
- ref_image_np = ref_image
1095
-
1096
- # 根据分割方法选择不同的处理流程
1097
- if SEGMENTATION_METHOD == "segformer":
1098
- # SegFormer: 直接语义分割,无需文本提示
1099
- detections, masks = run_segformer_segmentation(ref_image_np, "cpu")
1100
- print(f" ✓ 检测到 {len(detections)} 个物体")
1101
-
1102
- if len(detections) > 0:
1103
- for i, det in enumerate(detections):
1104
- print(f" 物体 {i+1}: {det['label']}")
1105
-
1106
- points3d = world_points_list[view_idx]
1107
- for det_idx, (det, mask) in enumerate(zip(detections, masks)):
1108
- center_3d = compute_object_3d_center(points3d, mask)
1109
- det['center_3d'] = center_3d
1110
- det['mask_2d'] = mask
1111
-
1112
- all_view_detections.append(detections)
1113
- all_view_masks.append(masks)
1114
- else:
1115
- all_view_detections.append([])
1116
- all_view_masks.append([])
1117
-
1118
- elif SEGMENTATION_METHOD == "grounding_sam":
1119
- # GroundingDINO + SAM: 文本提示驱动
1120
- detections = run_grounding_dino_detection(ref_image_np, text_prompt, "cpu")
1121
- print(f" ✓ 检测到 {len(detections)} 个物体")
1122
-
1123
- if len(detections) > 0:
1124
- for i, det in enumerate(detections):
1125
- print(f" 物体 {i+1}: {det['label']} (置信度: {det['confidence']:.2f})")
1126
- boxes = [d['bbox'] for d in detections]
1127
- masks = run_sam_refinement(ref_image_np, boxes)
1128
-
1129
- points3d = world_points_list[view_idx]
1130
- for det_idx, (det, mask) in enumerate(zip(detections, masks)):
1131
- center_3d = compute_object_3d_center(points3d, mask)
1132
- det['center_3d'] = center_3d
1133
- det['mask_2d'] = mask
1134
-
1135
- all_view_detections.append(detections)
1136
- all_view_masks.append(masks)
1137
- else:
1138
- all_view_detections.append([])
1139
- all_view_masks.append([])
1140
-
1141
- # Match objects across views
1142
- total_detections = sum(len(dets) for dets in all_view_detections)
1143
- print(f"\n📊 总检测数: {total_detections}")
1144
-
1145
- if any(len(dets) > 0 for dets in all_view_detections):
1146
- progress(0.85, desc="🔗 匹配跨视图物体...")
1147
- object_id_map, unique_objects = match_objects_across_views(all_view_detections)
1148
-
1149
- # Generate segmented mesh
1150
- progress(0.9, desc="🏗️ 生成分割3D模型...")
1151
- segmented_glb = create_multi_view_segmented_mesh(
1152
- processed_data, all_view_detections, all_view_masks,
1153
- object_id_map, unique_objects, target_dir
1154
- )
1155
-
1156
- if segmented_glb:
1157
- print(f"✅ 分割3D模型已生成: {segmented_glb}")
1158
- else:
1159
- print(f"⚠️ 分割3D模型生成失败")
1160
- else:
1161
- print(f"\n{'='*70}")
1162
- print(f"⚠️ 未检测到任何物体,无法生成分割模型")
1163
- print(f"\n💡 调试提示:")
1164
- print(f" 1. 检查检测提示词是否准确(当前: {text_prompt[:50]}...)")
1165
- print(f" 2. 当前置信度阈值: {GROUNDING_DINO_BOX_THRESHOLD}")
1166
- print(f" 3. 尝试更通用的提示词,如: {COMMON_OBJECTS_PROMPT[:80]}...")
1167
- print(f" 4. 确保图片中有清晰可见的物体")
1168
- print(f"{'='*70}\n")
1169
-
1170
- # Cleanup
1171
  progress(0.95, desc="🧹 清理内存...")
1172
  torch.cuda.empty_cache()
1173
 
1174
- progress(1.0, desc="✅ 完成!")
1175
- return predictions, processed_data, segmented_glb
 
1176
 
 
1177
 
1178
- # ============================================================================
1179
- # Helper Functions (from app.py)
1180
- # ============================================================================
1181
 
1182
  def update_view_selectors(processed_data):
1183
  """Update view selector dropdowns based on available views"""
@@ -1188,9 +238,9 @@ def update_view_selectors(processed_data):
1188
  choices = [f"View {i + 1}" for i in range(num_views)]
1189
 
1190
  return (
1191
- gr.Dropdown(choices=choices, value=choices[0]),
1192
- gr.Dropdown(choices=choices, value=choices[0]),
1193
- gr.Dropdown(choices=choices, value=choices[0]),
1194
  )
1195
 
1196
 
@@ -1228,24 +278,33 @@ def update_measure_view(processed_data, view_index):
1228
  """Update measure view for a specific view index with mask overlay"""
1229
  view_data = get_view_data_by_index(processed_data, view_index)
1230
  if view_data is None:
1231
- return None, []
1232
 
 
1233
  image = view_data["image"].copy()
1234
 
 
1235
  if image.dtype != np.uint8:
1236
  if image.max() <= 1.0:
1237
  image = (image * 255).astype(np.uint8)
1238
  else:
1239
  image = image.astype(np.uint8)
1240
 
 
1241
  if view_data["mask"] is not None:
1242
  mask = view_data["mask"]
1243
- invalid_mask = ~mask
 
 
 
1244
 
1245
  if invalid_mask.any():
 
1246
  overlay_color = np.array([255, 220, 220], dtype=np.uint8)
1247
- alpha = 0.5
1248
- for c in range(3):
 
 
1249
  image[:, :, c] = np.where(
1250
  invalid_mask,
1251
  (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
@@ -1260,6 +319,7 @@ def navigate_depth_view(processed_data, current_selector_value, direction):
1260
  if processed_data is None or len(processed_data) == 0:
1261
  return "View 1", None
1262
 
 
1263
  try:
1264
  current_view = int(current_selector_value.split()[1]) - 1
1265
  except:
@@ -1279,6 +339,7 @@ def navigate_normal_view(processed_data, current_selector_value, direction):
1279
  if processed_data is None or len(processed_data) == 0:
1280
  return "View 1", None
1281
 
 
1282
  try:
1283
  current_view = int(current_selector_value.split()[1]) - 1
1284
  except:
@@ -1298,6 +359,7 @@ def navigate_measure_view(processed_data, current_selector_value, direction):
1298
  if processed_data is None or len(processed_data) == 0:
1299
  return "View 1", None, []
1300
 
 
1301
  try:
1302
  current_view = int(current_selector_value.split()[1]) - 1
1303
  except:
@@ -1317,6 +379,7 @@ def populate_visualization_tabs(processed_data):
1317
  if processed_data is None or len(processed_data) == 0:
1318
  return None, None, None, []
1319
 
 
1320
  depth_vis = update_depth_view(processed_data, 0)
1321
  normal_vis = update_normal_view(processed_data, 0)
1322
  measure_img, _ = update_measure_view(processed_data, 0)
@@ -1324,6 +387,9 @@ def populate_visualization_tabs(processed_data):
1324
  return depth_vis, normal_vis, measure_img, []
1325
 
1326
 
 
 
 
1327
  def handle_uploads(unified_upload, s_time_interval=1.0):
1328
  """
1329
  Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
@@ -1333,10 +399,12 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
1333
  gc.collect()
1334
  torch.cuda.empty_cache()
1335
 
 
1336
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
1337
  target_dir = f"input_images_{timestamp}"
1338
  target_dir_images = os.path.join(target_dir, "images")
1339
 
 
1340
  if os.path.exists(target_dir):
1341
  shutil.rmtree(target_dir)
1342
  os.makedirs(target_dir)
@@ -1344,6 +412,7 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
1344
 
1345
  image_paths = []
1346
 
 
1347
  if unified_upload is not None:
1348
  for file_data in unified_upload:
1349
  if isinstance(file_data, dict) and "name" in file_data:
@@ -1353,13 +422,23 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
1353
 
1354
  file_ext = os.path.splitext(file_path)[1].lower()
1355
 
 
1356
  video_extensions = [
1357
- ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp",
 
 
 
 
 
 
 
 
1358
  ]
1359
  if file_ext in video_extensions:
 
1360
  vs = cv2.VideoCapture(file_path)
1361
  fps = vs.get(cv2.CAP_PROP_FPS)
1362
- frame_interval = int(fps * s_time_interval)
1363
 
1364
  count = 0
1365
  video_frame_num = 0
@@ -1369,6 +448,7 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
1369
  break
1370
  count += 1
1371
  if count % frame_interval == 0:
 
1372
  base_name = os.path.splitext(os.path.basename(file_path))[0]
1373
  image_path = os.path.join(
1374
  target_dir_images, f"{base_name}_{video_frame_num:06}.png"
@@ -1377,52 +457,82 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
1377
  image_paths.append(image_path)
1378
  video_frame_num += 1
1379
  vs.release()
1380
- print(f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}")
 
 
1381
 
1382
  else:
 
 
1383
  if file_ext in [".heic", ".heif"]:
 
1384
  try:
1385
  with Image.open(file_path) as img:
 
1386
  if img.mode not in ("RGB", "L"):
1387
  img = img.convert("RGB")
1388
 
 
1389
  base_name = os.path.splitext(os.path.basename(file_path))[0]
1390
- dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
 
 
1391
 
 
1392
  img.save(dst_path, "JPEG", quality=95)
1393
  image_paths.append(dst_path)
1394
- print(f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}")
 
 
1395
  except Exception as e:
1396
  print(f"Error converting HEIC file {file_path}: {e}")
1397
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
 
 
 
1398
  shutil.copy(file_path, dst_path)
1399
  image_paths.append(dst_path)
1400
  else:
1401
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
 
 
 
1402
  shutil.copy(file_path, dst_path)
1403
  image_paths.append(dst_path)
1404
 
 
1405
  image_paths = sorted(image_paths)
1406
 
1407
  end_time = time.time()
1408
- print(f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds")
 
 
1409
  return target_dir, image_paths
1410
 
1411
 
 
 
 
1412
  def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
1413
- """Update gallery on upload"""
 
 
 
 
1414
  if not input_video and not input_images:
1415
- return None, None, None, None, None
1416
  target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
1417
  return (
1418
- None,
1419
  None,
1420
  target_dir,
1421
  image_paths,
1422
- "Upload complete. Click 'Reconstruct' to begin 3D processing.",
1423
  )
1424
 
1425
 
 
 
 
1426
  @spaces.GPU(duration=120)
1427
  def gradio_demo(
1428
  target_dir,
@@ -1432,19 +542,20 @@ def gradio_demo(
1432
  filter_white_bg=False,
1433
  apply_mask=True,
1434
  show_mesh=True,
1435
- enable_segmentation=False,
1436
- text_prompt=DEFAULT_TEXT_PROMPT,
1437
  progress=gr.Progress(),
1438
  ):
1439
- """执行重建"""
 
 
1440
  if not os.path.isdir(target_dir) or target_dir == "None":
1441
- return None, None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None
1442
 
1443
  progress(0, desc="🔄 准备重建...")
1444
  start_time = time.time()
1445
  gc.collect()
1446
  torch.cuda.empty_cache()
1447
 
 
1448
  target_dir_images = os.path.join(target_dir, "images")
1449
  all_files = (
1450
  sorted(os.listdir(target_dir_images))
@@ -1454,94 +565,92 @@ def gradio_demo(
1454
  all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
1455
  frame_filter_choices = ["All"] + all_files
1456
 
1457
- progress(0.05, desc="🚀 运行 MapAnything 模型...")
1458
- print("运行 MapAnything 模型...")
1459
  with torch.no_grad():
1460
- predictions, processed_data, segmented_glb = run_model(
1461
- target_dir, apply_mask, True, filter_black_bg, filter_white_bg,
1462
- enable_segmentation, text_prompt, progress
1463
  )
1464
 
1465
- # 保存预测结果
1466
  progress(0.92, desc="💾 保存预测结果...")
1467
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
1468
  np.savez(prediction_save_path, **predictions)
1469
 
 
1470
  if frame_filter is None:
1471
  frame_filter = "All"
1472
 
1473
- # 生成 GLB 文件名
1474
- progress(0.93, desc="🏗️ 生成原始3D模型...")
1475
  glbfile = os.path.join(
1476
  target_dir,
1477
- f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb",
1478
  )
1479
 
1480
- # 转换预测结果为 GLB
1481
  glbscene = predictions_to_glb(
1482
  predictions,
1483
  filter_by_frames=frame_filter,
1484
  show_cam=show_cam,
1485
  mask_black_bg=filter_black_bg,
1486
  mask_white_bg=filter_white_bg,
1487
- as_mesh=show_mesh,
1488
  )
1489
  glbscene.export(file_obj=glbfile)
1490
 
1491
- # 清理内存
1492
  progress(0.96, desc="🧹 清理内存...")
1493
  del predictions
1494
  gc.collect()
1495
  torch.cuda.empty_cache()
1496
 
1497
  end_time = time.time()
1498
- print(f"总耗时: {end_time - start_time:.2f}秒")
1499
- log_msg = f" 重建成功 ({len(all_files)} 帧,耗时 {end_time - start_time:.1f}秒)"
 
1500
 
1501
- # Populate visualization tabs
1502
  progress(0.98, desc="🎨 生成可视化...")
1503
  depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
1504
  processed_data
1505
  )
1506
 
1507
- # Update view selectors
1508
  depth_selector, normal_selector, measure_selector = update_view_selectors(
1509
  processed_data
1510
  )
1511
 
1512
  progress(1.0, desc="✅ 全部完成!")
1513
-
1514
- # 添加分割状态信息
1515
- if enable_segmentation:
1516
- if segmented_glb:
1517
- log_msg += f"\n🎨 分割模型已生成"
1518
- else:
1519
- log_msg += f"\n⚠️ 未检测到物体,无分割模型"
1520
 
1521
  return (
1522
  glbfile,
1523
- segmented_glb,
1524
  log_msg,
1525
  gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
1526
  processed_data,
1527
  depth_vis,
1528
  normal_vis,
1529
  measure_img,
1530
- "",
1531
  depth_selector,
1532
  normal_selector,
1533
  measure_selector,
1534
  )
1535
 
1536
 
 
 
 
1537
  def colorize_depth(depth_map, mask=None):
1538
  """Convert depth map to colorized visualization with optional mask"""
1539
  if depth_map is None:
1540
  return None
1541
 
 
1542
  depth_normalized = depth_map.copy()
1543
  valid_mask = depth_normalized > 0
1544
 
 
1545
  if mask is not None:
1546
  valid_mask = valid_mask & mask
1547
 
@@ -1552,12 +661,14 @@ def colorize_depth(depth_map, mask=None):
1552
 
1553
  depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
1554
 
 
1555
  import matplotlib.pyplot as plt
1556
 
1557
  colormap = plt.cm.turbo_r
1558
  colored = colormap(depth_normalized)
1559
  colored = (colored[:, :, :3] * 255).astype(np.uint8)
1560
 
 
1561
  colored[~valid_mask] = [255, 255, 255]
1562
 
1563
  return colored
@@ -1568,12 +679,15 @@ def colorize_normal(normal_map, mask=None):
1568
  if normal_map is None:
1569
  return None
1570
 
 
1571
  normal_vis = normal_map.copy()
1572
 
 
1573
  if mask is not None:
1574
  invalid_mask = ~mask
1575
- normal_vis[invalid_mask] = [0, 0, 0]
1576
 
 
1577
  normal_vis = (normal_vis + 1.0) / 2.0
1578
  normal_vis = (normal_vis * 255).astype(np.uint8)
1579
 
@@ -1586,11 +700,15 @@ def process_predictions_for_visualization(
1586
  """Extract depth, normal, and 3D points from predictions for visualization"""
1587
  processed_data = {}
1588
 
 
1589
  for view_idx, view in enumerate(views):
 
1590
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
1591
 
 
1592
  pred_pts3d = predictions["world_points"][view_idx]
1593
 
 
1594
  view_data = {
1595
  "image": image[0],
1596
  "points3d": pred_pts3d,
@@ -1599,15 +717,22 @@ def process_predictions_for_visualization(
1599
  "mask": None,
1600
  }
1601
 
 
1602
  mask = predictions["final_mask"][view_idx].copy()
1603
 
 
1604
  if filter_black_bg:
 
1605
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
 
1606
  black_bg_mask = view_colors.sum(axis=2) >= 16
1607
  mask = mask & black_bg_mask
1608
 
 
1609
  if filter_white_bg:
 
1610
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
 
1611
  white_bg_mask = ~(
1612
  (view_colors[:, :, 0] > 240)
1613
  & (view_colors[:, :, 1] > 240)
@@ -1631,6 +756,7 @@ def reset_measure(processed_data):
1631
  if processed_data is None or len(processed_data) == 0:
1632
  return None, [], ""
1633
 
 
1634
  first_view = list(processed_data.values())[0]
1635
  return first_view["image"], [], ""
1636
 
@@ -1640,18 +766,20 @@ def measure(
1640
  ):
1641
  """Handle measurement on images"""
1642
  try:
1643
- print(f"测量功能调用,选择器: {current_view_selector}")
1644
 
1645
  if processed_data is None or len(processed_data) == 0:
1646
- return None, [], " 没有可用数据"
1647
 
 
1648
  try:
1649
  current_view_index = int(current_view_selector.split()[1]) - 1
1650
  except:
1651
  current_view_index = 0
1652
 
1653
- print(f"使用视图索引: {current_view_index}")
1654
 
 
1655
  if current_view_index < 0 or current_view_index >= len(processed_data):
1656
  current_view_index = 0
1657
 
@@ -1659,46 +787,54 @@ def measure(
1659
  current_view = processed_data[view_keys[current_view_index]]
1660
 
1661
  if current_view is None:
1662
- return None, [], " 没有视图数据"
1663
 
1664
  point2d = event.index[0], event.index[1]
1665
- print(f"点击点: {point2d}")
1666
 
 
1667
  if (
1668
  current_view["mask"] is not None
1669
  and 0 <= point2d[1] < current_view["mask"].shape[0]
1670
  and 0 <= point2d[0] < current_view["mask"].shape[1]
1671
  ):
 
1672
  if not current_view["mask"][point2d[1], point2d[0]]:
1673
- print(f"点击点 {point2d} 在遮罩区域,忽略点击")
 
1674
  masked_image, _ = update_measure_view(
1675
  processed_data, current_view_index
1676
  )
1677
  return (
1678
  masked_image,
1679
  measure_points,
1680
- '<span style="color: red; font-weight: bold;">⚠️ 无法在遮罩区域测量(显示为灰色)</span>',
1681
  )
1682
 
1683
  measure_points.append(point2d)
1684
 
 
1685
  image, _ = update_measure_view(processed_data, current_view_index)
1686
  if image is None:
1687
- return None, [], " 没有可用图像"
1688
 
1689
  image = image.copy()
1690
  points3d = current_view["points3d"]
1691
 
 
1692
  try:
1693
  if image.dtype != np.uint8:
1694
  if image.max() <= 1.0:
 
1695
  image = (image * 255).astype(np.uint8)
1696
  else:
 
1697
  image = image.astype(np.uint8)
1698
  except Exception as e:
1699
- print(f"图像转换错误: {e}")
1700
- return None, [], f" 图像转换错误: {e}"
1701
 
 
1702
  try:
1703
  for p in measure_points:
1704
  if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
@@ -1706,8 +842,8 @@ def measure(
1706
  image, p, radius=5, color=(255, 0, 0), thickness=2
1707
  )
1708
  except Exception as e:
1709
- print(f"绘制错误: {e}")
1710
- return None, [], f" 绘制错误: {e}"
1711
 
1712
  depth_text = ""
1713
  try:
@@ -1718,22 +854,24 @@ def measure(
1718
  and 0 <= p[0] < current_view["depth"].shape[1]
1719
  ):
1720
  d = current_view["depth"][p[1], p[0]]
1721
- depth_text += f"- **P{i + 1} 深度: {d:.2f}m**\n"
1722
  else:
 
1723
  if (
1724
  points3d is not None
1725
  and 0 <= p[1] < points3d.shape[0]
1726
  and 0 <= p[0] < points3d.shape[1]
1727
  ):
1728
  z = points3d[p[1], p[0], 2]
1729
- depth_text += f"- **P{i + 1} Z坐标: {z:.2f}m**\n"
1730
  except Exception as e:
1731
- print(f"深度文本错误: {e}")
1732
- depth_text = f" 深度计算错误: {e}\n"
1733
 
1734
  if len(measure_points) == 2:
1735
  try:
1736
  point1, point2 = measure_points
 
1737
  if (
1738
  0 <= point1[0] < image.shape[1]
1739
  and 0 <= point1[1] < image.shape[0]
@@ -1744,7 +882,8 @@ def measure(
1744
  image, point1, point2, color=(255, 0, 0), thickness=2
1745
  )
1746
 
1747
- distance_text = "- **距离: 无法计算**"
 
1748
  if (
1749
  points3d is not None
1750
  and 0 <= point1[1] < points3d.shape[0]
@@ -1756,35 +895,39 @@ def measure(
1756
  p1_3d = points3d[point1[1], point1[0]]
1757
  p2_3d = points3d[point2[1], point2[0]]
1758
  distance = np.linalg.norm(p1_3d - p2_3d)
1759
- distance_text = f"- **距离: {distance:.2f}m**"
1760
  except Exception as e:
1761
- print(f"距离计算错误: {e}")
1762
- distance_text = f"- **距离计算错误: {e}**"
1763
 
1764
  measure_points = []
1765
  text = depth_text + distance_text
1766
- print(f"测量完成: {text}")
1767
  return [image, measure_points, text]
1768
  except Exception as e:
1769
- print(f"最终测量错误: {e}")
1770
- return None, [], f" 测量错误: {e}"
1771
  else:
1772
- print(f"单点测量: {depth_text}")
1773
  return [image, measure_points, depth_text]
1774
 
1775
  except Exception as e:
1776
- print(f"整体测量功能错误: {e}")
1777
- return None, [], f" 测量功能错误: {e}"
1778
 
1779
 
1780
  def clear_fields():
1781
- """清空 3D 查看器"""
1782
- return None, None
 
 
1783
 
1784
 
1785
  def update_log():
1786
- """显示日志消息"""
1787
- return "🔄 加载和重建中..."
 
 
1788
 
1789
 
1790
  def update_visualization(
@@ -1796,16 +939,30 @@ def update_visualization(
1796
  filter_white_bg=False,
1797
  show_mesh=True,
1798
  ):
1799
- """更新可视化"""
 
 
 
 
 
1800
  if is_example == "True":
1801
- return gr.update(), "❌ 没有可用的重建。请先点击重建按钮。"
 
 
 
1802
 
1803
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
1804
- return gr.update(), "❌ 没有可用的重建。请先点击重建按钮。"
 
 
 
1805
 
1806
  predictions_path = os.path.join(target_dir, "predictions.npz")
1807
  if not os.path.exists(predictions_path):
1808
- return gr.update(), f"❌ 没有可用的重建。请先运行「重建」。"
 
 
 
1809
 
1810
  loaded = np.load(predictions_path, allow_pickle=True)
1811
  predictions = {key: loaded[key] for key in loaded.keys()}
@@ -1815,17 +972,21 @@ def update_visualization(
1815
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
1816
  )
1817
 
1818
- glbscene = predictions_to_glb(
1819
- predictions,
1820
- filter_by_frames=frame_filter,
1821
- show_cam=show_cam,
1822
- mask_black_bg=filter_black_bg,
1823
- mask_white_bg=filter_white_bg,
1824
- as_mesh=show_mesh,
1825
- )
1826
- glbscene.export(file_obj=glbfile)
 
1827
 
1828
- return glbfile, "✅ 可视化已更新。"
 
 
 
1829
 
1830
 
1831
  def update_all_views_on_filter_change(
@@ -1837,7 +998,11 @@ def update_all_views_on_filter_change(
1837
  normal_view_selector,
1838
  measure_view_selector,
1839
  ):
1840
- """Update all individual view tabs when background filtering checkboxes change"""
 
 
 
 
1841
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
1842
  return processed_data, None, None, None, []
1843
 
@@ -1846,16 +1011,20 @@ def update_all_views_on_filter_change(
1846
  return processed_data, None, None, None, []
1847
 
1848
  try:
 
1849
  loaded = np.load(predictions_path, allow_pickle=True)
1850
  predictions = {key: loaded[key] for key in loaded.keys()}
1851
 
 
1852
  image_folder_path = os.path.join(target_dir, "images")
1853
  views = load_images(image_folder_path)
1854
 
 
1855
  new_processed_data = process_predictions_for_visualization(
1856
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
1857
  )
1858
 
 
1859
  try:
1860
  depth_view_idx = (
1861
  int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
@@ -1879,6 +1048,7 @@ def update_all_views_on_filter_change(
1879
  except:
1880
  measure_view_idx = 0
1881
 
 
1882
  depth_vis = update_depth_view(new_processed_data, depth_view_idx)
1883
  normal_vis = update_normal_view(new_processed_data, normal_view_idx)
1884
  measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
@@ -1890,10 +1060,9 @@ def update_all_views_on_filter_change(
1890
  return processed_data, None, None, None, []
1891
 
1892
 
1893
- # ============================================================================
1894
- # Example Scene Functions
1895
- # ============================================================================
1896
-
1897
  def get_scene_info(examples_dir):
1898
  """Get information about scenes in the examples directory"""
1899
  import glob
@@ -1905,6 +1074,7 @@ def get_scene_info(examples_dir):
1905
  for scene_folder in sorted(os.listdir(examples_dir)):
1906
  scene_path = os.path.join(examples_dir, scene_folder)
1907
  if os.path.isdir(scene_path):
 
1908
  image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
1909
  image_files = []
1910
  for ext in image_extensions:
@@ -1912,6 +1082,7 @@ def get_scene_info(examples_dir):
1912
  image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
1913
 
1914
  if image_files:
 
1915
  image_files = sorted(image_files)
1916
  first_image = image_files[0]
1917
  num_images = len(image_files)
@@ -1930,9 +1101,10 @@ def get_scene_info(examples_dir):
1930
 
1931
 
1932
  def load_example_scene(scene_name, examples_dir="examples"):
1933
- """从示例目录加载场景"""
1934
  scenes = get_scene_info(examples_dir)
1935
 
 
1936
  selected_scene = None
1937
  for scene in scenes:
1938
  if scene["name"] == scene_name:
@@ -1940,26 +1112,28 @@ def load_example_scene(scene_name, examples_dir="examples"):
1940
  break
1941
 
1942
  if selected_scene is None:
1943
- return None, None, None, " 场景未找到"
1944
 
 
 
1945
  file_objects = []
1946
  for image_path in selected_scene["image_files"]:
1947
  file_objects.append(image_path)
1948
 
 
1949
  target_dir, image_paths = handle_uploads(file_objects, 1.0)
1950
 
1951
  return (
1952
- None,
1953
- target_dir,
1954
- image_paths,
1955
- f"已加载场景 '{scene_name}' ({selected_scene['num_images']} 张图像)。点击「开始重建」进行 3D 处理。",
1956
  )
1957
 
1958
 
1959
- # ============================================================================
1960
- # Gradio UI
1961
- # ============================================================================
1962
-
1963
  theme = get_gradio_theme()
1964
 
1965
  # 自定义CSS防止UI抖动
@@ -2022,45 +1196,44 @@ CUSTOM_CSS = GRADIO_CSS + """
2022
  }
2023
  """
2024
 
2025
- # JavaScript for paste support
2026
- PASTE_JS = """
2027
- <script>
2028
- // 添加粘贴板支持
2029
- document.addEventListener('paste', function(e) {
2030
- const items = e.clipboardData.items;
2031
- for (let i = 0; i < items.length; i++) {
2032
- if (items[i].type.indexOf('image') !== -1) {
2033
- const blob = items[i].getAsFile();
2034
- const fileInput = document.querySelector('input[type="file"][multiple]');
2035
- if (fileInput) {
2036
- const dataTransfer = new DataTransfer();
2037
- dataTransfer.items.add(blob);
2038
- fileInput.files = dataTransfer.files;
2039
- fileInput.dispatchEvent(new Event('change', { bubbles: true }));
2040
- console.log('✅ 图片已从剪贴板粘贴');
2041
- }
2042
- }
2043
- }
2044
- });
2045
-
2046
- // 添加提示信息
2047
- console.log('💡 粘贴板功能已启用:使用 Ctrl+V 可直接粘贴截图');
2048
- </script>
2049
- """
2050
-
2051
- with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与物体分割") as demo:
2052
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
 
2053
  processed_data_state = gr.State(value=None)
2054
  measure_points_state = gr.State(value=[])
 
2055
 
2056
  # 添加粘贴板支持的 JavaScript
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2057
  gr.HTML(PASTE_JS)
2058
-
2059
- # 顶部标题
2060
  gr.HTML("""
2061
  <div style="text-align: center; margin: 20px 0;">
2062
- <h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything V2 - 3D重建与物体分割</h2>
2063
- <p style="color: #666; font-size: 16px;">基于DBSCAN聚类的智能物体识别 | 多视图融合 | 自适应参数调整</p>
2064
  </div>
2065
  """)
2066
 
@@ -2133,23 +1306,6 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2133
  clear_color=[0.0, 0.0, 0.0, 0.0]
2134
  )
2135
 
2136
- with gr.Tab("🎨 分割3D"):
2137
- gr.Markdown(
2138
- """
2139
- 💡 **使用说明**:
2140
- 1. 在下方「⚙️ 高级选项」中勾选「启用语义分割 (CPU)」
2141
- 2. 点击「开始重建」按钮
2142
- 3. 等待处理完成后,分割结果将显示在此处
2143
-
2144
- 📌 如果没有显示分割结果,请查看控制台日志查找原因
2145
- """,
2146
- elem_classes=["info-box"]
2147
- )
2148
- segmented_output = gr.Model3D(
2149
- height=450, zoom_speed=0.5, pan_speed=0.5,
2150
- clear_color=[0.0, 0.0, 0.0, 0.0]
2151
- )
2152
-
2153
  with gr.Tab("📊 深度图"):
2154
  with gr.Row(elem_classes=["navigation-row"]):
2155
  prev_depth_btn = gr.Button("◀", size="sm", scale=1)
@@ -2200,8 +1356,8 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2200
  max_lines=1
2201
  )
2202
 
2203
- # 高级选项(默认打开)
2204
- with gr.Accordion("⚙️ 高级选项", open=True):
2205
  with gr.Row(equal_height=False):
2206
  with gr.Column(scale=1, min_width=300):
2207
  gr.Markdown("#### 可视化参数")
@@ -2218,32 +1374,13 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2218
  apply_mask_checkbox = gr.Checkbox(
2219
  label="应用深度掩码", value=True
2220
  )
2221
-
2222
- gr.Markdown("#### 分割参数")
2223
- gr.Markdown("💡 **说明**: 分割使用 CPU 运行(MobileSAM轻量级模型),不占用GPU资源")
2224
- enable_segmentation = gr.Checkbox(
2225
- label="启用语义分割 (CPU)", value=False
2226
- )
2227
-
2228
- text_prompt = gr.Textbox(
2229
- value=DEFAULT_TEXT_PROMPT,
2230
- label="检测物体(用 . 分隔)",
2231
- placeholder="例如: chair . table . sofa",
2232
- lines=2,
2233
- max_lines=2
2234
- )
2235
-
2236
- with gr.Row():
2237
- detect_all_btn = gr.Button("🔍 检测所有", size="sm")
2238
- restore_default_btn = gr.Button("↻ 默认", size="sm")
2239
-
2240
- gr.Markdown("📌 **提示**: 启用后会在「分割3D」标签页显示彩色分割模型")
2241
-
2242
  # 示例场景(可折叠)
2243
  with gr.Accordion("🖼️ 示例场景", open=False):
 
2244
  scenes = get_scene_info("examples")
 
2245
  if scenes:
2246
- for i in range(0, len(scenes), 4):
2247
  with gr.Row(equal_height=True):
2248
  for j in range(4):
2249
  scene_idx = i + j
@@ -2251,10 +1388,10 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2251
  scene = scenes[scene_idx]
2252
  with gr.Column(scale=1, min_width=150):
2253
  scene_img = gr.Image(
2254
- value=scene["thumbnail"],
2255
  height=150,
2256
- interactive=False,
2257
- show_label=False,
2258
  sources=[],
2259
  container=False
2260
  )
@@ -2266,22 +1403,14 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2266
  fn=lambda name=scene["name"]: load_example_scene(name),
2267
  outputs=[
2268
  reconstruction_output,
2269
- target_dir_output, image_gallery, log_output
2270
- ]
 
 
2271
  )
2272
 
2273
  # === 事件绑定 ===
2274
 
2275
- # 分割选项按钮
2276
- detect_all_btn.click(
2277
- fn=lambda: COMMON_OBJECTS_PROMPT,
2278
- outputs=[text_prompt]
2279
- )
2280
- restore_default_btn.click(
2281
- fn=lambda: DEFAULT_TEXT_PROMPT,
2282
- outputs=[text_prompt]
2283
- )
2284
-
2285
  # 上传文件自动更新
2286
  def update_gallery_on_unified_upload(files, interval):
2287
  if not files:
@@ -2411,7 +1540,7 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2411
  # 重建按钮
2412
  submit_btn.click(
2413
  fn=clear_fields,
2414
- outputs=[reconstruction_output, segmented_output]
2415
  ).then(
2416
  fn=update_log,
2417
  outputs=[log_output]
@@ -2420,11 +1549,10 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2420
  inputs=[
2421
  target_dir_output, frame_filter, show_cam,
2422
  filter_black_bg, filter_white_bg,
2423
- apply_mask_checkbox, show_mesh,
2424
- enable_segmentation, text_prompt
2425
  ],
2426
  outputs=[
2427
- reconstruction_output, segmented_output, log_output, frame_filter,
2428
  processed_data_state, depth_map, normal_map, measure_image,
2429
  measure_text, depth_view_selector, normal_view_selector, measure_view_selector
2430
  ]
@@ -2434,8 +1562,8 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2434
  )
2435
 
2436
  # 清空按钮
2437
- clear_btn.add([reconstruction_output, segmented_output, log_output])
2438
-
2439
  # 可视化参数实时更新
2440
  for component in [frame_filter, show_cam, show_mesh]:
2441
  component.change(
@@ -2457,7 +1585,7 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2457
  ],
2458
  outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
2459
  )
2460
-
2461
  # 深度图导航
2462
  prev_depth_btn.click(
2463
  fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
@@ -2514,17 +1642,4 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与
2514
  outputs=[measure_image, measure_points_state]
2515
  )
2516
 
2517
- # 启动信息
2518
- print("\n" + "="*70)
2519
- print("🚀 MapAnything V2 - 3D重建与物体分割")
2520
- print("="*70)
2521
- print("📊 核心技术: 自适应DBSCAN聚类 + 多视图融合")
2522
- print(f"🔧 质量控制: 置信度≥{MIN_DETECTION_CONFIDENCE} | 面积≥{MIN_MASK_AREA}px")
2523
- print(f"🎯 聚类半径: 沙发{DBSCAN_EPS_CONFIG['sofa']}m | 桌子{DBSCAN_EPS_CONFIG['table']}m | 窗户{DBSCAN_EPS_CONFIG['window']}m | 默认{DBSCAN_EPS_CONFIG['default']}m")
2524
- print("\n💡 分割配置 (CPU优化):")
2525
- print(f" - 检测模型: {GROUNDING_DINO_MODEL_ID} (CPU)")
2526
- print(f" - 分割模型: {SAM_MODEL_ID} (MobileSAM, 10MB, CPU)")
2527
- print(f" - 运行设备: CPU (不占用GPU资源,适合分离部署)")
2528
- print("="*70 + "\n")
2529
-
2530
  demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  """
8
+ MapAnything V2 - 3D重建系统(中文版)
9
+ - 多视图 3D 重建
10
+ - 深度估计与法线计算
11
+ - 距离测量功能
 
12
  """
13
 
14
  import gc
 
17
  import sys
18
  import time
19
  from datetime import datetime
 
 
20
 
21
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
22
 
 
25
  import numpy as np
26
  import spaces
27
  import torch
 
28
  from PIL import Image
29
  from pillow_heif import register_heif_opener
 
30
 
31
  register_heif_opener()
32
 
 
60
  return None
61
 
62
 
 
 
 
 
63
  # MapAnything Configuration
64
  high_level_config = {
65
  "path": "configs/train.yaml",
 
80
  "resolution": 518,
81
  }
82
 
83
+ # Initialize model - this will be done on GPU when needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
+ # -------------------------------------------------------------------------
88
+ # 1) Core model inference
89
+ # -------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  @spaces.GPU(duration=120)
91
  def run_model(
92
  target_dir,
 
94
  mask_edges=True,
95
  filter_black_bg=False,
96
  filter_white_bg=False,
 
 
97
  progress=gr.Progress(),
98
  ):
99
  """
100
+ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
101
  """
102
  global model
103
+ import torch # Ensure torch is available in function scope
104
 
105
+ start_time = time.time()
106
  print(f"Processing images from {target_dir}")
107
 
108
+ # Device check
109
+ progress(0, desc="🔧 初始化设备...")
110
  device = "cuda" if torch.cuda.is_available() else "cpu"
111
  device = torch.device(device)
112
 
113
+ # Initialize model if not already done
114
+ progress(0.05, desc="📥 加载模型... (~5秒)")
115
  if model is None:
116
  model = initialize_mapanything_model(high_level_config, device)
117
  else:
 
119
 
120
  model.eval()
121
 
122
+ # Load images using MapAnything's load_images function
123
+ progress(0.15, desc="📷 加载图片... (~2秒)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  print("Loading images...")
125
  image_folder_path = os.path.join(target_dir, "images")
126
  views = load_images(image_folder_path)
 
130
  raise ValueError("No images found. Check your upload.")
131
 
132
  # Run model inference
133
+ num_images = len(views)
134
+ estimated_time = num_images * 3 # 估计每张图片3秒
135
+ progress(0.2, desc=f"🚀 运行3D重建... ({num_images}张图片,预计{estimated_time}秒)")
136
  print("Running inference...")
137
+
138
+ inference_start = time.time()
139
  outputs = model.infer(
140
  views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
141
  )
142
+ inference_time = time.time() - inference_start
143
 
144
+ # Convert predictions to format expected by visualization
145
+ progress(0.6, desc=f"🔄 处理预测结果... (推理耗时: {inference_time:.1f}秒)")
146
  predictions = {}
147
+
148
+ # Initialize lists for the required keys
149
  extrinsic_list = []
150
  intrinsic_list = []
151
  world_points_list = []
 
153
  images_list = []
154
  final_mask_list = []
155
 
156
+ # Loop through the outputs
157
+ for i, pred in enumerate(outputs):
158
+ if i % max(1, len(outputs) // 5) == 0:
159
+ progress(0.6 + (i / len(outputs)) * 0.25, desc=f"🔄 处理视图 {i+1}/{len(outputs)}...")
160
+ # Extract data from predictions
161
+ depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
162
+ intrinsics_torch = pred["intrinsics"][0] # (3, 3)
163
+ camera_pose_torch = pred["camera_poses"][0] # (4, 4)
164
 
165
+ # Compute new pts3d using depth, intrinsics, and camera pose
166
  pts3d_computed, valid_mask = depthmap_to_world_frame(
167
  depthmap_torch, intrinsics_torch, camera_pose_torch
168
  )
169
 
170
+ # Convert to numpy arrays for visualization
171
+ # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
172
  if "mask" in pred:
173
  mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
174
  else:
175
+ # Fill with boolean trues in the size of depthmap_torch
176
  mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
177
 
178
+ # Combine with valid depth mask
179
  mask = mask & valid_mask.cpu().numpy()
180
+
181
  image = pred["img_no_norm"][0].cpu().numpy()
182
 
183
+ # Append to lists
184
  extrinsic_list.append(camera_pose_torch.cpu().numpy())
185
  intrinsic_list.append(intrinsics_torch.cpu().numpy())
186
  world_points_list.append(pts3d_computed.cpu().numpy())
187
  depth_maps_list.append(depthmap_torch.cpu().numpy())
188
+ images_list.append(image) # Add image to list
189
+ final_mask_list.append(mask) # Add final_mask to list
190
 
191
+ # Convert lists to numpy arrays with required shapes
192
+ # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
193
  predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
194
+
195
+ # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
196
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
197
+
198
+ # world_points: (S, H, W, 3) - batch of 3D world points
199
  predictions["world_points"] = np.stack(world_points_list, axis=0)
200
 
201
+ # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
202
  depth_maps = np.stack(depth_maps_list, axis=0)
203
+ # Add channel dimension if needed to match (S, H, W, 1) format
204
  if len(depth_maps.shape) == 3:
205
  depth_maps = depth_maps[..., np.newaxis]
206
+
207
  predictions["depth"] = depth_maps
208
 
209
+ # images: (S, H, W, 3) - batch of input images
210
  predictions["images"] = np.stack(images_list, axis=0)
211
+
212
+ # final_mask: (S, H, W) - batch of final masks for filtering
213
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
214
 
215
+ # Process data for visualization tabs (depth, normal, measure)
216
+ progress(0.85, desc="🎨 生成深度图与法线图...")
217
  processed_data = process_predictions_for_visualization(
218
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
219
  )
220
 
221
+ # Clean up
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  progress(0.95, desc="🧹 清理内存...")
223
  torch.cuda.empty_cache()
224
 
225
+ total_time = time.time() - start_time
226
+ progress(1.0, desc=f"✅ 完成!总耗时: {total_time:.1f}秒")
227
+ print(f"Total processing time: {total_time:.2f} seconds")
228
 
229
+ return predictions, processed_data
230
 
 
 
 
231
 
232
  def update_view_selectors(processed_data):
233
  """Update view selector dropdowns based on available views"""
 
238
  choices = [f"View {i + 1}" for i in range(num_views)]
239
 
240
  return (
241
+ gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
242
+ gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
243
+ gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
244
  )
245
 
246
 
 
278
  """Update measure view for a specific view index with mask overlay"""
279
  view_data = get_view_data_by_index(processed_data, view_index)
280
  if view_data is None:
281
+ return None, [] # image, measure_points
282
 
283
+ # Get the base image
284
  image = view_data["image"].copy()
285
 
286
+ # Ensure image is in uint8 format
287
  if image.dtype != np.uint8:
288
  if image.max() <= 1.0:
289
  image = (image * 255).astype(np.uint8)
290
  else:
291
  image = image.astype(np.uint8)
292
 
293
+ # Apply mask overlay if mask is available
294
  if view_data["mask"] is not None:
295
  mask = view_data["mask"]
296
+
297
+ # Create light grey overlay for masked areas
298
+ # Masked areas (False values) will be overlaid with light grey
299
+ invalid_mask = ~mask # Areas where mask is False
300
 
301
  if invalid_mask.any():
302
+ # Create a light grey overlay (RGB: 192, 192, 192)
303
  overlay_color = np.array([255, 220, 220], dtype=np.uint8)
304
+
305
+ # Apply overlay with some transparency
306
+ alpha = 0.5 # Transparency level
307
+ for c in range(3): # RGB channels
308
  image[:, :, c] = np.where(
309
  invalid_mask,
310
  (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
 
319
  if processed_data is None or len(processed_data) == 0:
320
  return "View 1", None
321
 
322
+ # Parse current view number
323
  try:
324
  current_view = int(current_selector_value.split()[1]) - 1
325
  except:
 
339
  if processed_data is None or len(processed_data) == 0:
340
  return "View 1", None
341
 
342
+ # Parse current view number
343
  try:
344
  current_view = int(current_selector_value.split()[1]) - 1
345
  except:
 
359
  if processed_data is None or len(processed_data) == 0:
360
  return "View 1", None, []
361
 
362
+ # Parse current view number
363
  try:
364
  current_view = int(current_selector_value.split()[1]) - 1
365
  except:
 
379
  if processed_data is None or len(processed_data) == 0:
380
  return None, None, None, []
381
 
382
+ # Use update functions to ensure confidence filtering is applied from the start
383
  depth_vis = update_depth_view(processed_data, 0)
384
  normal_vis = update_normal_view(processed_data, 0)
385
  measure_img, _ = update_measure_view(processed_data, 0)
 
387
  return depth_vis, normal_vis, measure_img, []
388
 
389
 
390
+ # -------------------------------------------------------------------------
391
+ # 2) Handle uploaded video/images --> produce target_dir + images
392
+ # -------------------------------------------------------------------------
393
  def handle_uploads(unified_upload, s_time_interval=1.0):
394
  """
395
  Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
 
399
  gc.collect()
400
  torch.cuda.empty_cache()
401
 
402
+ # Create a unique folder name
403
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
404
  target_dir = f"input_images_{timestamp}"
405
  target_dir_images = os.path.join(target_dir, "images")
406
 
407
+ # Clean up if somehow that folder already exists
408
  if os.path.exists(target_dir):
409
  shutil.rmtree(target_dir)
410
  os.makedirs(target_dir)
 
412
 
413
  image_paths = []
414
 
415
+ # --- Handle uploaded files (both images and videos) ---
416
  if unified_upload is not None:
417
  for file_data in unified_upload:
418
  if isinstance(file_data, dict) and "name" in file_data:
 
422
 
423
  file_ext = os.path.splitext(file_path)[1].lower()
424
 
425
+ # Check if it's a video file
426
  video_extensions = [
427
+ ".mp4",
428
+ ".avi",
429
+ ".mov",
430
+ ".mkv",
431
+ ".wmv",
432
+ ".flv",
433
+ ".webm",
434
+ ".m4v",
435
+ ".3gp",
436
  ]
437
  if file_ext in video_extensions:
438
+ # Handle as video
439
  vs = cv2.VideoCapture(file_path)
440
  fps = vs.get(cv2.CAP_PROP_FPS)
441
+ frame_interval = int(fps * s_time_interval) # frames per interval
442
 
443
  count = 0
444
  video_frame_num = 0
 
448
  break
449
  count += 1
450
  if count % frame_interval == 0:
451
+ # Use original filename as prefix for frames
452
  base_name = os.path.splitext(os.path.basename(file_path))[0]
453
  image_path = os.path.join(
454
  target_dir_images, f"{base_name}_{video_frame_num:06}.png"
 
457
  image_paths.append(image_path)
458
  video_frame_num += 1
459
  vs.release()
460
+ print(
461
+ f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
462
+ )
463
 
464
  else:
465
+ # Handle as image
466
+ # Check if the file is a HEIC image
467
  if file_ext in [".heic", ".heif"]:
468
+ # Convert HEIC to JPEG for better gallery compatibility
469
  try:
470
  with Image.open(file_path) as img:
471
+ # Convert to RGB if necessary (HEIC can have different color modes)
472
  if img.mode not in ("RGB", "L"):
473
  img = img.convert("RGB")
474
 
475
+ # Create JPEG filename
476
  base_name = os.path.splitext(os.path.basename(file_path))[0]
477
+ dst_path = os.path.join(
478
+ target_dir_images, f"{base_name}.jpg"
479
+ )
480
 
481
+ # Save as JPEG with high quality
482
  img.save(dst_path, "JPEG", quality=95)
483
  image_paths.append(dst_path)
484
+ print(
485
+ f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
486
+ )
487
  except Exception as e:
488
  print(f"Error converting HEIC file {file_path}: {e}")
489
+ # Fall back to copying as is
490
+ dst_path = os.path.join(
491
+ target_dir_images, os.path.basename(file_path)
492
+ )
493
  shutil.copy(file_path, dst_path)
494
  image_paths.append(dst_path)
495
  else:
496
+ # Regular image files - copy as is
497
+ dst_path = os.path.join(
498
+ target_dir_images, os.path.basename(file_path)
499
+ )
500
  shutil.copy(file_path, dst_path)
501
  image_paths.append(dst_path)
502
 
503
+ # Sort final images for gallery
504
  image_paths = sorted(image_paths)
505
 
506
  end_time = time.time()
507
+ print(
508
+ f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
509
+ )
510
  return target_dir, image_paths
511
 
512
 
513
+ # -------------------------------------------------------------------------
514
+ # 3) Update gallery on upload
515
+ # -------------------------------------------------------------------------
516
  def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
517
+ """
518
+ Whenever user uploads or changes files, immediately handle them
519
+ and show in the gallery. Return (target_dir, image_paths).
520
+ If nothing is uploaded, returns "None" and empty list.
521
+ """
522
  if not input_video and not input_images:
523
+ return None, None, None, None
524
  target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
525
  return (
 
526
  None,
527
  target_dir,
528
  image_paths,
529
+ "上传完成。点击「开始重建」进行3D处理",
530
  )
531
 
532
 
533
+ # -------------------------------------------------------------------------
534
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
535
+ # -------------------------------------------------------------------------
536
  @spaces.GPU(duration=120)
537
  def gradio_demo(
538
  target_dir,
 
542
  filter_white_bg=False,
543
  apply_mask=True,
544
  show_mesh=True,
 
 
545
  progress=gr.Progress(),
546
  ):
547
+ """
548
+ Perform reconstruction using the already-created target_dir/images.
549
+ """
550
  if not os.path.isdir(target_dir) or target_dir == "None":
551
+ return None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None
552
 
553
  progress(0, desc="🔄 准备重建...")
554
  start_time = time.time()
555
  gc.collect()
556
  torch.cuda.empty_cache()
557
 
558
+ # Prepare frame_filter dropdown
559
  target_dir_images = os.path.join(target_dir, "images")
560
  all_files = (
561
  sorted(os.listdir(target_dir_images))
 
565
  all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
566
  frame_filter_choices = ["All"] + all_files
567
 
568
+ progress(0.05, desc=f"🚀 运行 MapAnything 模型... ({len(all_files)}张图片)")
569
+ print("Running MapAnything model...")
570
  with torch.no_grad():
571
+ predictions, processed_data = run_model(
572
+ target_dir, apply_mask, True, filter_black_bg, filter_white_bg, progress
 
573
  )
574
 
575
+ # Save predictions
576
  progress(0.92, desc="💾 保存预测结果...")
577
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
578
  np.savez(prediction_save_path, **predictions)
579
 
580
+ # Handle None frame_filter
581
  if frame_filter is None:
582
  frame_filter = "All"
583
 
584
+ # Build a GLB file name
585
+ progress(0.93, desc="🏗️ 生成3D模型文件...")
586
  glbfile = os.path.join(
587
  target_dir,
588
+ f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
589
  )
590
 
591
+ # Convert predictions to GLB
592
  glbscene = predictions_to_glb(
593
  predictions,
594
  filter_by_frames=frame_filter,
595
  show_cam=show_cam,
596
  mask_black_bg=filter_black_bg,
597
  mask_white_bg=filter_white_bg,
598
+ as_mesh=show_mesh, # Use the show_mesh parameter
599
  )
600
  glbscene.export(file_obj=glbfile)
601
 
602
+ # Cleanup
603
  progress(0.96, desc="🧹 清理内存...")
604
  del predictions
605
  gc.collect()
606
  torch.cuda.empty_cache()
607
 
608
  end_time = time.time()
609
+ total_time = end_time - start_time
610
+ print(f"总耗时: {total_time:.2f}秒")
611
+ log_msg = f"✅ 重建成功 ({len(all_files)} 帧,耗时 {total_time:.1f}秒)"
612
 
613
+ # Populate visualization tabs with processed data
614
  progress(0.98, desc="🎨 生成可视化...")
615
  depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
616
  processed_data
617
  )
618
 
619
+ # Update view selectors based on available views
620
  depth_selector, normal_selector, measure_selector = update_view_selectors(
621
  processed_data
622
  )
623
 
624
  progress(1.0, desc="✅ 全部完成!")
 
 
 
 
 
 
 
625
 
626
  return (
627
  glbfile,
 
628
  log_msg,
629
  gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
630
  processed_data,
631
  depth_vis,
632
  normal_vis,
633
  measure_img,
634
+ "", # measure_text (empty initially)
635
  depth_selector,
636
  normal_selector,
637
  measure_selector,
638
  )
639
 
640
 
641
+ # -------------------------------------------------------------------------
642
+ # 5) Helper functions for UI resets + re-visualization
643
+ # -------------------------------------------------------------------------
644
  def colorize_depth(depth_map, mask=None):
645
  """Convert depth map to colorized visualization with optional mask"""
646
  if depth_map is None:
647
  return None
648
 
649
+ # Normalize depth to 0-1 range
650
  depth_normalized = depth_map.copy()
651
  valid_mask = depth_normalized > 0
652
 
653
+ # Apply additional mask if provided (for background filtering)
654
  if mask is not None:
655
  valid_mask = valid_mask & mask
656
 
 
661
 
662
  depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
663
 
664
+ # Apply colormap
665
  import matplotlib.pyplot as plt
666
 
667
  colormap = plt.cm.turbo_r
668
  colored = colormap(depth_normalized)
669
  colored = (colored[:, :, :3] * 255).astype(np.uint8)
670
 
671
+ # Set invalid pixels to white
672
  colored[~valid_mask] = [255, 255, 255]
673
 
674
  return colored
 
679
  if normal_map is None:
680
  return None
681
 
682
+ # Create a copy for modification
683
  normal_vis = normal_map.copy()
684
 
685
+ # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
686
  if mask is not None:
687
  invalid_mask = ~mask
688
+ normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
689
 
690
+ # Normalize normals to [0, 1] range for visualization
691
  normal_vis = (normal_vis + 1.0) / 2.0
692
  normal_vis = (normal_vis * 255).astype(np.uint8)
693
 
 
700
  """Extract depth, normal, and 3D points from predictions for visualization"""
701
  processed_data = {}
702
 
703
+ # Process each view
704
  for view_idx, view in enumerate(views):
705
+ # Get image
706
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
707
 
708
+ # Get predicted points
709
  pred_pts3d = predictions["world_points"][view_idx]
710
 
711
+ # Initialize data for this view
712
  view_data = {
713
  "image": image[0],
714
  "points3d": pred_pts3d,
 
717
  "mask": None,
718
  }
719
 
720
+ # Start with the final mask from predictions
721
  mask = predictions["final_mask"][view_idx].copy()
722
 
723
+ # Apply black background filtering if enabled
724
  if filter_black_bg:
725
+ # Get the image colors (ensure they're in 0-255 range)
726
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
727
+ # Filter out black background pixels (sum of RGB < 16)
728
  black_bg_mask = view_colors.sum(axis=2) >= 16
729
  mask = mask & black_bg_mask
730
 
731
+ # Apply white background filtering if enabled
732
  if filter_white_bg:
733
+ # Get the image colors (ensure they're in 0-255 range)
734
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
735
+ # Filter out white background pixels (all RGB > 240)
736
  white_bg_mask = ~(
737
  (view_colors[:, :, 0] > 240)
738
  & (view_colors[:, :, 1] > 240)
 
756
  if processed_data is None or len(processed_data) == 0:
757
  return None, [], ""
758
 
759
+ # Return the first view image
760
  first_view = list(processed_data.values())[0]
761
  return first_view["image"], [], ""
762
 
 
766
  ):
767
  """Handle measurement on images"""
768
  try:
769
+ print(f"Measure function called with selector: {current_view_selector}")
770
 
771
  if processed_data is None or len(processed_data) == 0:
772
+ return None, [], "No data available"
773
 
774
+ # Use the currently selected view instead of always using the first view
775
  try:
776
  current_view_index = int(current_view_selector.split()[1]) - 1
777
  except:
778
  current_view_index = 0
779
 
780
+ print(f"Using view index: {current_view_index}")
781
 
782
+ # Get view data safely
783
  if current_view_index < 0 or current_view_index >= len(processed_data):
784
  current_view_index = 0
785
 
 
787
  current_view = processed_data[view_keys[current_view_index]]
788
 
789
  if current_view is None:
790
+ return None, [], "No view data available"
791
 
792
  point2d = event.index[0], event.index[1]
793
+ print(f"Clicked point: {point2d}")
794
 
795
+ # Check if the clicked point is in a masked area (prevent interaction)
796
  if (
797
  current_view["mask"] is not None
798
  and 0 <= point2d[1] < current_view["mask"].shape[0]
799
  and 0 <= point2d[0] < current_view["mask"].shape[1]
800
  ):
801
+ # Check if the point is in a masked (invalid) area
802
  if not current_view["mask"][point2d[1], point2d[0]]:
803
+ print(f"Clicked point {point2d} is in masked area, ignoring click")
804
+ # Always return image with mask overlay
805
  masked_image, _ = update_measure_view(
806
  processed_data, current_view_index
807
  )
808
  return (
809
  masked_image,
810
  measure_points,
811
+ '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
812
  )
813
 
814
  measure_points.append(point2d)
815
 
816
+ # Get image with mask overlay and ensure it's valid
817
  image, _ = update_measure_view(processed_data, current_view_index)
818
  if image is None:
819
+ return None, [], "No image available"
820
 
821
  image = image.copy()
822
  points3d = current_view["points3d"]
823
 
824
+ # Ensure image is in uint8 format for proper cv2 operations
825
  try:
826
  if image.dtype != np.uint8:
827
  if image.max() <= 1.0:
828
+ # Image is in [0, 1] range, convert to [0, 255]
829
  image = (image * 255).astype(np.uint8)
830
  else:
831
+ # Image is already in [0, 255] range
832
  image = image.astype(np.uint8)
833
  except Exception as e:
834
+ print(f"Image conversion error: {e}")
835
+ return None, [], f"Image conversion error: {e}"
836
 
837
+ # Draw circles for points
838
  try:
839
  for p in measure_points:
840
  if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
 
842
  image, p, radius=5, color=(255, 0, 0), thickness=2
843
  )
844
  except Exception as e:
845
+ print(f"Drawing error: {e}")
846
+ return None, [], f"Drawing error: {e}"
847
 
848
  depth_text = ""
849
  try:
 
854
  and 0 <= p[0] < current_view["depth"].shape[1]
855
  ):
856
  d = current_view["depth"][p[1], p[0]]
857
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
858
  else:
859
+ # Use Z coordinate of 3D points if depth not available
860
  if (
861
  points3d is not None
862
  and 0 <= p[1] < points3d.shape[0]
863
  and 0 <= p[0] < points3d.shape[1]
864
  ):
865
  z = points3d[p[1], p[0], 2]
866
+ depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
867
  except Exception as e:
868
+ print(f"Depth text error: {e}")
869
+ depth_text = f"Error computing depth: {e}\n"
870
 
871
  if len(measure_points) == 2:
872
  try:
873
  point1, point2 = measure_points
874
+ # Draw line
875
  if (
876
  0 <= point1[0] < image.shape[1]
877
  and 0 <= point1[1] < image.shape[0]
 
882
  image, point1, point2, color=(255, 0, 0), thickness=2
883
  )
884
 
885
+ # Compute 3D distance
886
+ distance_text = "- **Distance: Unable to compute**"
887
  if (
888
  points3d is not None
889
  and 0 <= point1[1] < points3d.shape[0]
 
895
  p1_3d = points3d[point1[1], point1[0]]
896
  p2_3d = points3d[point2[1], point2[0]]
897
  distance = np.linalg.norm(p1_3d - p2_3d)
898
+ distance_text = f"- **Distance: {distance:.2f}m**"
899
  except Exception as e:
900
+ print(f"Distance computation error: {e}")
901
+ distance_text = f"- **Distance computation error: {e}**"
902
 
903
  measure_points = []
904
  text = depth_text + distance_text
905
+ print(f"Measurement complete: {text}")
906
  return [image, measure_points, text]
907
  except Exception as e:
908
+ print(f"Final measurement error: {e}")
909
+ return None, [], f"Measurement error: {e}"
910
  else:
911
+ print(f"Single point measurement: {depth_text}")
912
  return [image, measure_points, depth_text]
913
 
914
  except Exception as e:
915
+ print(f"Overall measure function error: {e}")
916
+ return None, [], f"Measure function error: {e}"
917
 
918
 
919
  def clear_fields():
920
+ """
921
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
922
+ """
923
+ return None
924
 
925
 
926
  def update_log():
927
+ """
928
+ Display a quick log message while waiting.
929
+ """
930
+ return "加载和重建中..."
931
 
932
 
933
  def update_visualization(
 
939
  filter_white_bg=False,
940
  show_mesh=True,
941
  ):
942
+ """
943
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
944
+ and return it for the 3D viewer. If is_example == "True", skip.
945
+ """
946
+
947
+ # If it's an example click, skip as requested
948
  if is_example == "True":
949
+ return (
950
+ gr.update(),
951
+ "没有可用的重建。请先点击重建按钮。",
952
+ )
953
 
954
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
955
+ return (
956
+ gr.update(),
957
+ "没有可用的重建。请先点击重建按钮。",
958
+ )
959
 
960
  predictions_path = os.path.join(target_dir, "predictions.npz")
961
  if not os.path.exists(predictions_path):
962
+ return (
963
+ gr.update(),
964
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
965
+ )
966
 
967
  loaded = np.load(predictions_path, allow_pickle=True)
968
  predictions = {key: loaded[key] for key in loaded.keys()}
 
972
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
973
  )
974
 
975
+ if not os.path.exists(glbfile):
976
+ glbscene = predictions_to_glb(
977
+ predictions,
978
+ filter_by_frames=frame_filter,
979
+ show_cam=show_cam,
980
+ mask_black_bg=filter_black_bg,
981
+ mask_white_bg=filter_white_bg,
982
+ as_mesh=show_mesh,
983
+ )
984
+ glbscene.export(file_obj=glbfile)
985
 
986
+ return (
987
+ glbfile,
988
+ "可视化已更新",
989
+ )
990
 
991
 
992
  def update_all_views_on_filter_change(
 
998
  normal_view_selector,
999
  measure_view_selector,
1000
  ):
1001
+ """
1002
+ Update all individual view tabs when background filtering checkboxes change.
1003
+ This regenerates the processed data with new filtering and updates all views.
1004
+ """
1005
+ # Check if we have a valid target directory and predictions
1006
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
1007
  return processed_data, None, None, None, []
1008
 
 
1011
  return processed_data, None, None, None, []
1012
 
1013
  try:
1014
+ # Load the original predictions and views
1015
  loaded = np.load(predictions_path, allow_pickle=True)
1016
  predictions = {key: loaded[key] for key in loaded.keys()}
1017
 
1018
+ # Load images using MapAnything's load_images function
1019
  image_folder_path = os.path.join(target_dir, "images")
1020
  views = load_images(image_folder_path)
1021
 
1022
+ # Regenerate processed data with new filtering settings
1023
  new_processed_data = process_predictions_for_visualization(
1024
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
1025
  )
1026
 
1027
+ # Get current view indices
1028
  try:
1029
  depth_view_idx = (
1030
  int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
 
1048
  except:
1049
  measure_view_idx = 0
1050
 
1051
+ # Update all views with new filtered data
1052
  depth_vis = update_depth_view(new_processed_data, depth_view_idx)
1053
  normal_vis = update_normal_view(new_processed_data, normal_view_idx)
1054
  measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
 
1060
  return processed_data, None, None, None, []
1061
 
1062
 
1063
+ # -------------------------------------------------------------------------
1064
+ # Example scene functions
1065
+ # -------------------------------------------------------------------------
 
1066
  def get_scene_info(examples_dir):
1067
  """Get information about scenes in the examples directory"""
1068
  import glob
 
1074
  for scene_folder in sorted(os.listdir(examples_dir)):
1075
  scene_path = os.path.join(examples_dir, scene_folder)
1076
  if os.path.isdir(scene_path):
1077
+ # Find all image files in the scene folder
1078
  image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
1079
  image_files = []
1080
  for ext in image_extensions:
 
1082
  image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
1083
 
1084
  if image_files:
1085
+ # Sort images and get the first one for thumbnail
1086
  image_files = sorted(image_files)
1087
  first_image = image_files[0]
1088
  num_images = len(image_files)
 
1101
 
1102
 
1103
  def load_example_scene(scene_name, examples_dir="examples"):
1104
+ """Load a scene from examples directory"""
1105
  scenes = get_scene_info(examples_dir)
1106
 
1107
+ # Find the selected scene
1108
  selected_scene = None
1109
  for scene in scenes:
1110
  if scene["name"] == scene_name:
 
1112
  break
1113
 
1114
  if selected_scene is None:
1115
+ return None, None, None, "Scene not found"
1116
 
1117
+ # Create file-like objects for the unified upload system
1118
+ # Convert image file paths to the format expected by unified_upload
1119
  file_objects = []
1120
  for image_path in selected_scene["image_files"]:
1121
  file_objects.append(image_path)
1122
 
1123
+ # Create target directory and copy images using the unified upload system
1124
  target_dir, image_paths = handle_uploads(file_objects, 1.0)
1125
 
1126
  return (
1127
+ None, # Clear reconstruction output
1128
+ target_dir, # Set target directory
1129
+ image_paths, # Set gallery
1130
+ f"已加载场景 '{scene_name}'{selected_scene['num_images']} 张图片)。点击「开始重建」进行3D处理。",
1131
  )
1132
 
1133
 
1134
+ # -------------------------------------------------------------------------
1135
+ # 6) Build Gradio UI
1136
+ # -------------------------------------------------------------------------
 
1137
  theme = get_gradio_theme()
1138
 
1139
  # 自定义CSS防止UI抖动
 
1196
  }
1197
  """
1198
 
1199
+ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything - 3D重建系统") as demo:
1200
+ # State variables for the tabbed interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1201
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
1202
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
1203
  processed_data_state = gr.State(value=None)
1204
  measure_points_state = gr.State(value=[])
1205
+ current_view_index = gr.State(value=0) # Track current view index for navigation
1206
 
1207
  # 添加粘贴板支持的 JavaScript
1208
+ PASTE_JS = """
1209
+ <script>
1210
+ // 添加粘贴板支持
1211
+ document.addEventListener('paste', function(e) {
1212
+ const items = e.clipboardData.items;
1213
+ for (let i = 0; i < items.length; i++) {
1214
+ if (items[i].type.indexOf('image') !== -1) {
1215
+ const blob = items[i].getAsFile();
1216
+ const fileInput = document.querySelector('input[type="file"][multiple]');
1217
+ if (fileInput) {
1218
+ const dataTransfer = new DataTransfer();
1219
+ dataTransfer.items.add(blob);
1220
+ fileInput.files = dataTransfer.files;
1221
+ fileInput.dispatchEvent(new Event('change', { bubbles: true }));
1222
+ console.log('✅ 图片已从剪贴板粘贴');
1223
+ }
1224
+ }
1225
+ }
1226
+ });
1227
+ console.log('💡 粘贴板功能已启用:使用 Ctrl+V 可直接粘贴截图');
1228
+ </script>
1229
+ """
1230
  gr.HTML(PASTE_JS)
1231
+
1232
+ # 美化的顶部标题
1233
  gr.HTML("""
1234
  <div style="text-align: center; margin: 20px 0;">
1235
+ <h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything - 3D重建系统</h2>
1236
+ <p style="color: #666; font-size: 16px;">多视图3D重建 | 深度估计 | 法线计算 | 距离测量</p>
1237
  </div>
1238
  """)
1239
 
 
1306
  clear_color=[0.0, 0.0, 0.0, 0.0]
1307
  )
1308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1309
  with gr.Tab("📊 深度图"):
1310
  with gr.Row(elem_classes=["navigation-row"]):
1311
  prev_depth_btn = gr.Button("◀", size="sm", scale=1)
 
1356
  max_lines=1
1357
  )
1358
 
1359
+ # 高级选项(默认折叠)
1360
+ with gr.Accordion("⚙️ 高级选项", open=False):
1361
  with gr.Row(equal_height=False):
1362
  with gr.Column(scale=1, min_width=300):
1363
  gr.Markdown("#### 可视化参数")
 
1374
  apply_mask_checkbox = gr.Checkbox(
1375
  label="应用深度掩码", value=True
1376
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1377
  # 示例场景(可折叠)
1378
  with gr.Accordion("🖼️ 示例场景", open=False):
1379
+ gr.Markdown("点击缩略图加载场景进行重建")
1380
  scenes = get_scene_info("examples")
1381
+
1382
  if scenes:
1383
+ for i in range(0, len(scenes), 4): # Process 4 scenes per row
1384
  with gr.Row(equal_height=True):
1385
  for j in range(4):
1386
  scene_idx = i + j
 
1388
  scene = scenes[scene_idx]
1389
  with gr.Column(scale=1, min_width=150):
1390
  scene_img = gr.Image(
1391
+ value=scene["thumbnail"],
1392
  height=150,
1393
+ interactive=False,
1394
+ show_label=False,
1395
  sources=[],
1396
  container=False
1397
  )
 
1403
  fn=lambda name=scene["name"]: load_example_scene(name),
1404
  outputs=[
1405
  reconstruction_output,
1406
+ target_dir_output,
1407
+ image_gallery,
1408
+ log_output,
1409
+ ],
1410
  )
1411
 
1412
  # === 事件绑定 ===
1413
 
 
 
 
 
 
 
 
 
 
 
1414
  # 上传文件自动更新
1415
  def update_gallery_on_unified_upload(files, interval):
1416
  if not files:
 
1540
  # 重建按钮
1541
  submit_btn.click(
1542
  fn=clear_fields,
1543
+ outputs=[reconstruction_output]
1544
  ).then(
1545
  fn=update_log,
1546
  outputs=[log_output]
 
1549
  inputs=[
1550
  target_dir_output, frame_filter, show_cam,
1551
  filter_black_bg, filter_white_bg,
1552
+ apply_mask_checkbox, show_mesh
 
1553
  ],
1554
  outputs=[
1555
+ reconstruction_output, log_output, frame_filter,
1556
  processed_data_state, depth_map, normal_map, measure_image,
1557
  measure_text, depth_view_selector, normal_view_selector, measure_view_selector
1558
  ]
 
1562
  )
1563
 
1564
  # 清空按钮
1565
+ clear_btn.add([reconstruction_output, log_output])
1566
+
1567
  # 可视化参数实时更新
1568
  for component in [frame_filter, show_cam, show_mesh]:
1569
  component.change(
 
1585
  ],
1586
  outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
1587
  )
1588
+
1589
  # 深度图导航
1590
  prev_depth_btn.click(
1591
  fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
 
1642
  outputs=[measure_image, measure_points_state]
1643
  )
1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1645
  demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)