shriarul5273 commited on
Commit
8ca4dce
·
1 Parent(s): 8e11b47

added CREStereo and FoundationStereo code

Browse files
Files changed (37) hide show
  1. .gitattributes +1 -0
  2. .gitignore +3 -0
  3. CREStereo_demo/app.py +967 -0
  4. CREStereo_demo/app_local.py +889 -0
  5. CREStereo_demo/models/.gitkeep +0 -0
  6. CREStereo_demo/models/crestereo_eth3d.pth +3 -0
  7. CREStereo_demo/nets/__init__.py +1 -0
  8. CREStereo_demo/nets/attention/__init__.py +2 -0
  9. CREStereo_demo/nets/attention/linear_attention.py +81 -0
  10. CREStereo_demo/nets/attention/position_encoding.py +41 -0
  11. CREStereo_demo/nets/attention/transformer.py +100 -0
  12. CREStereo_demo/nets/corr.py +148 -0
  13. CREStereo_demo/nets/crestereo.py +258 -0
  14. CREStereo_demo/nets/extractor.py +123 -0
  15. CREStereo_demo/nets/update.py +91 -0
  16. CREStereo_demo/nets/utils/__init__.py +1 -0
  17. CREStereo_demo/nets/utils/utils.py +108 -0
  18. FoundationStereo_demo/Utils.py +160 -0
  19. FoundationStereo_demo/app.py +1138 -0
  20. FoundationStereo_demo/app_local.py +1292 -0
  21. FoundationStereo_demo/core/extractor.py +371 -0
  22. FoundationStereo_demo/core/foundation_stereo.py +277 -0
  23. FoundationStereo_demo/core/geometry.py +77 -0
  24. FoundationStereo_demo/core/submodule.py +588 -0
  25. FoundationStereo_demo/core/update.py +159 -0
  26. FoundationStereo_demo/core/utils/utils.py +64 -0
  27. FoundationStereo_demo/depth_anything/LICENSE.txt +201 -0
  28. FoundationStereo_demo/depth_anything/__init__.py +2 -0
  29. FoundationStereo_demo/depth_anything/blocks.py +153 -0
  30. FoundationStereo_demo/depth_anything/dpt.py +203 -0
  31. FoundationStereo_demo/depth_anything/util/transform.py +248 -0
  32. assets/example1/K.txt +2 -0
  33. assets/example1/left.png +3 -0
  34. assets/example1/right.png +3 -0
  35. assets/example2/K.txt +9 -0
  36. assets/example2/left.png +3 -0
  37. assets/example2/right.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.Identifier
2
+ __pycache__/
3
+ *.pyc
CREStereo_demo/app.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CREStereo Gradio Demo with ZeroGPU Integration
3
+
4
+ This demo showcases the CREStereo model for stereo depth estimation.
5
+ Optimized for Hugging Face Spaces with ZeroGPU support.
6
+
7
+ Key ZeroGPU optimizations:
8
+ - @spaces.GPU decorators for GPU-intensive functions
9
+ - CUDA operations only within GPU context
10
+ - Memory-efficient inference with cleanup
11
+ - Safe CUDA initialization patterns
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import logging
17
+ import tempfile
18
+ import gc
19
+ from pathlib import Path
20
+ from typing import Optional, Tuple, Union
21
+ import numpy as np
22
+ import cv2
23
+ import gradio as gr
24
+ import imageio
25
+
26
+ # Import spaces BEFORE torch to ensure proper ZeroGPU initialization
27
+ import spaces
28
+
29
+ # Import torch after spaces - avoid any CUDA calls during import
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from torch.cuda.amp import autocast
34
+
35
+ # Completely avoid CUDA operations during import phase
36
+ # Do not set default tensor type or modify CUDA settings outside GPU context
37
+ # torch.set_default_tensor_type('torch.FloatTensor') # Commented out - causes CUDA init
38
+
39
+ # Do not modify CUDA settings during import - this can trigger CUDA initialization
40
+ # torch.backends.cudnn.enabled = False # Commented out
41
+ # torch.backends.cudnn.benchmark = False # Commented out
42
+
43
+ # Use current directory as base
44
+ current_dir = os.path.dirname(os.path.abspath(__file__))
45
+ base_dir = current_dir
46
+
47
+ # Add current directory to path for local imports
48
+ sys.path.insert(0, current_dir)
49
+
50
+ # Import local modules
51
+ from nets import Model
52
+
53
+ # Import Open3D with error handling
54
+ OPEN3D_AVAILABLE = False
55
+ try:
56
+ # Set Open3D to CPU mode to avoid CUDA initialization
57
+ os.environ['OPEN3D_CPU_RENDERING'] = '1'
58
+ # Don't import open3d here - do it inside functions
59
+ # import open3d as o3d
60
+ OPEN3D_AVAILABLE = True # Assume available, will check later
61
+ except Exception as e:
62
+ logging.warning(f"Open3D setup failed: {e}")
63
+ OPEN3D_AVAILABLE = False
64
+
65
+ # Configure logging
66
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
67
+
68
+ # Model configuration
69
+ MODEL_VARIANTS = {
70
+ "crestereo_eth3d": {
71
+ "display_name": "CREStereo ETH3D (Pre-trained model)",
72
+ "model_file": "models/crestereo_eth3d.pth",
73
+ "max_disp": 256
74
+ }
75
+ }
76
+
77
+ # Global variables for model caching
78
+ _cached_model = None
79
+ _cached_device = None
80
+ _cached_model_selection = None
81
+
82
+
83
+ class InputPadder:
84
+ """ Pads images such that dimensions are divisible by divis_by """
85
+ def __init__(self, dims, divis_by=8, force_square=False):
86
+ self.ht, self.wd = dims[-2:]
87
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
88
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
89
+
90
+ if force_square:
91
+ # Make the padded dimensions square
92
+ max_dim = max(self.ht + pad_ht, self.wd + pad_wd)
93
+ pad_ht = max_dim - self.ht
94
+ pad_wd = max_dim - self.wd
95
+
96
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
97
+
98
+ def pad(self, *inputs):
99
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
100
+
101
+ def unpad(self, x):
102
+ ht, wd = x.shape[-2:]
103
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
104
+ return x[..., c[0]:c[1], c[2]:c[3]]
105
+
106
+
107
+ def aggressive_cleanup():
108
+ """Perform basic cleanup - no CUDA operations outside GPU context"""
109
+ import gc
110
+ gc.collect()
111
+ logging.info("Performed basic memory cleanup")
112
+
113
+
114
+ @spaces.GPU
115
+ def initialize_gpu_context():
116
+ """Initialize GPU context safely for ZeroGPU"""
117
+ try:
118
+ # Set CUDA settings safely within GPU context
119
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
120
+ torch.backends.cudnn.enabled = True
121
+ torch.backends.cudnn.benchmark = True
122
+
123
+ # Check GPU availability and log info
124
+ if torch.cuda.is_available():
125
+ device_name = torch.cuda.get_device_name(0)
126
+ memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
127
+ logging.info(f"GPU initialized: {device_name}, Total memory: {memory_total:.2f}GB")
128
+ return True
129
+ else:
130
+ logging.error("CUDA not available after GPU context initialization")
131
+ return False
132
+ except Exception as e:
133
+ logging.error(f"GPU context initialization failed: {e}")
134
+ return False
135
+
136
+
137
+ @spaces.GPU
138
+ def check_gpu_memory():
139
+ """Check and log current GPU memory usage - only call within GPU context"""
140
+ try:
141
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
142
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
143
+ max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
144
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
145
+
146
+ logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
147
+ return allocated, reserved, max_allocated, total
148
+ except RuntimeError as e:
149
+ logging.warning(f"Failed to get GPU memory info: {e}")
150
+ return None, None, None, None
151
+
152
+
153
+ def get_available_models() -> dict:
154
+ """Get all available models with their display names"""
155
+ models = {}
156
+
157
+ # Check for local models
158
+ for variant, info in MODEL_VARIANTS.items():
159
+ model_path = os.path.join(current_dir, info["model_file"])
160
+
161
+ if os.path.exists(model_path):
162
+ display_name = info["display_name"]
163
+ models[display_name] = {
164
+ "model_path": model_path,
165
+ "variant": variant,
166
+ "max_disp": info["max_disp"],
167
+ "source": "local"
168
+ }
169
+
170
+ return models
171
+
172
+
173
+ def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]:
174
+ """Get model path and config from the selected model"""
175
+ models = get_available_models()
176
+
177
+ # Check if it's in our models dict
178
+ if model_selection in models:
179
+ model_info = models[model_selection]
180
+ logging.info(f"📁 Using local model: {model_selection}")
181
+ return model_info["model_path"], model_info
182
+
183
+ return None, None
184
+
185
+
186
+ @spaces.GPU
187
+ def load_model_for_inference(model_path: str, model_info: dict):
188
+ """Load CREStereo model for inference temporarily (demo-style)"""
189
+ # Set CUDA settings safely within GPU context
190
+ torch.set_default_tensor_type('torch.cuda.FloatTensor') # Now safe to use CUDA tensors
191
+ torch.backends.cudnn.enabled = True
192
+ torch.backends.cudnn.benchmark = True
193
+
194
+ # Check if CUDA is available after ZeroGPU initialization
195
+ if not torch.cuda.is_available():
196
+ raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.")
197
+
198
+ # Use the first available CUDA device
199
+ device = torch.device("cuda")
200
+
201
+ # Set CUDA seed safely within GPU context
202
+ try:
203
+ random_seed = 0
204
+ torch.cuda.manual_seed_all(random_seed)
205
+ torch.backends.cudnn.deterministic = True
206
+ torch.backends.cudnn.benchmark = False
207
+ except Exception as e:
208
+ logging.warning(f"Could not set CUDA seed: {e}")
209
+
210
+ try:
211
+ # Create model
212
+ max_disp = model_info.get("max_disp", 256)
213
+ model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True)
214
+
215
+ # Load checkpoint
216
+ ckpt = torch.load(model_path, map_location=device)
217
+ model.load_state_dict(ckpt, strict=True)
218
+ model.to(device)
219
+ model.eval()
220
+
221
+ logging.info("Loaded CREStereo model weights")
222
+
223
+ # Memory optimizations
224
+ torch.set_grad_enabled(False)
225
+ logging.info("Applied memory optimizations")
226
+
227
+ return model, device
228
+
229
+ except Exception as e:
230
+ logging.error(f"Model loading failed: {e}")
231
+ raise RuntimeError(f"Failed to load model: {e}")
232
+
233
+
234
+ def get_cached_model(model_selection: str):
235
+ """Get cached model or load new one if selection changed"""
236
+ global _cached_model, _cached_device, _cached_model_selection
237
+
238
+ # Get model paths from selection
239
+ model_path, model_info = get_model_paths_from_selection(model_selection)
240
+
241
+ if model_path is None or model_info is None:
242
+ raise ValueError(f"Selected model not found: {model_selection}")
243
+
244
+ # Check if we need to reload the model
245
+ if (_cached_model is None or
246
+ _cached_model_selection != model_selection):
247
+
248
+ # Clear previous model if exists
249
+ if _cached_model is not None:
250
+ del _cached_model
251
+ torch.cuda.empty_cache()
252
+ gc.collect()
253
+
254
+ logging.info(f"🚀 Loading model: {model_selection}")
255
+ _cached_model, _cached_device = load_model_for_inference(model_path, model_info)
256
+ _cached_model_selection = model_selection
257
+
258
+ logging.info(f"✅ Model loaded successfully: {model_selection}")
259
+ else:
260
+ logging.info(f"✅ Using cached model: {model_selection}")
261
+
262
+ return _cached_model, _cached_device
263
+
264
+
265
+ def clear_model_cache():
266
+ """Clear the cached model to free memory"""
267
+ global _cached_model, _cached_device, _cached_model_selection
268
+
269
+ if _cached_model is not None:
270
+ logging.info("Clearing model cache...")
271
+ del _cached_model
272
+ _cached_model = None
273
+ _cached_device = None
274
+ _cached_model_selection = None
275
+
276
+ # Simple cleanup
277
+ import gc
278
+ gc.collect()
279
+ torch.cuda.empty_cache()
280
+ logging.info("Model cache cleared")
281
+ else:
282
+ logging.info("No model in cache to clear")
283
+
284
+
285
+ def inference(left, right, model, device, n_iter=20):
286
+ """Run CREStereo inference on stereo pair"""
287
+ print("Model Forwarding...")
288
+ imgL = left.transpose(2, 0, 1)
289
+ imgR = right.transpose(2, 0, 1)
290
+ imgL = np.ascontiguousarray(imgL[None, :, :, :])
291
+ imgR = np.ascontiguousarray(imgR[None, :, :, :])
292
+
293
+ imgL = torch.tensor(imgL.astype("float32")).to(device)
294
+ imgR = torch.tensor(imgR.astype("float32")).to(device)
295
+
296
+ # Use InputPadder to handle any image size
297
+ padder = InputPadder(imgL.shape, divis_by=8)
298
+ imgL_padded, imgR_padded = padder.pad(imgL, imgR)
299
+
300
+ # Downsample for coarse prediction
301
+ imgL_dw2 = F.interpolate(
302
+ imgL_padded,
303
+ size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
304
+ mode="bilinear",
305
+ align_corners=True,
306
+ )
307
+ imgR_dw2 = F.interpolate(
308
+ imgR_padded,
309
+ size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
310
+ mode="bilinear",
311
+ align_corners=True,
312
+ )
313
+
314
+ with torch.inference_mode():
315
+ pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
316
+ pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2)
317
+
318
+ # Unpad the result to original dimensions
319
+ pred_flow = padder.unpad(pred_flow)
320
+ pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
321
+
322
+ return pred_disp
323
+
324
+
325
+ def vis_disparity(disparity_map, max_val=None):
326
+ """Visualize disparity map"""
327
+ if max_val is None:
328
+ disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0
329
+ else:
330
+ disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255)
331
+
332
+ disp_vis = disp_vis.astype("uint8")
333
+ disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
334
+ disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB)
335
+ return disp_vis
336
+
337
+
338
+ # Fixed with static duration
339
+ @spaces.GPU(duration=60) # Static 60 seconds for basic processing
340
+ def process_stereo_pair(model_selection: str, left_image: str, right_image: str,
341
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
342
+ """
343
+ Main processing function for stereo pair (with model caching)
344
+ """
345
+ logging.info("Starting stereo pair processing...")
346
+
347
+ if left_image is None or right_image is None:
348
+ return None, "❌ Please upload both left and right images."
349
+
350
+ # Convert image paths to numpy arrays
351
+ logging.info(f"Loading images: left={left_image}, right={right_image}")
352
+
353
+ try:
354
+ # Load left image
355
+ if not os.path.exists(left_image):
356
+ logging.error(f"Left image file does not exist: {left_image}")
357
+ return None, f"❌ Left image file not found: {left_image}"
358
+
359
+ logging.info(f"Loading left image from: {left_image}")
360
+ left_img = cv2.imread(left_image)
361
+ if left_img is not None:
362
+ left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
363
+ else:
364
+ # Try with imageio as fallback
365
+ left_img = imageio.imread(left_image)
366
+ if len(left_img.shape) == 3 and left_img.shape[2] == 4:
367
+ left_img = left_img[:, :, :3]
368
+
369
+ # Load right image
370
+ if not os.path.exists(right_image):
371
+ logging.error(f"Right image file does not exist: {right_image}")
372
+ return None, f"❌ Right image file not found: {right_image}"
373
+
374
+ logging.info(f"Loading right image from: {right_image}")
375
+ right_img = cv2.imread(right_image)
376
+ if right_img is not None:
377
+ right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
378
+ else:
379
+ # Try with imageio as fallback
380
+ right_img = imageio.imread(right_image)
381
+ if len(right_img.shape) == 3 and right_img.shape[2] == 4:
382
+ right_img = right_img[:, :, :3]
383
+
384
+ logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}")
385
+
386
+ except Exception as e:
387
+ logging.error(f"Failed to load images: {e}")
388
+ return None, f"❌ Failed to load images: {str(e)}"
389
+
390
+ try:
391
+ # Get cached model
392
+ variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
393
+ progress(0.1, desc=f"Loading cached model ({variant_name})...")
394
+ logging.info("🚀 Getting cached model...")
395
+ model, device = get_cached_model(model_selection)
396
+ logging.info("✅ Cached model loaded successfully")
397
+
398
+ progress(0.2, desc="Preprocessing images...")
399
+
400
+ # Validate input images
401
+ if left_img.shape != right_img.shape:
402
+ return None, "❌ Left and right images must have the same dimensions."
403
+
404
+ H, W = left_img.shape[:2]
405
+
406
+ progress(0.5, desc="Running inference...")
407
+
408
+ # Process stereo pair
409
+ torch.cuda.empty_cache() # Clear any cached memory before inference
410
+
411
+ disp_cpu = inference(left_img, right_img, model, device, n_iter=20)
412
+
413
+ progress(0.8, desc="Creating visualization...")
414
+
415
+ # Create visualization
416
+ disparity_vis = vis_disparity(disp_cpu)
417
+ result_image = disparity_vis
418
+
419
+ progress(1.0, desc="Complete!")
420
+
421
+ # Create status message
422
+ valid_mask = ~np.isinf(disp_cpu)
423
+ min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
424
+ max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
425
+ mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
426
+
427
+ # Get model variant for status
428
+ variant = variant_name
429
+
430
+ # Check current memory usage
431
+ try:
432
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
433
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
434
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
435
+ except:
436
+ memory_info = ""
437
+
438
+ status = f"""✅ Processing successful!
439
+ 🔧 Model: {variant}{memory_info}
440
+ 📊 Disparity Statistics:
441
+ • Range: {min_disp:.2f} - {max_disp:.2f}
442
+ • Mean: {mean_disp:.2f}
443
+ • Input size: {W}×{H}
444
+ • Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
445
+
446
+ return result_image, status
447
+
448
+ except Exception as e:
449
+ logging.error(f"Processing failed: {e}")
450
+ # Clean up GPU memory
451
+ torch.cuda.empty_cache()
452
+ gc.collect()
453
+ return None, f"❌ Error: {str(e)}"
454
+
455
+
456
+ # Fixed with static duration
457
+ @spaces.GPU(duration=120) # Static 120 seconds for depth processing
458
+ def process_with_depth(model_selection: str, left_image: str, right_image: str,
459
+ camera_matrix: str, baseline: float,
460
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
461
+ """
462
+ Process stereo pair and generate depth map and point cloud (with model caching)
463
+ """
464
+ # Import Open3D
465
+ global OPEN3D_AVAILABLE
466
+ try:
467
+ import open3d as o3d
468
+ OPEN3D_AVAILABLE = True
469
+ except ImportError as e:
470
+ logging.warning(f"Open3D not available: {e}")
471
+ OPEN3D_AVAILABLE = False
472
+ return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
473
+
474
+ if left_image is None or right_image is None:
475
+ return None, None, None, "❌ Please upload both left and right images."
476
+
477
+ try:
478
+ progress(0.1, desc="Parsing camera parameters...")
479
+
480
+ # Parse camera matrix
481
+ try:
482
+ K_values = list(map(float, camera_matrix.strip().split()))
483
+ if len(K_values) != 9:
484
+ return None, None, None, "❌ Camera matrix must contain exactly 9 values."
485
+ K = np.array(K_values).reshape(3, 3)
486
+ except ValueError:
487
+ return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
488
+
489
+ if baseline <= 0:
490
+ return None, None, None, "❌ Baseline must be positive."
491
+
492
+ # First get disparity using the same process as basic function
493
+ disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress)
494
+
495
+ if disparity_result is None:
496
+ return None, None, None, status
497
+
498
+ # Load images again for depth processing
499
+ left_img = cv2.imread(left_image)
500
+ left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
501
+
502
+ # Get disparity from model again (we need the raw values, not the visualization)
503
+ model, device = get_cached_model(model_selection)
504
+ disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20)
505
+
506
+ progress(0.6, desc="Converting to depth...")
507
+
508
+ # Remove invisible points
509
+ H, W = disp_cpu.shape
510
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
511
+ us_right = xx - disp_cpu
512
+ invalid = us_right < 0
513
+ disp_cpu[invalid] = np.inf
514
+
515
+ # Convert to depth using the formula: depth = focal_length * baseline / disparity
516
+ depth = K[0, 0] * baseline / disp_cpu
517
+
518
+ # Visualize depth
519
+ depth_vis = vis_disparity(depth, max_val=10.0)
520
+
521
+ progress(0.8, desc="Generating point cloud...")
522
+
523
+ # Generate point cloud
524
+ fx, fy = K[0, 0], K[1, 1]
525
+ cx, cy = K[0, 2], K[1, 2]
526
+
527
+ # Create coordinate meshgrids
528
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
529
+
530
+ # Convert to 3D coordinates
531
+ valid_depth = ~np.isinf(depth)
532
+ z = depth[valid_depth] # Z coordinate (depth)
533
+ x = (u[valid_depth] - cx) * z / fx # X coordinate
534
+ y = (v[valid_depth] - cy) * z / fy # Y coordinate
535
+
536
+ # Stack coordinates (X, Y, Z)
537
+ points = np.stack([x, y, z], axis=-1)
538
+
539
+ # Get corresponding colors
540
+ colors = left_img[valid_depth]
541
+
542
+ # Filter points by depth range
543
+ depth_mask = (z > 0) & (z <= 10.0)
544
+ valid_points = points[depth_mask]
545
+ valid_colors = colors[depth_mask]
546
+
547
+ if len(valid_points) == 0:
548
+ return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
549
+
550
+ # Subsample points for better performance
551
+ if len(valid_points) > 100000:
552
+ indices = np.random.choice(len(valid_points), 100000, replace=False)
553
+ valid_points = valid_points[indices]
554
+ valid_colors = valid_colors[indices]
555
+
556
+ # Transform coordinates for proper visualization
557
+ transformed_points = valid_points.copy()
558
+ transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
559
+ transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
560
+
561
+ # Generate point cloud
562
+ pcd = o3d.geometry.PointCloud()
563
+ pcd.points = o3d.utility.Vector3dVector(transformed_points)
564
+ pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
565
+
566
+ progress(1.0, desc="Complete!")
567
+
568
+ # Check current memory usage
569
+ try:
570
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
571
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
572
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
573
+ except:
574
+ memory_info = ""
575
+
576
+ variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
577
+
578
+ status = f"""✅ Depth processing successful!
579
+ 🔧 Model: {variant}{memory_info}
580
+ 📊 Statistics:
581
+ • Valid points: {len(valid_points):,}
582
+ • Depth range: {z.min():.2f} - {z.max():.2f} m
583
+ • Baseline: {baseline} m
584
+ • Point cloud generated with {len(valid_points)} points
585
+ • 3D visualization available"""
586
+
587
+ return depth_vis, None, None, status
588
+
589
+ except Exception as e:
590
+ logging.error(f"Depth processing failed: {e}")
591
+ torch.cuda.empty_cache()
592
+ gc.collect()
593
+ return None, None, None, f"❌ Error: {str(e)}"
594
+
595
+
596
+ def create_app() -> gr.Blocks:
597
+ """Create the Gradio application"""
598
+
599
+ # Get available models
600
+ try:
601
+ available_models = get_available_models()
602
+ logging.info(f"Successfully got available models: {len(available_models)} found")
603
+ except Exception as e:
604
+ logging.error(f"Failed to get available models: {e}")
605
+ available_models = {}
606
+
607
+ with gr.Blocks(
608
+ title="CREStereo - Stereo Depth Estimation",
609
+ theme=gr.themes.Soft(),
610
+ css="footer {visibility: hidden}",
611
+ delete_cache=(60, 60)
612
+ ) as app:
613
+
614
+ gr.Markdown("""
615
+ # 🔍 CREStereo: Practical Stereo Matching
616
+
617
+ Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo.
618
+
619
+ ⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
620
+ ⚡ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference.
621
+ """)
622
+
623
+ # Instructions section
624
+ with gr.Accordion("📋 Instructions", open=False):
625
+ gr.Markdown("""
626
+ ## 🚀 How to Use This Demo
627
+
628
+ ### 🖼️ Input Requirements
629
+ 1. **Image Format**: Upload images in JPEG or PNG format.
630
+ 2. **Image Size**: Images should be of the same size and resolution.
631
+ 3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
632
+ 4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance.
633
+
634
+ ### 📊 Using the Demo
635
+ 1. **Select Model**: Choose the CREStereo model variant
636
+ 2. **Upload Images**: Provide rectified stereo image pairs
637
+ 3. **Basic Processing**: Get disparity visualization
638
+ 4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
639
+
640
+ ### 📖 Original Work
641
+ This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network.
642
+ - **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483)
643
+ - **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo)
644
+ """)
645
+
646
+ # Model selection
647
+ with gr.Row():
648
+ all_choices = list(available_models.keys())
649
+
650
+ if not all_choices:
651
+ all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"]
652
+
653
+ default_model = all_choices[0] if all_choices else None
654
+
655
+ model_selector = gr.Dropdown(
656
+ choices=all_choices,
657
+ value=default_model,
658
+ label="🎯 Select Model",
659
+ info="Choose the CREStereo model variant.",
660
+ interactive=True
661
+ )
662
+
663
+ with gr.Tabs():
664
+ # Basic stereo processing tab
665
+ with gr.TabItem("🖼️ Basic Stereo Processing"):
666
+ with gr.Row():
667
+ with gr.Column():
668
+ left_input = gr.Image(
669
+ label="📷 Left Image",
670
+ type="filepath",
671
+ height=300
672
+ )
673
+ right_input = gr.Image(
674
+ label="📷 Right Image",
675
+ type="filepath",
676
+ height=300
677
+ )
678
+
679
+ process_btn = gr.Button(
680
+ "🚀 Process Stereo Pair",
681
+ variant="primary",
682
+ size="lg"
683
+ )
684
+
685
+ with gr.Column():
686
+ output_image = gr.Image(
687
+ label="📊 Disparity Visualization",
688
+ height=400
689
+ )
690
+ status_text = gr.Textbox(
691
+ label="Status",
692
+ interactive=False,
693
+ lines=8
694
+ )
695
+
696
+ # Example images
697
+ examples_list = []
698
+
699
+ # Example 1
700
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
701
+ examples_list.append([
702
+ os.path.join(current_dir, "assets", "example1", "left.png"),
703
+ os.path.join(current_dir, "assets", "example1", "right.png")
704
+ ])
705
+
706
+ # Example 2
707
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
708
+ examples_list.append([
709
+ os.path.join(current_dir, "assets", "example2", "left.png"),
710
+ os.path.join(current_dir, "assets", "example2", "right.png")
711
+ ])
712
+
713
+ if examples_list:
714
+ gr.Examples(
715
+ examples=examples_list,
716
+ inputs=[left_input, right_input],
717
+ label="📋 Example Images"
718
+ )
719
+
720
+ # Advanced processing with depth
721
+ with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
722
+ with gr.Row():
723
+ with gr.Column():
724
+ left_input_adv = gr.Image(
725
+ label="📷 Left Image",
726
+ type="filepath",
727
+ height=250
728
+ )
729
+ right_input_adv = gr.Image(
730
+ label="📷 Right Image",
731
+ type="filepath",
732
+ height=250
733
+ )
734
+
735
+ # Camera parameters
736
+ with gr.Group():
737
+ gr.Markdown("### 📹 Camera Parameters")
738
+ camera_matrix_input = gr.Textbox(
739
+ label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
740
+ value="",
741
+ )
742
+ baseline_input = gr.Number(
743
+ label="Baseline (meters)",
744
+ value=None,
745
+ minimum=0.001,
746
+ maximum=10.0,
747
+ step=0.001
748
+ )
749
+
750
+ process_depth_btn = gr.Button(
751
+ "🔬 Process with Depth",
752
+ variant="primary",
753
+ size="lg"
754
+ )
755
+
756
+ with gr.Column():
757
+ depth_output = gr.Image(
758
+ label="📏 Depth Visualization",
759
+ height=300
760
+ )
761
+ pointcloud_output = gr.File(
762
+ label="☁️ Point Cloud Download (.ply)",
763
+ file_types=[".ply"]
764
+ )
765
+ status_depth = gr.Textbox(
766
+ label="Status",
767
+ interactive=False,
768
+ lines=6
769
+ )
770
+
771
+ # 3D Point Cloud Visualization
772
+ with gr.Row():
773
+ pointcloud_3d = gr.Model3D(
774
+ label="🌐 3D Point Cloud Viewer",
775
+ clear_color=[0.0, 0.0, 0.0, 0.0],
776
+ height=400
777
+ )
778
+
779
+ # Example images for advanced processing
780
+ examples_advanced_list = []
781
+
782
+ # Try to read camera parameters from K.txt files
783
+ # Example 1
784
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
785
+ k_file = os.path.join(current_dir, "assets", "example1", "K.txt")
786
+ camera_matrix_str = ""
787
+ baseline_val = 0.063 # default
788
+
789
+ if os.path.exists(k_file):
790
+ try:
791
+ with open(k_file, 'r') as f:
792
+ lines = f.readlines()
793
+ if len(lines) >= 1:
794
+ camera_matrix_str = lines[0].strip()
795
+ if len(lines) >= 2:
796
+ baseline_val = float(lines[1].strip())
797
+ except:
798
+ camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0"
799
+
800
+ examples_advanced_list.append([
801
+ os.path.join(current_dir, "assets", "example1", "left.png"),
802
+ os.path.join(current_dir, "assets", "example1", "right.png"),
803
+ camera_matrix_str,
804
+ baseline_val
805
+ ])
806
+
807
+ # Example 2
808
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
809
+ k_file = os.path.join(current_dir, "assets", "example2", "K.txt")
810
+ camera_matrix_str = ""
811
+ baseline_val = 0.537 # default
812
+
813
+ if os.path.exists(k_file):
814
+ try:
815
+ with open(k_file, 'r') as f:
816
+ lines = f.readlines()
817
+ if len(lines) >= 1:
818
+ camera_matrix_str = lines[0].strip()
819
+ if len(lines) >= 2:
820
+ baseline_val = float(lines[1].strip())
821
+ except:
822
+ camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0"
823
+
824
+ examples_advanced_list.append([
825
+ os.path.join(current_dir, "assets", "example2", "left.png"),
826
+ os.path.join(current_dir, "assets", "example2", "right.png"),
827
+ camera_matrix_str,
828
+ baseline_val
829
+ ])
830
+
831
+ if examples_advanced_list:
832
+ gr.Examples(
833
+ examples=examples_advanced_list,
834
+ inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
835
+ label="📋 Example Images with Camera Parameters"
836
+ )
837
+
838
+ # Event handlers
839
+ if available_models:
840
+ process_btn.click(
841
+ fn=process_stereo_pair,
842
+ inputs=[model_selector, left_input, right_input],
843
+ outputs=[output_image, status_text],
844
+ show_progress=True
845
+ )
846
+
847
+ if OPEN3D_AVAILABLE:
848
+ process_depth_btn.click(
849
+ fn=process_with_depth,
850
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
851
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
852
+ show_progress=True
853
+ )
854
+ else:
855
+ process_depth_btn.click(
856
+ fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
857
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
858
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
859
+ )
860
+ else:
861
+ # No models available
862
+ process_btn.click(
863
+ fn=lambda *args: (None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
864
+ inputs=[model_selector, left_input, right_input],
865
+ outputs=[output_image, status_text]
866
+ )
867
+
868
+ process_depth_btn.click(
869
+ fn=lambda *args: (None, None, None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
870
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
871
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
872
+ )
873
+
874
+ # Citation section at the bottom
875
+ with gr.Accordion("📖 Citation", open=False):
876
+ gr.Markdown("""
877
+ ### 📄 Please Cite the Original Paper
878
+
879
+ If you use this work in your research, please cite:
880
+
881
+ ```bibtex
882
+ @article{li2022practical,
883
+ title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation},
884
+ author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng},
885
+ journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
886
+ pages={16263--16272},
887
+ year={2022}
888
+ }
889
+ ```
890
+ """)
891
+
892
+ # Footer
893
+ gr.Markdown("""
894
+ ---
895
+ ### 📝 Notes:
896
+ - **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
897
+ - **⚡ GPU Acceleration**: Requires CUDA-compatible GPU
898
+ - **📦 Model Caching**: Models are cached for efficient repeated usage
899
+ - For best results, use high-quality rectified stereo pairs
900
+ - Model works on RGB images and supports various resolutions
901
+
902
+ ### 🔗 References:
903
+ - [CREStereo Paper](https://arxiv.org/abs/2203.11483)
904
+ - [Original GitHub Repository](https://github.com/megvii-research/CREStereo)
905
+ - [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch)
906
+ """)
907
+
908
+ return app
909
+
910
+
911
+ def main():
912
+ """Main function to launch the app"""
913
+
914
+ # Ensure no CUDA operations during startup
915
+ if torch.cuda.is_available():
916
+ logging.warning("CUDA detected during startup - this should not happen in ZeroGPU")
917
+
918
+ logging.info("🚀 Starting CREStereo Gradio App...")
919
+
920
+ # Parse command line arguments
921
+ import argparse
922
+ parser = argparse.ArgumentParser(description="CREStereo Gradio App")
923
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
924
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
925
+ parser.add_argument("--share", action="store_true", help="Create shareable link")
926
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
927
+
928
+ args = parser.parse_args()
929
+
930
+ if args.debug:
931
+ logging.getLogger().setLevel(logging.DEBUG)
932
+
933
+ try:
934
+ # Create and launch app
935
+ logging.info("Creating Gradio app...")
936
+ app = create_app()
937
+ logging.info("✅ Gradio app created successfully")
938
+
939
+ logging.info(f"Launching app on {args.host}:{args.port}")
940
+ if args.share:
941
+ logging.info("Share link will be created")
942
+
943
+ # For ZeroGPU compatibility, launch with appropriate settings
944
+ app.launch(
945
+ server_name=args.host,
946
+ server_port=args.port,
947
+ share=args.share,
948
+ show_error=True,
949
+ favicon_path=None,
950
+ ssr_mode=False, # Disable SSR for ZeroGPU compatibility
951
+ allowed_paths=["./"] # Allow access to local files
952
+ )
953
+ except Exception as e:
954
+ logging.error(f"Failed to launch app: {e}")
955
+ raise
956
+
957
+
958
+ if __name__ == "__main__":
959
+ # Additional safety check for ZeroGPU environment
960
+ if 'SPACE_ID' in os.environ:
961
+ logging.info("Running in Hugging Face Spaces environment")
962
+
963
+ # Do not check CUDA status during startup - this can trigger CUDA initialization
964
+ # The CUDA status will be checked inside the @spaces.GPU decorated functions
965
+ logging.info("✅ CUDA status will be checked within GPU-decorated functions")
966
+
967
+ main()
CREStereo_demo/app_local.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import tempfile
5
+ import gc
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple, Union
8
+ import numpy as np
9
+ import cv2
10
+ import gradio as gr
11
+ import imageio
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ # Set default tensor type if needed
17
+ # torch.set_default_tensor_type('torch.FloatTensor')
18
+
19
+ # CUDA backend settings
20
+ # torch.backends.cudnn.enabled = False
21
+ # torch.backends.cudnn.benchmark = False
22
+
23
+ # Use current directory as base
24
+ current_dir = os.path.dirname(os.path.abspath(__file__))
25
+ base_dir = current_dir
26
+
27
+ # Add current directory to path for local imports
28
+ sys.path.insert(0, current_dir)
29
+
30
+ # Import local modules
31
+ from nets import Model
32
+
33
+ # Import Open3D with error handling
34
+ OPEN3D_AVAILABLE = False
35
+ try:
36
+ # Set Open3D to CPU mode to avoid CUDA initialization
37
+ os.environ['OPEN3D_CPU_RENDERING'] = '1'
38
+ # Don't import open3d here - do it inside functions
39
+ # import open3d as o3d
40
+ OPEN3D_AVAILABLE = True # Assume available, will check later
41
+ except Exception as e:
42
+ logging.warning(f"Open3D setup failed: {e}")
43
+ OPEN3D_AVAILABLE = False
44
+
45
+ # Configure logging
46
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
47
+
48
+ # Model configuration
49
+ MODEL_VARIANTS = {
50
+ "crestereo_eth3d": {
51
+ "display_name": "CREStereo ETH3D (Pre-trained model)",
52
+ "model_file": "models/crestereo_eth3d.pth",
53
+ "max_disp": 256
54
+ }
55
+ }
56
+
57
+ # Global variables for model caching
58
+ _cached_model = None
59
+ _cached_device = None
60
+ _cached_model_selection = None
61
+
62
+
63
+ class InputPadder:
64
+ """ Pads images such that dimensions are divisible by divis_by """
65
+ def __init__(self, dims, divis_by=8, force_square=False):
66
+ self.ht, self.wd = dims[-2:]
67
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
68
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
69
+
70
+ if force_square:
71
+ # Make the padded dimensions square
72
+ max_dim = max(self.ht + pad_ht, self.wd + pad_wd)
73
+ pad_ht = max_dim - self.ht
74
+ pad_wd = max_dim - self.wd
75
+
76
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
77
+
78
+ def pad(self, *inputs):
79
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
80
+
81
+ def unpad(self, x):
82
+ ht, wd = x.shape[-2:]
83
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
84
+ return x[..., c[0]:c[1], c[2]:c[3]]
85
+
86
+
87
+ def aggressive_cleanup():
88
+ """Perform basic cleanup"""
89
+ import gc
90
+ gc.collect()
91
+ logging.info("Performed basic memory cleanup")
92
+
93
+
94
+ def check_gpu_memory():
95
+ """Check and log current GPU memory usage"""
96
+ try:
97
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
98
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
99
+ max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
100
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
101
+
102
+ logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
103
+ return allocated, reserved, max_allocated, total
104
+ except RuntimeError as e:
105
+ logging.warning(f"Failed to get GPU memory info: {e}")
106
+ return None, None, None, None
107
+
108
+
109
+ def get_available_models() -> dict:
110
+ """Get all available models with their display names"""
111
+ models = {}
112
+
113
+ # Check for local models
114
+ for variant, info in MODEL_VARIANTS.items():
115
+ model_path = os.path.join(current_dir, info["model_file"])
116
+
117
+ if os.path.exists(model_path):
118
+ display_name = info["display_name"]
119
+ models[display_name] = {
120
+ "model_path": model_path,
121
+ "variant": variant,
122
+ "max_disp": info["max_disp"],
123
+ "source": "local"
124
+ }
125
+
126
+ return models
127
+
128
+
129
+ def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]:
130
+ """Get model path and config from the selected model"""
131
+ models = get_available_models()
132
+
133
+ # Check if it's in our models dict
134
+ if model_selection in models:
135
+ model_info = models[model_selection]
136
+ logging.info(f"📁 Using local model: {model_selection}")
137
+ return model_info["model_path"], model_info
138
+
139
+ return None, None
140
+
141
+
142
+ def load_model_for_inference(model_path: str, model_info: dict):
143
+ """Load CREStereo model for inference"""
144
+ # Check if CUDA is available
145
+ if not torch.cuda.is_available():
146
+ raise RuntimeError("CUDA is not available.")
147
+
148
+ # Use the first available CUDA device
149
+ device = torch.device("cuda")
150
+
151
+ try:
152
+ # Create model
153
+ max_disp = model_info.get("max_disp", 256)
154
+ model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True)
155
+
156
+ # Load checkpoint
157
+ ckpt = torch.load(model_path, map_location=device)
158
+ model.load_state_dict(ckpt, strict=True)
159
+ model.to(device)
160
+ model.eval()
161
+
162
+ logging.info("Loaded CREStereo model weights")
163
+
164
+ # Memory optimizations
165
+ torch.set_grad_enabled(False)
166
+ logging.info("Applied memory optimizations")
167
+
168
+ return model, device
169
+
170
+ except Exception as e:
171
+ logging.error(f"Model loading failed: {e}")
172
+ raise RuntimeError(f"Failed to load model: {e}")
173
+
174
+
175
+ def get_cached_model(model_selection: str):
176
+ """Get cached model or load new one if selection changed"""
177
+ global _cached_model, _cached_device, _cached_model_selection
178
+
179
+ # Get model paths from selection
180
+ model_path, model_info = get_model_paths_from_selection(model_selection)
181
+
182
+ if model_path is None or model_info is None:
183
+ raise ValueError(f"Selected model not found: {model_selection}")
184
+
185
+ # Check if we need to reload the model
186
+ if (_cached_model is None or
187
+ _cached_model_selection != model_selection):
188
+
189
+ # Clear previous model if exists
190
+ if _cached_model is not None:
191
+ del _cached_model
192
+ torch.cuda.empty_cache()
193
+ gc.collect()
194
+
195
+ logging.info(f"🚀 Loading model: {model_selection}")
196
+ _cached_model, _cached_device = load_model_for_inference(model_path, model_info)
197
+ _cached_model_selection = model_selection
198
+
199
+ logging.info(f"✅ Model loaded successfully: {model_selection}")
200
+ else:
201
+ logging.info(f"✅ Using cached model: {model_selection}")
202
+
203
+ return _cached_model, _cached_device
204
+
205
+
206
+ def clear_model_cache():
207
+ """Clear the cached model to free memory"""
208
+ global _cached_model, _cached_device, _cached_model_selection
209
+
210
+ if _cached_model is not None:
211
+ logging.info("Clearing model cache...")
212
+ del _cached_model
213
+ _cached_model = None
214
+ _cached_device = None
215
+ _cached_model_selection = None
216
+
217
+ # Simple cleanup
218
+ import gc
219
+ gc.collect()
220
+ torch.cuda.empty_cache()
221
+ logging.info("Model cache cleared")
222
+ else:
223
+ logging.info("No model in cache to clear")
224
+
225
+
226
+ def inference(left, right, model, device, n_iter=20):
227
+ """Run CREStereo inference on stereo pair"""
228
+ print("Model Forwarding...")
229
+ imgL = left.transpose(2, 0, 1)
230
+ imgR = right.transpose(2, 0, 1)
231
+ imgL = np.ascontiguousarray(imgL[None, :, :, :])
232
+ imgR = np.ascontiguousarray(imgR[None, :, :, :])
233
+
234
+ imgL = torch.tensor(imgL.astype("float32")).to(device)
235
+ imgR = torch.tensor(imgR.astype("float32")).to(device)
236
+
237
+ # Use InputPadder to handle any image size
238
+ padder = InputPadder(imgL.shape, divis_by=8)
239
+ imgL_padded, imgR_padded = padder.pad(imgL, imgR)
240
+
241
+ # Downsample for coarse prediction
242
+ imgL_dw2 = F.interpolate(
243
+ imgL_padded,
244
+ size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+ imgR_dw2 = F.interpolate(
249
+ imgR_padded,
250
+ size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
251
+ mode="bilinear",
252
+ align_corners=True,
253
+ )
254
+
255
+ with torch.inference_mode():
256
+ pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
257
+ pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2)
258
+
259
+ # Unpad the result to original dimensions
260
+ pred_flow = padder.unpad(pred_flow)
261
+ pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
262
+
263
+ return pred_disp
264
+
265
+
266
+ def vis_disparity(disparity_map, max_val=None):
267
+ """Visualize disparity map"""
268
+ if max_val is None:
269
+ disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0
270
+ else:
271
+ disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255)
272
+
273
+ disp_vis = disp_vis.astype("uint8")
274
+ disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
275
+ disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB)
276
+ return disp_vis
277
+
278
+
279
+ def process_stereo_pair(model_selection: str, left_image: str, right_image: str,
280
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
281
+ """
282
+ Main processing function for stereo pair (with model caching)
283
+ """
284
+ logging.info("Starting stereo pair processing...")
285
+
286
+ if left_image is None or right_image is None:
287
+ return None, "❌ Please upload both left and right images."
288
+
289
+ # Convert image paths to numpy arrays
290
+ logging.info(f"Loading images: left={left_image}, right={right_image}")
291
+
292
+ try:
293
+ # Load left image
294
+ if not os.path.exists(left_image):
295
+ logging.error(f"Left image file does not exist: {left_image}")
296
+ return None, f"❌ Left image file not found: {left_image}"
297
+
298
+ logging.info(f"Loading left image from: {left_image}")
299
+ left_img = cv2.imread(left_image)
300
+ if left_img is not None:
301
+ left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
302
+ else:
303
+ # Try with imageio as fallback
304
+ left_img = imageio.imread(left_image)
305
+ if len(left_img.shape) == 3 and left_img.shape[2] == 4:
306
+ left_img = left_img[:, :, :3]
307
+
308
+ # Load right image
309
+ if not os.path.exists(right_image):
310
+ logging.error(f"Right image file does not exist: {right_image}")
311
+ return None, f"❌ Right image file not found: {right_image}"
312
+
313
+ logging.info(f"Loading right image from: {right_image}")
314
+ right_img = cv2.imread(right_image)
315
+ if right_img is not None:
316
+ right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
317
+ else:
318
+ # Try with imageio as fallback
319
+ right_img = imageio.imread(right_image)
320
+ if len(right_img.shape) == 3 and right_img.shape[2] == 4:
321
+ right_img = right_img[:, :, :3]
322
+
323
+ logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}")
324
+
325
+ except Exception as e:
326
+ logging.error(f"Failed to load images: {e}")
327
+ return None, f"❌ Failed to load images: {str(e)}"
328
+
329
+ try:
330
+ # Get cached model
331
+ variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
332
+ progress(0.1, desc=f"Loading cached model ({variant_name})...")
333
+ logging.info("🚀 Getting cached model...")
334
+ model, device = get_cached_model(model_selection)
335
+ logging.info("✅ Cached model loaded successfully")
336
+
337
+ progress(0.2, desc="Preprocessing images...")
338
+
339
+ # Validate input images
340
+ if left_img.shape != right_img.shape:
341
+ return None, "❌ Left and right images must have the same dimensions."
342
+
343
+ H, W = left_img.shape[:2]
344
+
345
+ progress(0.5, desc="Running inference...")
346
+
347
+ # Process stereo pair
348
+ torch.cuda.empty_cache() # Clear any cached memory before inference
349
+
350
+ disp_cpu = inference(left_img, right_img, model, device, n_iter=20)
351
+
352
+ progress(0.8, desc="Creating visualization...")
353
+
354
+ # Create visualization
355
+ disparity_vis = vis_disparity(disp_cpu)
356
+ result_image = disparity_vis
357
+
358
+ progress(1.0, desc="Complete!")
359
+
360
+ # Create status message
361
+ valid_mask = ~np.isinf(disp_cpu)
362
+ min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
363
+ max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
364
+ mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
365
+
366
+ # Get model variant for status
367
+ variant = variant_name
368
+
369
+ # Check current memory usage
370
+ try:
371
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
372
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
373
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
374
+ except:
375
+ memory_info = ""
376
+
377
+ status = f"""✅ Processing successful!
378
+ 🔧 Model: {variant}{memory_info}
379
+ 📊 Disparity Statistics:
380
+ • Range: {min_disp:.2f} - {max_disp:.2f}
381
+ • Mean: {mean_disp:.2f}
382
+ • Input size: {W}×{H}
383
+ • Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
384
+
385
+ return result_image, status
386
+
387
+ except Exception as e:
388
+ logging.error(f"Processing failed: {e}")
389
+ # Clean up GPU memory
390
+ torch.cuda.empty_cache()
391
+ gc.collect()
392
+ return None, f"❌ Error: {str(e)}"
393
+
394
+
395
+ def process_with_depth(model_selection: str, left_image: str, right_image: str,
396
+ camera_matrix: str, baseline: float,
397
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
398
+ """
399
+ Process stereo pair and generate depth map and point cloud (with model caching)
400
+ """
401
+ # Import Open3D
402
+ global OPEN3D_AVAILABLE
403
+ try:
404
+ import open3d as o3d
405
+ OPEN3D_AVAILABLE = True
406
+ except ImportError as e:
407
+ logging.warning(f"Open3D not available: {e}")
408
+ OPEN3D_AVAILABLE = False
409
+ return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
410
+
411
+ if left_image is None or right_image is None:
412
+ return None, None, None, "❌ Please upload both left and right images."
413
+
414
+ try:
415
+ progress(0.1, desc="Parsing camera parameters...")
416
+
417
+ # Parse camera matrix
418
+ try:
419
+ K_values = list(map(float, camera_matrix.strip().split()))
420
+ if len(K_values) != 9:
421
+ return None, None, None, "❌ Camera matrix must contain exactly 9 values."
422
+ K = np.array(K_values).reshape(3, 3)
423
+ except ValueError:
424
+ return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
425
+
426
+ if baseline <= 0:
427
+ return None, None, None, "❌ Baseline must be positive."
428
+
429
+ # First get disparity using the same process as basic function
430
+ disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress)
431
+
432
+ if disparity_result is None:
433
+ return None, None, None, status
434
+
435
+ # Load images again for depth processing
436
+ left_img = cv2.imread(left_image)
437
+ left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
438
+
439
+ # Get disparity from model again (we need the raw values, not the visualization)
440
+ model, device = get_cached_model(model_selection)
441
+ disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20)
442
+
443
+ progress(0.6, desc="Converting to depth...")
444
+
445
+ # Remove invisible points
446
+ H, W = disp_cpu.shape
447
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
448
+ us_right = xx - disp_cpu
449
+ invalid = us_right < 0
450
+ disp_cpu[invalid] = np.inf
451
+
452
+ # Convert to depth using the formula: depth = focal_length * baseline / disparity
453
+ depth = K[0, 0] * baseline / disp_cpu
454
+
455
+ # Visualize depth
456
+ depth_vis = vis_disparity(depth, max_val=10.0)
457
+
458
+ progress(0.8, desc="Generating point cloud...")
459
+
460
+ # Generate point cloud
461
+ fx, fy = K[0, 0], K[1, 1]
462
+ cx, cy = K[0, 2], K[1, 2]
463
+
464
+ # Create coordinate meshgrids
465
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
466
+
467
+ # Convert to 3D coordinates
468
+ valid_depth = ~np.isinf(depth)
469
+ z = depth[valid_depth] # Z coordinate (depth)
470
+ x = (u[valid_depth] - cx) * z / fx # X coordinate
471
+ y = (v[valid_depth] - cy) * z / fy # Y coordinate
472
+
473
+ # Stack coordinates (X, Y, Z)
474
+ points = np.stack([x, y, z], axis=-1)
475
+
476
+ # Get corresponding colors
477
+ colors = left_img[valid_depth]
478
+
479
+ # Filter points by depth range
480
+ depth_mask = (z > 0) & (z <= 10.0)
481
+ valid_points = points[depth_mask]
482
+ valid_colors = colors[depth_mask]
483
+
484
+ if len(valid_points) == 0:
485
+ return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
486
+
487
+ # Subsample points for better performance
488
+ if len(valid_points) > 100000:
489
+ indices = np.random.choice(len(valid_points), 100000, replace=False)
490
+ valid_points = valid_points[indices]
491
+ valid_colors = valid_colors[indices]
492
+
493
+ # Transform coordinates for proper visualization
494
+ transformed_points = valid_points.copy()
495
+ transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
496
+ transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
497
+
498
+ # Generate point cloud
499
+ pcd = o3d.geometry.PointCloud()
500
+ pcd.points = o3d.utility.Vector3dVector(transformed_points)
501
+ pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
502
+
503
+ progress(1.0, desc="Complete!")
504
+
505
+ # Check current memory usage
506
+ try:
507
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
508
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
509
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
510
+ except:
511
+ memory_info = ""
512
+
513
+ variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
514
+
515
+ status = f"""✅ Depth processing successful!
516
+ 🔧 Model: {variant}{memory_info}
517
+ 📊 Statistics:
518
+ • Valid points: {len(valid_points):,}
519
+ • Depth range: {z.min():.2f} - {z.max():.2f} m
520
+ • Baseline: {baseline} m
521
+ • Point cloud generated with {len(valid_points)} points
522
+ • 3D visualization available"""
523
+
524
+ return depth_vis, None, None, status
525
+
526
+ except Exception as e:
527
+ logging.error(f"Depth processing failed: {e}")
528
+ torch.cuda.empty_cache()
529
+ gc.collect()
530
+ return None, None, None, f"❌ Error: {str(e)}"
531
+
532
+
533
+ def create_app() -> gr.Blocks:
534
+ """Create the Gradio application"""
535
+
536
+ # Get available models
537
+ try:
538
+ available_models = get_available_models()
539
+ logging.info(f"Successfully got available models: {len(available_models)} found")
540
+ except Exception as e:
541
+ logging.error(f"Failed to get available models: {e}")
542
+ available_models = {}
543
+
544
+ with gr.Blocks(
545
+ title="CREStereo - Stereo Depth Estimation",
546
+ theme=gr.themes.Soft(),
547
+ css="footer {visibility: hidden}",
548
+ delete_cache=(60, 60)
549
+ ) as app:
550
+
551
+ gr.Markdown("""
552
+ # 🔍 CREStereo: Practical Stereo Matching
553
+
554
+ Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo.
555
+
556
+ ⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
557
+ ⚡ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference.
558
+ """)
559
+
560
+ # Instructions section
561
+ with gr.Accordion("📋 Instructions", open=False):
562
+ gr.Markdown("""
563
+ ## 🚀 How to Use This Demo
564
+
565
+ ### 🖼️ Input Requirements
566
+ 1. **Image Format**: Upload images in JPEG or PNG format.
567
+ 2. **Image Size**: Images should be of the same size and resolution.
568
+ 3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
569
+ 4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance.
570
+
571
+ ### 📊 Using the Demo
572
+ 1. **Select Model**: Choose the CREStereo model variant
573
+ 2. **Upload Images**: Provide rectified stereo image pairs
574
+ 3. **Basic Processing**: Get disparity visualization
575
+ 4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
576
+
577
+ ### 📖 Original Work
578
+ This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network.
579
+ - **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483)
580
+ - **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo)
581
+ """)
582
+
583
+ # Model selection
584
+ with gr.Row():
585
+ all_choices = list(available_models.keys())
586
+
587
+ if not all_choices:
588
+ all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"]
589
+
590
+ default_model = all_choices[0] if all_choices else None
591
+
592
+ model_selector = gr.Dropdown(
593
+ choices=all_choices,
594
+ value=default_model,
595
+ label="🎯 Select Model",
596
+ info="Choose the CREStereo model variant.",
597
+ interactive=True
598
+ )
599
+
600
+ with gr.Tabs():
601
+ # Basic stereo processing tab
602
+ with gr.TabItem("🖼️ Basic Stereo Processing"):
603
+ with gr.Row():
604
+ with gr.Column():
605
+ left_input = gr.Image(
606
+ label="📷 Left Image",
607
+ type="filepath",
608
+ height=300
609
+ )
610
+ right_input = gr.Image(
611
+ label="📷 Right Image",
612
+ type="filepath",
613
+ height=300
614
+ )
615
+
616
+ process_btn = gr.Button(
617
+ "🚀 Process Stereo Pair",
618
+ variant="primary",
619
+ size="lg"
620
+ )
621
+
622
+ with gr.Column():
623
+ output_image = gr.Image(
624
+ label="📊 Disparity Visualization",
625
+ height=400
626
+ )
627
+ status_text = gr.Textbox(
628
+ label="Status",
629
+ interactive=False,
630
+ lines=8
631
+ )
632
+
633
+ # Example images
634
+ examples_list = []
635
+
636
+ # Example 1
637
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
638
+ examples_list.append([
639
+ os.path.join(current_dir, "assets", "example1", "left.png"),
640
+ os.path.join(current_dir, "assets", "example1", "right.png")
641
+ ])
642
+
643
+ # Example 2
644
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
645
+ examples_list.append([
646
+ os.path.join(current_dir, "assets", "example2", "left.png"),
647
+ os.path.join(current_dir, "assets", "example2", "right.png")
648
+ ])
649
+
650
+ if examples_list:
651
+ gr.Examples(
652
+ examples=examples_list,
653
+ inputs=[left_input, right_input],
654
+ label="📋 Example Images"
655
+ )
656
+
657
+ # Advanced processing with depth
658
+ with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
659
+ with gr.Row():
660
+ with gr.Column():
661
+ left_input_adv = gr.Image(
662
+ label="📷 Left Image",
663
+ type="filepath",
664
+ height=250
665
+ )
666
+ right_input_adv = gr.Image(
667
+ label="📷 Right Image",
668
+ type="filepath",
669
+ height=250
670
+ )
671
+
672
+ # Camera parameters
673
+ with gr.Group():
674
+ gr.Markdown("### 📹 Camera Parameters")
675
+ camera_matrix_input = gr.Textbox(
676
+ label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
677
+ value="",
678
+ )
679
+ baseline_input = gr.Number(
680
+ label="Baseline (meters)",
681
+ value=None,
682
+ minimum=0.001,
683
+ maximum=10.0,
684
+ step=0.001
685
+ )
686
+
687
+ process_depth_btn = gr.Button(
688
+ "🔬 Process with Depth",
689
+ variant="primary",
690
+ size="lg"
691
+ )
692
+
693
+ with gr.Column():
694
+ depth_output = gr.Image(
695
+ label="📏 Depth Visualization",
696
+ height=300
697
+ )
698
+ pointcloud_output = gr.File(
699
+ label="☁️ Point Cloud Download (.ply)",
700
+ file_types=[".ply"]
701
+ )
702
+ status_depth = gr.Textbox(
703
+ label="Status",
704
+ interactive=False,
705
+ lines=6
706
+ )
707
+
708
+ # 3D Point Cloud Visualization
709
+ with gr.Row():
710
+ pointcloud_3d = gr.Model3D(
711
+ label="🌐 3D Point Cloud Viewer",
712
+ clear_color=[0.0, 0.0, 0.0, 0.0],
713
+ height=400
714
+ )
715
+
716
+ # Example images for advanced processing
717
+ examples_advanced_list = []
718
+
719
+ # Try to read camera parameters from K.txt files
720
+ # Example 1
721
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
722
+ k_file = os.path.join(current_dir, "assets", "example1", "K.txt")
723
+ camera_matrix_str = ""
724
+ baseline_val = 0.063 # default
725
+
726
+ if os.path.exists(k_file):
727
+ try:
728
+ with open(k_file, 'r') as f:
729
+ lines = f.readlines()
730
+ if len(lines) >= 1:
731
+ camera_matrix_str = lines[0].strip()
732
+ if len(lines) >= 2:
733
+ baseline_val = float(lines[1].strip())
734
+ except:
735
+ camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0"
736
+
737
+ examples_advanced_list.append([
738
+ os.path.join(current_dir, "assets", "example1", "left.png"),
739
+ os.path.join(current_dir, "assets", "example1", "right.png"),
740
+ camera_matrix_str,
741
+ baseline_val
742
+ ])
743
+
744
+ # Example 2
745
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
746
+ k_file = os.path.join(current_dir, "assets", "example2", "K.txt")
747
+ camera_matrix_str = ""
748
+ baseline_val = 0.537 # default
749
+
750
+ if os.path.exists(k_file):
751
+ try:
752
+ with open(k_file, 'r') as f:
753
+ lines = f.readlines()
754
+ if len(lines) >= 1:
755
+ camera_matrix_str = lines[0].strip()
756
+ if len(lines) >= 2:
757
+ baseline_val = float(lines[1].strip())
758
+ except:
759
+ camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0"
760
+
761
+ examples_advanced_list.append([
762
+ os.path.join(current_dir, "assets", "example2", "left.png"),
763
+ os.path.join(current_dir, "assets", "example2", "right.png"),
764
+ camera_matrix_str,
765
+ baseline_val
766
+ ])
767
+
768
+ if examples_advanced_list:
769
+ gr.Examples(
770
+ examples=examples_advanced_list,
771
+ inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
772
+ label="📋 Example Images with Camera Parameters"
773
+ )
774
+
775
+ # Event handlers
776
+ if available_models:
777
+ process_btn.click(
778
+ fn=process_stereo_pair,
779
+ inputs=[model_selector, left_input, right_input],
780
+ outputs=[output_image, status_text],
781
+ show_progress=True
782
+ )
783
+
784
+ if OPEN3D_AVAILABLE:
785
+ process_depth_btn.click(
786
+ fn=process_with_depth,
787
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
788
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
789
+ show_progress=True
790
+ )
791
+ else:
792
+ process_depth_btn.click(
793
+ fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
794
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
795
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
796
+ )
797
+ else:
798
+ # No models available
799
+ process_btn.click(
800
+ fn=lambda *args: (None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
801
+ inputs=[model_selector, left_input, right_input],
802
+ outputs=[output_image, status_text]
803
+ )
804
+
805
+ process_depth_btn.click(
806
+ fn=lambda *args: (None, None, None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
807
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
808
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
809
+ )
810
+
811
+ # Citation section at the bottom
812
+ with gr.Accordion("📖 Citation", open=False):
813
+ gr.Markdown("""
814
+ ### 📄 Please Cite the Original Paper
815
+
816
+ If you use this work in your research, please cite:
817
+
818
+ ```bibtex
819
+ @article{li2022practical,
820
+ title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation},
821
+ author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng},
822
+ journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
823
+ pages={16263--16272},
824
+ year={2022}
825
+ }
826
+ ```
827
+ """)
828
+
829
+ # Footer
830
+ gr.Markdown("""
831
+ ---
832
+ ### 📝 Notes:
833
+ - **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
834
+ - **⚡ GPU Acceleration**: Requires CUDA-compatible GPU
835
+ - **📦 Model Caching**: Models are cached for efficient repeated usage
836
+ - For best results, use high-quality rectified stereo pairs
837
+ - Model works on RGB images and supports various resolutions
838
+
839
+ ### 🔗 References:
840
+ - [CREStereo Paper](https://arxiv.org/abs/2203.11483)
841
+ - [Original GitHub Repository](https://github.com/megvii-research/CREStereo)
842
+ - [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch)
843
+ """)
844
+
845
+ return app
846
+
847
+
848
+ def main():
849
+ """Main function to launch the app"""
850
+
851
+ logging.info("🚀 Starting CREStereo Gradio App...")
852
+
853
+ # Parse command line arguments
854
+ import argparse
855
+ parser = argparse.ArgumentParser(description="CREStereo Gradio App")
856
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
857
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
858
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
859
+
860
+ args = parser.parse_args()
861
+
862
+ if args.debug:
863
+ logging.getLogger().setLevel(logging.DEBUG)
864
+
865
+ try:
866
+ # Create and launch app
867
+ logging.info("Creating Gradio app...")
868
+ app = create_app()
869
+ logging.info("✅ Gradio app created successfully")
870
+
871
+ logging.info(f"Launching app on {args.host}:{args.port}")
872
+
873
+ # Launch with appropriate settings
874
+ app.launch(
875
+ server_name=args.host,
876
+ server_port=args.port,
877
+ share=False,
878
+ show_error=True,
879
+ favicon_path=None,
880
+ ssr_mode=False,
881
+ allowed_paths=["./"]
882
+ )
883
+ except Exception as e:
884
+ logging.error(f"Failed to launch app: {e}")
885
+ raise
886
+
887
+
888
+ if __name__ == "__main__":
889
+ main()
CREStereo_demo/models/.gitkeep ADDED
File without changes
CREStereo_demo/models/crestereo_eth3d.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2271ab615015a73edd4759b0f7b25a4d82ffb654270b92d3811237da3d63aa6d
3
+ size 21763979
CREStereo_demo/nets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .crestereo import CREStereo as Model
CREStereo_demo/nets/attention/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import LocalFeatureTransformer
2
+ from .position_encoding import PositionEncodingSine
CREStereo_demo/nets/attention/linear_attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4
+ """
5
+
6
+ import torch
7
+ from torch.nn import Module, Dropout
8
+
9
+
10
+ def elu_feature_map(x):
11
+ return torch.nn.functional.elu(x) + 1
12
+
13
+
14
+ class LinearAttention(Module):
15
+ def __init__(self, eps=1e-6):
16
+ super().__init__()
17
+ self.feature_map = elu_feature_map
18
+ self.eps = eps
19
+
20
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
21
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
22
+ Args:
23
+ queries: [N, L, H, D]
24
+ keys: [N, S, H, D]
25
+ values: [N, S, H, D]
26
+ q_mask: [N, L]
27
+ kv_mask: [N, S]
28
+ Returns:
29
+ queried_values: (N, L, H, D)
30
+ """
31
+ Q = self.feature_map(queries)
32
+ K = self.feature_map(keys)
33
+
34
+ # set padded position to zero
35
+ if q_mask is not None:
36
+ Q = Q * q_mask[:, :, None, None]
37
+ if kv_mask is not None:
38
+ K = K * kv_mask[:, :, None, None]
39
+ values = values * kv_mask[:, :, None, None]
40
+
41
+ v_length = values.size(1)
42
+ values = values / v_length # prevent fp16 overflow
43
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
44
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
45
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
46
+
47
+ return queried_values.contiguous()
48
+
49
+
50
+ class FullAttention(Module):
51
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
52
+ super().__init__()
53
+ self.use_dropout = use_dropout
54
+ self.dropout = Dropout(attention_dropout)
55
+
56
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
57
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
58
+ Args:
59
+ queries: [N, L, H, D]
60
+ keys: [N, S, H, D]
61
+ values: [N, S, H, D]
62
+ q_mask: [N, L]
63
+ kv_mask: [N, S]
64
+ Returns:
65
+ queried_values: (N, L, H, D)
66
+ """
67
+
68
+ # Compute the unnormalized attention and apply the masks
69
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
70
+ if kv_mask is not None:
71
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
72
+
73
+ # Compute the attention and the weighted average
74
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
75
+ A = torch.softmax(softmax_temp * QK, dim=2)
76
+ if self.use_dropout:
77
+ A = self.dropout(A)
78
+
79
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
80
+
81
+ return queried_values.contiguous()
CREStereo_demo/nets/attention/position_encoding.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PositionEncodingSine(nn.Module):
7
+ """
8
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
9
+ """
10
+
11
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=False):
12
+ """
13
+ Args:
14
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
15
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
16
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
17
+ on the final performance. For now, we keep both impls for backward compatability.
18
+ We will remove the buggy impl after re-training all variants of our released models.
19
+ """
20
+ super().__init__()
21
+ pe = torch.zeros((d_model, *max_shape))
22
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
23
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
24
+ if temp_bug_fix:
25
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
26
+ else: # a buggy implementation (for backward compatability only)
27
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
28
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
29
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
30
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
31
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
32
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
33
+
34
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
35
+
36
+ def forward(self, x):
37
+ """
38
+ Args:
39
+ x: [N, C, H, W]
40
+ """
41
+ return x + self.pe[:, :, :x.size(2), :x.size(3)].to(x.device)
CREStereo_demo/nets/attention/transformer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ from .linear_attention import LinearAttention, FullAttention
5
+
6
+ #Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
7
+ class LoFTREncoderLayer(nn.Module):
8
+ def __init__(self,
9
+ d_model,
10
+ nhead,
11
+ attention='linear'):
12
+ super(LoFTREncoderLayer, self).__init__()
13
+
14
+ self.dim = d_model // nhead
15
+ self.nhead = nhead
16
+
17
+ # multi-head attention
18
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
19
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
20
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
21
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
22
+ self.merge = nn.Linear(d_model, d_model, bias=False)
23
+
24
+ # feed-forward network
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(d_model*2, d_model*2, bias=False),
27
+ nn.ReLU(),
28
+ nn.Linear(d_model*2, d_model, bias=False),
29
+ )
30
+
31
+ # norm and dropout
32
+ self.norm1 = nn.LayerNorm(d_model)
33
+ self.norm2 = nn.LayerNorm(d_model)
34
+
35
+ def forward(self, x, source, x_mask=None, source_mask=None):
36
+ """
37
+ Args:
38
+ x (torch.Tensor): [N, L, C]
39
+ source (torch.Tensor): [N, S, C]
40
+ x_mask (torch.Tensor): [N, L] (optional)
41
+ source_mask (torch.Tensor): [N, S] (optional)
42
+ """
43
+ bs = x.size(0)
44
+ query, key, value = x, source, source
45
+
46
+ # multi-head attention
47
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
48
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
49
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
50
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
51
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
52
+ message = self.norm1(message)
53
+
54
+ # feed-forward network
55
+ message = self.mlp(torch.cat([x, message], dim=2))
56
+ message = self.norm2(message)
57
+
58
+ return x + message
59
+
60
+
61
+ class LocalFeatureTransformer(nn.Module):
62
+ """A Local Feature Transformer (LoFTR) module."""
63
+
64
+ def __init__(self, d_model, nhead, layer_names, attention):
65
+ super(LocalFeatureTransformer, self).__init__()
66
+
67
+ self.d_model = d_model
68
+ self.nhead = nhead
69
+ self.layer_names = layer_names
70
+ encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
71
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
72
+ self._reset_parameters()
73
+
74
+ def _reset_parameters(self):
75
+ for p in self.parameters():
76
+ if p.dim() > 1:
77
+ nn.init.xavier_uniform_(p)
78
+
79
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
80
+ """
81
+ Args:
82
+ feat0 (torch.Tensor): [N, L, C]
83
+ feat1 (torch.Tensor): [N, S, C]
84
+ mask0 (torch.Tensor): [N, L] (optional)
85
+ mask1 (torch.Tensor): [N, S] (optional)
86
+ """
87
+ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
88
+
89
+ for layer, name in zip(self.layers, self.layer_names):
90
+
91
+ if name == 'self':
92
+ feat0 = layer(feat0, feat0, mask0, mask0)
93
+ feat1 = layer(feat1, feat1, mask1, mask1)
94
+ elif name == 'cross':
95
+ feat0 = layer(feat0, feat1, mask0, mask1)
96
+ feat1 = layer(feat1, feat0, mask1, mask0)
97
+ else:
98
+ raise KeyError
99
+
100
+ return feat0, feat1
CREStereo_demo/nets/corr.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .utils import bilinear_sampler, coords_grid, manual_pad
7
+
8
+ class AGCL:
9
+ """
10
+ Implementation of Adaptive Group Correlation Layer (AGCL).
11
+ """
12
+
13
+ def __init__(self, fmap1, fmap2, att=None):
14
+ self.fmap1 = fmap1
15
+ self.fmap2 = fmap2
16
+
17
+ self.att = att
18
+
19
+ self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)
20
+
21
+ def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
22
+ if iter_mode:
23
+ corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
24
+ else:
25
+ corr = self.corr_att_offset(
26
+ self.fmap1, self.fmap2, flow, extra_offset, small_patch
27
+ )
28
+ return corr
29
+
30
+ def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
31
+
32
+ N, C, H, W = left_feature.shape
33
+
34
+ di_y, di_x = dilate[0], dilate[1]
35
+ pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
36
+
37
+ right_pad = manual_pad(right_feature, pady, padx)
38
+
39
+ corr_list = []
40
+ for h in range(0, pady * 2 + 1, di_y):
41
+ for w in range(0, padx * 2 + 1, di_x):
42
+ right_crop = right_pad[:, :, h : h + H, w : w + W]
43
+ assert right_crop.shape == left_feature.shape
44
+ corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
45
+ corr_list.append(corr)
46
+
47
+ corr_final = torch.cat(corr_list, dim=1)
48
+
49
+ return corr_final
50
+
51
+ def corr_iter(self, left_feature, right_feature, flow, small_patch):
52
+
53
+ coords = self.coords + flow
54
+ coords = coords.permute(0, 2, 3, 1)
55
+ right_feature = bilinear_sampler(right_feature, coords)
56
+
57
+ if small_patch:
58
+ psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
59
+ dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
60
+ else:
61
+ psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
62
+ dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
63
+
64
+ N, C, H, W = left_feature.shape
65
+ lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
66
+ rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
67
+
68
+ corrs = []
69
+ for i in range(len(psize_list)):
70
+ corr = self.get_correlation(
71
+ lefts[i], rights[i], psize_list[i], dilate_list[i]
72
+ )
73
+ corrs.append(corr)
74
+
75
+ final_corr = torch.cat(corrs, dim=1)
76
+
77
+ return final_corr
78
+
79
+ def corr_att_offset(
80
+ self, left_feature, right_feature, flow, extra_offset, small_patch
81
+ ):
82
+
83
+ N, C, H, W = left_feature.shape
84
+
85
+ if self.att is not None:
86
+ left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
87
+ right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
88
+ # 'n (h w) c -> n c h w'
89
+ left_feature, right_feature = self.att(left_feature, right_feature)
90
+ # 'n (h w) c -> n c h w'
91
+ left_feature, right_feature = [
92
+ x.reshape(N, H, W, C).permute(0, 3, 1, 2)
93
+ for x in [left_feature, right_feature]
94
+ ]
95
+
96
+ lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
97
+ rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
98
+
99
+ C = C // 4
100
+
101
+ if small_patch:
102
+ psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
103
+ dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
104
+ else:
105
+ psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
106
+ dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
107
+
108
+ search_num = 9
109
+ extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2]
110
+
111
+ corrs = []
112
+ for i in range(len(psize_list)):
113
+ left_feature, right_feature = lefts[i], rights[i]
114
+ psize, dilate = psize_list[i], dilate_list[i]
115
+
116
+ psizey, psizex = psize[0], psize[1]
117
+ dilatey, dilatex = dilate[0], dilate[1]
118
+
119
+ ry = psizey // 2 * dilatey
120
+ rx = psizex // 2 * dilatex
121
+ x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device),
122
+ torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy')
123
+
124
+ offsets = torch.stack((x_grid, y_grid))
125
+ offsets = offsets.reshape(2, -1).permute(1, 0)
126
+ for d in sorted((0, 2, 3)):
127
+ offsets = offsets.unsqueeze(d)
128
+ offsets = offsets.repeat_interleave(N, dim=0)
129
+ offsets = offsets + extra_offset
130
+
131
+ coords = self.coords + flow # [N, 2, H, W]
132
+ coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2]
133
+ coords = torch.unsqueeze(coords, 1) + offsets
134
+ coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2]
135
+
136
+ right_feature = bilinear_sampler(
137
+ right_feature, coords
138
+ ) # [N, C, search_num*H, W]
139
+ right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W]
140
+ left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2)
141
+
142
+ corr = torch.mean(left_feature * right_feature, dim=1)
143
+
144
+ corrs.append(corr)
145
+
146
+ final_corr = torch.cat(corrs, dim=1)
147
+
148
+ return final_corr
CREStereo_demo/nets/crestereo.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .update import BasicUpdateBlock
6
+ from .extractor import BasicEncoder
7
+ from .corr import AGCL
8
+
9
+ from .attention import PositionEncodingSine, LocalFeatureTransformer
10
+
11
+ try:
12
+ autocast = torch.cuda.amp.autocast
13
+ except:
14
+ # dummy autocast for PyTorch < 1.6
15
+ class autocast:
16
+ def __init__(self, enabled):
17
+ pass
18
+ def __enter__(self):
19
+ pass
20
+ def __exit__(self, *args):
21
+ pass
22
+
23
+ #Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
24
+ class CREStereo(nn.Module):
25
+ def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
26
+ super(CREStereo, self).__init__()
27
+
28
+ self.max_flow = max_disp
29
+ self.mixed_precision = mixed_precision
30
+ self.test_mode = test_mode
31
+
32
+ self.hidden_dim = 128
33
+ self.context_dim = 128
34
+ self.dropout = 0
35
+
36
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
37
+ self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
38
+
39
+ # loftr
40
+ self.self_att_fn = LocalFeatureTransformer(
41
+ d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
42
+ )
43
+ self.cross_att_fn = LocalFeatureTransformer(
44
+ d_model=256, nhead=8, layer_names=["cross"] * 1, attention="linear"
45
+ )
46
+
47
+ # adaptive search
48
+ self.search_num = 9
49
+ self.conv_offset_16 = nn.Conv2d(
50
+ 256, self.search_num * 2, kernel_size=3, stride=1, padding=1
51
+ )
52
+ self.conv_offset_8 = nn.Conv2d(
53
+ 256, self.search_num * 2, kernel_size=3, stride=1, padding=1
54
+ )
55
+ self.range_16 = 1
56
+ self.range_8 = 1
57
+
58
+ def freeze_bn(self):
59
+ for m in self.modules():
60
+ if isinstance(m, nn.BatchNorm2d):
61
+ m.eval()
62
+
63
+ def convex_upsample(self, flow, mask, rate=4):
64
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
65
+ N, _, H, W = flow.shape
66
+ # print(flow.shape, mask.shape, rate)
67
+ mask = mask.view(N, 1, 9, rate, rate, H, W)
68
+ mask = torch.softmax(mask, dim=2)
69
+
70
+ up_flow = F.unfold(rate * flow, [3,3], padding=1)
71
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
72
+
73
+ up_flow = torch.sum(mask * up_flow, dim=2)
74
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
75
+ return up_flow.reshape(N, 2, rate*H, rate*W)
76
+
77
+ def zero_init(self, fmap):
78
+ N, C, H, W = fmap.shape
79
+ _x = torch.zeros([N, 1, H, W], dtype=torch.float32)
80
+ _y = torch.zeros([N, 1, H, W], dtype=torch.float32)
81
+ zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
82
+ return zero_flow
83
+
84
+ def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
85
+ """ Estimate optical flow between pair of frames """
86
+
87
+ image1 = 2 * (image1 / 255.0) - 1.0
88
+ image2 = 2 * (image2 / 255.0) - 1.0
89
+
90
+ image1 = image1.contiguous()
91
+ image2 = image2.contiguous()
92
+
93
+ hdim = self.hidden_dim
94
+ cdim = self.context_dim
95
+
96
+ # run the feature network
97
+ with autocast(enabled=self.mixed_precision):
98
+ fmap1, fmap2 = self.fnet([image1, image2])
99
+
100
+ fmap1 = fmap1.float()
101
+ fmap2 = fmap2.float()
102
+
103
+ with autocast(enabled=self.mixed_precision):
104
+
105
+ # 1/4 -> 1/8
106
+ # feature
107
+ fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
108
+ fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
109
+
110
+ # offset
111
+ offset_dw8 = self.conv_offset_8(fmap1_dw8)
112
+ offset_dw8 = self.range_8 * (torch.sigmoid(offset_dw8) - 0.5) * 2.0
113
+
114
+ # context
115
+ net, inp = torch.split(fmap1, [hdim,hdim], dim=1)
116
+ net = torch.tanh(net)
117
+ inp = F.relu(inp)
118
+ net_dw8 = F.avg_pool2d(net, 2, stride=2)
119
+ inp_dw8 = F.avg_pool2d(inp, 2, stride=2)
120
+
121
+ # 1/4 -> 1/16
122
+ # feature
123
+ fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
124
+ fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
125
+ offset_dw16 = self.conv_offset_16(fmap1_dw16)
126
+ offset_dw16 = self.range_16 * (torch.sigmoid(offset_dw16) - 0.5) * 2.0
127
+
128
+ # context
129
+ net_dw16 = F.avg_pool2d(net, 4, stride=4)
130
+ inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
131
+
132
+ # positional encoding and self-attention
133
+ pos_encoding_fn_small = PositionEncodingSine(
134
+ d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
135
+ )
136
+ # 'n c h w -> n (h w) c'
137
+ x_tmp = pos_encoding_fn_small(fmap1_dw16)
138
+ fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
139
+ # 'n c h w -> n (h w) c'
140
+ x_tmp = pos_encoding_fn_small(fmap2_dw16)
141
+ fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
142
+
143
+ fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
144
+ fmap1_dw16, fmap2_dw16 = [
145
+ x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
146
+ for x in [fmap1_dw16, fmap2_dw16]
147
+ ]
148
+
149
+ corr_fn = AGCL(fmap1, fmap2)
150
+ corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8)
151
+ corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn)
152
+
153
+ # Cascaded refinement (1/16 + 1/8 + 1/4)
154
+ predictions = []
155
+ flow = None
156
+ flow_up = None
157
+ if flow_init is not None:
158
+ scale = fmap1.shape[2] / flow_init.shape[2]
159
+ flow = -scale * F.interpolate(
160
+ flow_init,
161
+ size=(fmap1.shape[2], fmap1.shape[3]),
162
+ mode="bilinear",
163
+ align_corners=True,
164
+ )
165
+ else:
166
+ # zero initialization
167
+ flow_dw16 = self.zero_init(fmap1_dw16)
168
+
169
+ # Recurrent Update Module
170
+ # RUM: 1/16
171
+ for itr in range(iters // 2):
172
+ if itr % 2 == 0:
173
+ small_patch = False
174
+ else:
175
+ small_patch = True
176
+
177
+ flow_dw16 = flow_dw16.detach()
178
+ out_corrs = corr_fn_att_dw16(
179
+ flow_dw16, offset_dw16, small_patch=small_patch
180
+ )
181
+
182
+ with autocast(enabled=self.mixed_precision):
183
+ net_dw16, up_mask, delta_flow = self.update_block(
184
+ net_dw16, inp_dw16, out_corrs, flow_dw16
185
+ )
186
+
187
+ flow_dw16 = flow_dw16 + delta_flow
188
+ flow = self.convex_upsample(flow_dw16, up_mask, rate=4)
189
+ flow_up = -4 * F.interpolate(
190
+ flow,
191
+ size=(4 * flow.shape[2], 4 * flow.shape[3]),
192
+ mode="bilinear",
193
+ align_corners=True,
194
+ )
195
+ predictions.append(flow_up)
196
+
197
+ scale = fmap1_dw8.shape[2] / flow.shape[2]
198
+ flow_dw8 = -scale * F.interpolate(
199
+ flow,
200
+ size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
201
+ mode="bilinear",
202
+ align_corners=True,
203
+ )
204
+
205
+ # RUM: 1/8
206
+ for itr in range(iters // 2):
207
+ if itr % 2 == 0:
208
+ small_patch = False
209
+ else:
210
+ small_patch = True
211
+
212
+ flow_dw8 = flow_dw8.detach()
213
+ out_corrs = corr_fn_dw8(flow_dw8, offset_dw8, small_patch=small_patch)
214
+
215
+ with autocast(enabled=self.mixed_precision):
216
+ net_dw8, up_mask, delta_flow = self.update_block(
217
+ net_dw8, inp_dw8, out_corrs, flow_dw8
218
+ )
219
+
220
+ flow_dw8 = flow_dw8 + delta_flow
221
+ flow = self.convex_upsample(flow_dw8, up_mask, rate=4)
222
+ flow_up = -2 * F.interpolate(
223
+ flow,
224
+ size=(2 * flow.shape[2], 2 * flow.shape[3]),
225
+ mode="bilinear",
226
+ align_corners=True,
227
+ )
228
+ predictions.append(flow_up)
229
+
230
+ scale = fmap1.shape[2] / flow.shape[2]
231
+ flow = -scale * F.interpolate(
232
+ flow,
233
+ size=(fmap1.shape[2], fmap1.shape[3]),
234
+ mode="bilinear",
235
+ align_corners=True,
236
+ )
237
+
238
+ # RUM: 1/4
239
+ for itr in range(iters):
240
+ if itr % 2 == 0:
241
+ small_patch = False
242
+ else:
243
+ small_patch = True
244
+
245
+ flow = flow.detach()
246
+ out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True)
247
+
248
+ with autocast(enabled=self.mixed_precision):
249
+ net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow)
250
+
251
+ flow = flow + delta_flow
252
+ flow_up = -self.convex_upsample(flow, up_mask, rate=4)
253
+ predictions.append(flow_up)
254
+
255
+ if self.test_mode:
256
+ return flow_up
257
+
258
+ return predictions
CREStereo_demo/nets/extractor.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20
+
21
+ elif norm_fn == 'batch':
22
+ self.norm1 = nn.BatchNorm2d(planes)
23
+ self.norm2 = nn.BatchNorm2d(planes)
24
+ self.norm3 = nn.BatchNorm2d(planes)
25
+
26
+ elif norm_fn == 'instance':
27
+ self.norm1 = nn.InstanceNorm2d(planes, affine=False)
28
+ self.norm2 = nn.InstanceNorm2d(planes, affine=False)
29
+ self.norm3 = nn.InstanceNorm2d(planes, affine=False)
30
+
31
+ elif norm_fn == 'none':
32
+ self.norm1 = nn.Sequential()
33
+ self.norm2 = nn.Sequential()
34
+ self.norm3 = nn.Sequential()
35
+
36
+ self.downsample = nn.Sequential(
37
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
38
+
39
+
40
+ def forward(self, x):
41
+ y = x
42
+ y = self.relu(self.norm1(self.conv1(y)))
43
+ y = self.relu(self.norm2(self.conv2(y)))
44
+
45
+ x = self.downsample(x)
46
+
47
+ return self.relu(x+y)
48
+
49
+
50
+ class BasicEncoder(nn.Module):
51
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
52
+ super(BasicEncoder, self).__init__()
53
+ self.norm_fn = norm_fn
54
+
55
+ if self.norm_fn == 'group':
56
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
57
+
58
+ elif self.norm_fn == 'batch':
59
+ self.norm1 = nn.BatchNorm2d(64)
60
+
61
+ elif self.norm_fn == 'instance':
62
+ self.norm1 = nn.InstanceNorm2d(64, affine=False)
63
+
64
+ elif self.norm_fn == 'none':
65
+ self.norm1 = nn.Sequential()
66
+
67
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
68
+ self.relu1 = nn.ReLU(inplace=True)
69
+
70
+ self.in_planes = 64
71
+ self.layer1 = self._make_layer(64, stride=1)
72
+ self.layer2 = self._make_layer(96, stride=2)
73
+ self.layer3 = self._make_layer(128, stride=1)
74
+
75
+ # output convolution
76
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
77
+
78
+ self.dropout = None
79
+ if dropout > 0:
80
+ self.dropout = nn.Dropout2d(p=dropout)
81
+
82
+ for m in self.modules():
83
+ if isinstance(m, nn.Conv2d):
84
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
85
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
86
+ if m.weight is not None:
87
+ nn.init.constant_(m.weight, 1)
88
+ if m.bias is not None:
89
+ nn.init.constant_(m.bias, 0)
90
+
91
+ def _make_layer(self, dim, stride=1):
92
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
93
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
94
+ layers = (layer1, layer2)
95
+
96
+ self.in_planes = dim
97
+ return nn.Sequential(*layers)
98
+
99
+ def forward(self, x):
100
+
101
+ # if input is list, combine batch dimension
102
+ is_list = isinstance(x, tuple) or isinstance(x, list)
103
+ if is_list:
104
+ batch_dim = x[0].shape[0]
105
+ x = torch.cat(x, dim=0)
106
+
107
+ x = self.conv1(x)
108
+ x = self.norm1(x)
109
+ x = self.relu1(x)
110
+
111
+ x = self.layer1(x)
112
+ x = self.layer2(x)
113
+ x = self.layer3(x)
114
+
115
+ x = self.conv2(x)
116
+
117
+ if self.dropout is not None:
118
+ x = self.dropout(x)
119
+
120
+ if is_list:
121
+ x = torch.split(x, x.shape[0]//2, dim=0)
122
+
123
+ return x
CREStereo_demo/nets/update.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ #Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256):
8
+ super(FlowHead, self).__init__()
9
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11
+ self.relu = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ return self.conv2(self.relu(self.conv1(x)))
15
+
16
+
17
+ class SepConvGRU(nn.Module):
18
+ def __init__(self, hidden_dim=128, input_dim=192+128):
19
+ super(SepConvGRU, self).__init__()
20
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
21
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
22
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
23
+
24
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
25
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
26
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
27
+
28
+ def forward(self, h, x):
29
+ # horizontal
30
+ hx = torch.cat([h, x], dim=1)
31
+ z = torch.sigmoid(self.convz1(hx))
32
+ r = torch.sigmoid(self.convr1(hx))
33
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
34
+ h = (1-z) * h + z * q
35
+
36
+ # vertical
37
+ hx = torch.cat([h, x], dim=1)
38
+ z = torch.sigmoid(self.convz2(hx))
39
+ r = torch.sigmoid(self.convr2(hx))
40
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
41
+ h = (1-z) * h + z * q
42
+
43
+ return h
44
+
45
+
46
+ class BasicMotionEncoder(nn.Module):
47
+ def __init__(self, cor_planes):
48
+ super(BasicMotionEncoder, self).__init__()
49
+
50
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
51
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
52
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
53
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
54
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
55
+
56
+ def forward(self, flow, corr):
57
+ cor = F.relu(self.convc1(corr))
58
+ cor = F.relu(self.convc2(cor))
59
+ flo = F.relu(self.convf1(flow))
60
+ flo = F.relu(self.convf2(flo))
61
+
62
+ cor_flo = torch.cat([cor, flo], dim=1)
63
+ out = F.relu(self.conv(cor_flo))
64
+ return torch.cat([out, flow], dim=1)
65
+
66
+
67
+ class BasicUpdateBlock(nn.Module):
68
+ def __init__(self, hidden_dim, cor_planes, mask_size=8):
69
+ super(BasicUpdateBlock, self).__init__()
70
+
71
+ self.encoder = BasicMotionEncoder(cor_planes)
72
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
73
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
74
+
75
+ self.mask = nn.Sequential(
76
+ nn.Conv2d(128, 256, 3, padding=1),
77
+ nn.ReLU(inplace=True),
78
+ nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
79
+
80
+ def forward(self, net, inp, corr, flow, upsample=True):
81
+ # print(inp.shape, corr.shape, flow.shape)
82
+ motion_features = self.encoder(flow, corr)
83
+ # print(motion_features.shape, inp.shape)
84
+ inp = torch.cat((inp, motion_features), dim=1)
85
+
86
+ net = self.gru(net, inp)
87
+ delta_flow = self.flow_head(net)
88
+
89
+ # scale mask to balence gradients
90
+ mask = .25 * self.mask(net)
91
+ return net, mask, delta_flow
CREStereo_demo/nets/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import bilinear_sampler, coords_grid, manual_pad
CREStereo_demo/nets/utils/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+ #Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
6
+
7
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
8
+ """ Wrapper for grid_sample, uses pixel coordinates """
9
+ H, W = img.shape[-2:]
10
+ xgrid, ygrid = coords.split([1,1], dim=-1)
11
+ xgrid = 2*xgrid/(W-1) - 1
12
+ ygrid = 2*ygrid/(H-1) - 1
13
+
14
+ grid = torch.cat([xgrid, ygrid], dim=-1)
15
+ # img = F.grid_sample(img, grid, align_corners=True)
16
+ img = bilinear_grid_sample(img, grid, align_corners=True)
17
+
18
+ if mask:
19
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
20
+ return img, mask.float()
21
+
22
+ return img
23
+
24
+ def coords_grid(batch, ht, wd, device):
25
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
26
+ coords = torch.stack(coords[::-1], dim=0).float()
27
+ return coords[None].repeat(batch, 1, 1, 1)
28
+
29
+ def manual_pad(x, pady, padx):
30
+
31
+ pad = (padx, padx, pady, pady)
32
+ return F.pad(x.clone().detach(), pad, "replicate")
33
+
34
+ # Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160
35
+ def bilinear_grid_sample(im, grid, align_corners=False):
36
+ """Given an input and a flow-field grid, computes the output using input
37
+ values and pixel locations from grid. Supported only bilinear interpolation
38
+ method to sample the input pixels.
39
+
40
+ Args:
41
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
42
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
43
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
44
+ considered as referring to the center points of the input’s
45
+ corner pixels. If set to False, they are instead considered as
46
+ referring to the corner points of the input’s corner pixels,
47
+ making the sampling more resolution agnostic.
48
+
49
+ Returns:
50
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
51
+ """
52
+ n, c, h, w = im.shape
53
+ gn, gh, gw, _ = grid.shape
54
+ assert n == gn
55
+
56
+ x = grid[:, :, :, 0]
57
+ y = grid[:, :, :, 1]
58
+
59
+ if align_corners:
60
+ x = ((x + 1) / 2) * (w - 1)
61
+ y = ((y + 1) / 2) * (h - 1)
62
+ else:
63
+ x = ((x + 1) * w - 1) / 2
64
+ y = ((y + 1) * h - 1) / 2
65
+
66
+ x = x.view(n, -1)
67
+ y = y.view(n, -1)
68
+
69
+ x0 = torch.floor(x).long()
70
+ y0 = torch.floor(y).long()
71
+ x1 = x0 + 1
72
+ y1 = y0 + 1
73
+
74
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
75
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
76
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
77
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
78
+
79
+ # Apply default for grid_sample function zero padding
80
+ im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
81
+ padded_h = h + 2
82
+ padded_w = w + 2
83
+ # save points positions after padding
84
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
85
+
86
+ # Clip coordinates to padded image size
87
+ x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
88
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
89
+ x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
90
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
91
+ y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
92
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
93
+ y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
94
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
95
+
96
+ im_padded = im_padded.view(n, c, -1)
97
+
98
+ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
99
+ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
100
+ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
101
+ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
102
+
103
+ Ia = torch.gather(im_padded, 2, x0_y0)
104
+ Ib = torch.gather(im_padded, 2, x0_y1)
105
+ Ic = torch.gather(im_padded, 2, x1_y0)
106
+ Id = torch.gather(im_padded, 2, x1_y1)
107
+
108
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
FoundationStereo_demo/Utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+ import os, sys, time, pickle, itertools, datetime, imageio, logging, joblib, importlib, argparse
11
+ # Import torch and related modules only when needed inside functions to avoid CUDA init
12
+ # import torch, torchvision # Moved to function-level imports
13
+ # import torch.nn.functional as F # Moved to function-level imports
14
+ # import torch.nn as nn # Moved to function-level imports
15
+ from functools import partial
16
+ import pandas as pd
17
+ # Import open3d only when needed to avoid CUDA conflicts
18
+ # import open3d as o3d # Moved to function-level imports
19
+ import cv2
20
+ import numpy as np
21
+ # Removed transformations import to avoid ModuleNotFoundError
22
+ code_dir = os.path.dirname(os.path.realpath(__file__))
23
+ sys.path.append(code_dir)
24
+
25
+
26
+
27
+ def set_logging_format(level=logging.INFO):
28
+ importlib.reload(logging)
29
+ FORMAT = '%(message)s'
30
+ logging.basicConfig(level=level, format=FORMAT, datefmt='%m-%d|%H:%M:%S')
31
+
32
+ # Only call set_logging_format when explicitly needed, not during import
33
+ # set_logging_format() # Commented out to avoid automatic execution
34
+
35
+
36
+
37
+ def set_seed(random_seed=0):
38
+ import torch # Import torch only when function is called
39
+ import random
40
+ import numpy as np
41
+
42
+ np.random.seed(random_seed)
43
+ random.seed(random_seed)
44
+ torch.manual_seed(random_seed)
45
+ # Skip CUDA seeding to avoid initialization issues in ZeroGPU
46
+ # CUDA seeding should be done within @spaces.GPU context
47
+ try:
48
+ # Only try CUDA operations if we're already in a CUDA context
49
+ if hasattr(torch.cuda, '_initialized') and torch.cuda._initialized:
50
+ if torch.cuda.is_available():
51
+ torch.cuda.manual_seed_all(random_seed)
52
+ except (RuntimeError, AttributeError):
53
+ pass # CUDA not initialized yet or not available
54
+ torch.backends.cudnn.deterministic = True
55
+ torch.backends.cudnn.benchmark = False
56
+
57
+
58
+ def toOpen3dCloud(points,colors=None,normals=None):
59
+ import open3d as o3d # Import only when function is called
60
+
61
+ cloud = o3d.geometry.PointCloud()
62
+ cloud.points = o3d.utility.Vector3dVector(points.astype(np.float64))
63
+ if colors is not None:
64
+ if colors.max()>1:
65
+ colors = colors/255.0
66
+ cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
67
+ if normals is not None:
68
+ cloud.normals = o3d.utility.Vector3dVector(normals.astype(np.float64))
69
+ return cloud
70
+
71
+
72
+
73
+ def depth2xyzmap(depth:np.ndarray, K, uvs:np.ndarray=None, zmin=0.1):
74
+ invalid_mask = (depth<zmin)
75
+ H,W = depth.shape[:2]
76
+ if uvs is None:
77
+ vs,us = np.meshgrid(np.arange(0,H),np.arange(0,W), sparse=False, indexing='ij')
78
+ vs = vs.reshape(-1)
79
+ us = us.reshape(-1)
80
+ else:
81
+ uvs = uvs.round().astype(int)
82
+ us = uvs[:,0]
83
+ vs = uvs[:,1]
84
+ zs = depth[vs,us]
85
+ xs = (us-K[0,2])*zs/K[0,0]
86
+ ys = (vs-K[1,2])*zs/K[1,1]
87
+ pts = np.stack((xs.reshape(-1),ys.reshape(-1),zs.reshape(-1)), 1) #(N,3)
88
+ xyz_map = np.zeros((H,W,3), dtype=np.float32)
89
+ xyz_map[vs,us] = pts
90
+ if invalid_mask.any():
91
+ xyz_map[invalid_mask] = 0
92
+ return xyz_map
93
+
94
+
95
+
96
+ def freeze_model(model):
97
+ # This function now works with any model passed to it
98
+ # No need to import torch at module level
99
+ model = model.eval()
100
+ for p in model.parameters():
101
+ p.requires_grad = False
102
+ for p in model.buffers():
103
+ p.requires_grad = False
104
+ return model
105
+
106
+
107
+
108
+ def get_resize_keep_aspect_ratio(H, W, divider=16, max_H=1232, max_W=1232):
109
+ assert max_H%divider==0
110
+ assert max_W%divider==0
111
+
112
+ def round_by_divider(x):
113
+ return int(np.ceil(x/divider)*divider)
114
+
115
+ H_resize = round_by_divider(H) #!NOTE KITTI width=1242
116
+ W_resize = round_by_divider(W)
117
+ if H_resize>max_H or W_resize>max_W:
118
+ if H_resize>W_resize:
119
+ W_resize = round_by_divider(W_resize*max_H/H_resize)
120
+ H_resize = max_H
121
+ else:
122
+ H_resize = round_by_divider(H_resize*max_W/W_resize)
123
+ W_resize = max_W
124
+ return int(H_resize), int(W_resize)
125
+
126
+
127
+ def vis_disparity(disp, min_val=None, max_val=None, invalid_thres=np.inf, color_map=cv2.COLORMAP_TURBO, cmap=None, other_output={}):
128
+ """
129
+ @disp: np array (H,W)
130
+ @invalid_thres: > thres is invalid
131
+ """
132
+ disp = disp.copy()
133
+ H,W = disp.shape[:2]
134
+ invalid_mask = disp>=invalid_thres
135
+ if (invalid_mask==0).sum()==0:
136
+ other_output['min_val'] = None
137
+ other_output['max_val'] = None
138
+ return np.zeros((H,W,3))
139
+ if min_val is None:
140
+ min_val = disp[invalid_mask==0].min()
141
+ if max_val is None:
142
+ max_val = disp[invalid_mask==0].max()
143
+ other_output['min_val'] = min_val
144
+ other_output['max_val'] = max_val
145
+ vis = ((disp-min_val)/(max_val-min_val)).clip(0,1) * 255
146
+ if cmap is None:
147
+ vis = cv2.applyColorMap(vis.clip(0, 255).astype(np.uint8), color_map)[...,::-1]
148
+ else:
149
+ vis = cmap(vis.astype(np.uint8))[...,:3]*255
150
+ if invalid_mask.any():
151
+ vis[invalid_mask] = 0
152
+ return vis.astype(np.uint8)
153
+
154
+
155
+
156
+ def depth_uint8_decoding(depth_uint8, scale=1000):
157
+ depth_uint8 = depth_uint8.astype(float)
158
+ out = depth_uint8[...,0]*255*255 + depth_uint8[...,1]*255 + depth_uint8[...,2]
159
+ return out/float(scale)
160
+
FoundationStereo_demo/app.py ADDED
@@ -0,0 +1,1138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import tempfile
5
+ import zipfile
6
+ import gc
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union
9
+ import numpy as np
10
+ import cv2
11
+ import gradio as gr
12
+ import imageio
13
+
14
+ # Import spaces BEFORE torch to ensure proper ZeroGPU initialization
15
+ import spaces
16
+
17
+ # Import torch after spaces - avoid any CUDA calls during import
18
+ import torch
19
+
20
+ # Completely avoid CUDA operations during import phase
21
+ # Do not set default tensor type or modify CUDA settings outside GPU context
22
+ # torch.set_default_tensor_type('torch.FloatTensor') # Commented out - causes CUDA init
23
+
24
+ # Import other safe modules
25
+ from omegaconf import OmegaConf
26
+ from huggingface_hub import hf_hub_download, snapshot_download
27
+
28
+ # Do not modify CUDA settings during import - this can trigger CUDA initialization
29
+ # torch.backends.cudnn.enabled = False # Commented out
30
+ # torch.backends.cudnn.benchmark = False # Commented out
31
+
32
+ # Use current directory as base (gradio_app folder)
33
+ current_dir = os.path.dirname(os.path.abspath(__file__))
34
+ base_dir = current_dir # gradio_app folder
35
+
36
+ # Add current directory to path for local imports
37
+ sys.path.insert(0, current_dir)
38
+
39
+ # DO NOT import any local modules here that might use CUDA
40
+ # All local module imports will be done inside GPU-decorated functions
41
+
42
+ # Import Open3D with error handling - avoid any CUDA operations
43
+ OPEN3D_AVAILABLE = False # Will be set properly in GPU context
44
+ try:
45
+ # Set Open3D to CPU mode to avoid CUDA initialization
46
+ os.environ['OPEN3D_CPU_RENDERING'] = '1'
47
+ # Don't import open3d here - do it inside GPU functions
48
+ # import open3d as o3d
49
+ OPEN3D_AVAILABLE = True # Assume available, will check inside GPU context
50
+ except Exception as e:
51
+ logging.warning(f"Open3D setup failed: {e}")
52
+ OPEN3D_AVAILABLE = False
53
+
54
+ # Configure logging
55
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
56
+
57
+ # Hugging Face model repository configuration
58
+ HF_REPO_ID = "shriarul5273/FoundationStereo_models"
59
+ MODEL_VARIANTS = {
60
+ "11-33-40": {
61
+ "display_name": "FoundationStereo (Low-cost variant - 11-33-40)",
62
+ "model_file": "pretrained_models/11-33-40/model_best_bp2.pth",
63
+ "config_file": "pretrained_models/11-33-40/cfg.yaml"
64
+ },
65
+ "23-51-11": {
66
+ "display_name": "FoundationStereo (High-quality variant - 23-51-11)",
67
+ "model_file": "pretrained_models/23-51-11/model_best_bp2.pth",
68
+ "config_file": "pretrained_models/23-51-11/cfg.yaml"
69
+ }
70
+ }
71
+
72
+ # Global variables for model caching
73
+ MODEL_PATH: str = None
74
+ CONFIG_PATH: str = None
75
+
76
+ # Model cache to avoid reloading when selection doesn't change
77
+ _cached_model = None
78
+ _cached_device = None
79
+ _cached_model_selection = None
80
+
81
+
82
+ def aggressive_cleanup():
83
+ """Perform basic cleanup - no CUDA operations outside GPU context"""
84
+ import gc
85
+ gc.collect()
86
+ logging.info("Performed basic memory cleanup")
87
+
88
+
89
+ @spaces.GPU
90
+ def check_gpu_memory():
91
+ """Check and log current GPU memory usage - only call within GPU context"""
92
+ try:
93
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
94
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
95
+ max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
96
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
97
+
98
+ logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
99
+ return allocated, reserved, max_allocated, total
100
+ except RuntimeError as e:
101
+ logging.warning(f"Failed to get GPU memory info: {e}")
102
+ return None, None, None, None
103
+
104
+
105
+ def download_model_from_hf(variant: str, force_download: bool = False) -> Tuple[str, str]:
106
+ """
107
+ Download model and config files from Hugging Face Hub
108
+
109
+ Args:
110
+ variant: Model variant ("11-33-40" or "23-51-11")
111
+ force_download: Force re-download even if files exist locally
112
+
113
+ Returns:
114
+ Tuple of (model_path, config_path)
115
+ """
116
+ if variant not in MODEL_VARIANTS:
117
+ raise ValueError(f"Unknown model variant: {variant}. Available: {list(MODEL_VARIANTS.keys())}")
118
+
119
+ variant_info = MODEL_VARIANTS[variant]
120
+
121
+ try:
122
+ if not force_download:
123
+ logging.info(f"📦 Checking cache for model variant: {variant}")
124
+ else:
125
+ logging.info(f"🔄 Force downloading model variant: {variant}")
126
+
127
+ # Download model file
128
+ model_path = hf_hub_download(
129
+ repo_id=HF_REPO_ID,
130
+ filename=variant_info["model_file"],
131
+ force_download=force_download,
132
+ local_dir_use_symlinks=False
133
+ )
134
+
135
+ # Download config file
136
+ config_path = hf_hub_download(
137
+ repo_id=HF_REPO_ID,
138
+ filename=variant_info["config_file"],
139
+ force_download=force_download,
140
+ local_dir_use_symlinks=False
141
+ )
142
+
143
+ if force_download:
144
+ logging.info(f"✅ Successfully downloaded {variant} model files")
145
+ else:
146
+ logging.info(f"✅ Successfully loaded {variant} model files from cache")
147
+
148
+ logging.debug(f"Model: {model_path}")
149
+ logging.debug(f"Config: {config_path}")
150
+
151
+ return model_path, config_path
152
+
153
+ except Exception as e:
154
+ logging.error(f"Failed to download model {variant}: {e}")
155
+ raise RuntimeError(f"Failed to download model {variant} from Hugging Face: {e}")
156
+
157
+
158
+ def get_available_models() -> dict:
159
+ """Get all available models with their display names and download info"""
160
+ models = {}
161
+
162
+ # First check local models (legacy support)
163
+ search_dirs = [
164
+ os.path.join(current_dir, "pretrained_models"),
165
+ os.path.join(os.path.dirname(current_dir), "pretrained_models")
166
+ ]
167
+
168
+ for search_dir in search_dirs:
169
+ if os.path.exists(search_dir):
170
+ for model_dir in os.listdir(search_dir):
171
+ model_path = os.path.join(search_dir, model_dir, "model_best_bp2.pth")
172
+ cfg_path = os.path.join(search_dir, model_dir, "cfg.yaml")
173
+
174
+ if os.path.exists(model_path) and os.path.exists(cfg_path):
175
+ # Create a descriptive name for the model
176
+ if model_dir == "11-33-40":
177
+ display_name = "FoundationStereo (Low-cost variant - 11-33-40) [Local]"
178
+ elif model_dir == "23-51-11":
179
+ display_name = "FoundationStereo (High-quality variant - 23-51-11) [Local]"
180
+ else:
181
+ display_name = f"FoundationStereo ({model_dir}) [Local]"
182
+
183
+ models[display_name] = {
184
+ "model_path": model_path,
185
+ "config_path": cfg_path,
186
+ "variant": model_dir,
187
+ "source": "local"
188
+ }
189
+
190
+ # Add Hugging Face models
191
+ for variant, info in MODEL_VARIANTS.items():
192
+ display_name = f"{info['display_name']} [Hugging Face]"
193
+ models[display_name] = {
194
+ "model_path": None, # Will be downloaded when needed
195
+ "config_path": None, # Will be downloaded when needed
196
+ "variant": variant,
197
+ "source": "huggingface"
198
+ }
199
+
200
+ return models
201
+
202
+
203
+ def find_model_path() -> Tuple[Optional[str], Optional[str]]:
204
+ """Find available model and config paths (legacy function for backward compatibility)"""
205
+ models = get_available_models()
206
+ if models:
207
+ # Prefer Hugging Face models over local ones
208
+ # First try to find HF low-cost variant
209
+ for display_name in models:
210
+ if "11-33-40" in display_name and "[Hugging Face]" in display_name:
211
+ return get_model_paths_from_selection(display_name)
212
+
213
+ # Then try local low-cost variant
214
+ for display_name in models:
215
+ if "11-33-40" in display_name:
216
+ return get_model_paths_from_selection(display_name)
217
+
218
+ # If no low-cost variant, return the first available
219
+ first_model_name = next(iter(models.keys()))
220
+ return get_model_paths_from_selection(first_model_name)
221
+ return None, None
222
+
223
+
224
+ def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[str]]:
225
+ """Get model and config paths from the selected model"""
226
+ models = get_available_models()
227
+
228
+ # Check if it's in our models dict
229
+ if model_selection in models:
230
+ model_info = models[model_selection]
231
+
232
+ # If it's a Hugging Face model, download it first (or get from cache)
233
+ if model_info["source"] == "huggingface":
234
+ variant = model_info["variant"]
235
+ try:
236
+ logging.info(f"📦 Retrieving {variant} model from cache...")
237
+ model_path, config_path = download_model_from_hf(variant, force_download=False)
238
+ return model_path, config_path
239
+ except Exception as e:
240
+ logging.error(f"Failed to get model {variant} from cache: {e}")
241
+ return None, None
242
+ else:
243
+ # Local model
244
+ logging.info(f"📁 Using local model: {model_selection}")
245
+ return model_info["model_path"], model_info["config_path"]
246
+
247
+ # Handle direct HF model selection (fallback)
248
+ elif "[Hugging Face]" in model_selection:
249
+ if "11-33-40" in model_selection:
250
+ variant = "11-33-40"
251
+ elif "23-51-11" in model_selection:
252
+ variant = "23-51-11"
253
+ else:
254
+ logging.error(f"Unknown HF model variant in: {model_selection}")
255
+ return None, None
256
+
257
+ try:
258
+ logging.info(f"📦 Retrieving {variant} model from cache...")
259
+ model_path, config_path = download_model_from_hf(variant, force_download=False)
260
+ return model_path, config_path
261
+ except Exception as e:
262
+ logging.error(f"Failed to get model {variant} from cache: {e}")
263
+ return None, None
264
+
265
+ return None, None
266
+
267
+
268
+ def get_cached_model(model_selection: str):
269
+ """Get cached model or load new one if selection changed"""
270
+ global _cached_model, _cached_device, _cached_model_selection
271
+
272
+ # Get model paths from selection
273
+ model_path, config_path = get_model_paths_from_selection(model_selection)
274
+
275
+ if model_path is None or config_path is None:
276
+ raise ValueError(f"Selected model not found: {model_selection}")
277
+
278
+ # Load model fresh for each inference (ZeroGPU optimized)
279
+ # Since models are pre-downloaded, this should be fast
280
+ logging.info(f"🚀 Loading cached model: {model_selection}")
281
+ model, device = load_model_for_inference(model_path, config_path)
282
+
283
+ logging.info(f"✅ Model loaded successfully from cache: {model_selection}")
284
+ return model, device
285
+
286
+
287
+ def clear_model_cache():
288
+ """Clear the cached model to free memory"""
289
+ global _cached_model, _cached_device, _cached_model_selection
290
+
291
+ if _cached_model is not None:
292
+ logging.info("Clearing model cache...")
293
+ del _cached_model
294
+ _cached_model = None
295
+ _cached_device = None
296
+ _cached_model_selection = None
297
+
298
+ # Simple cleanup
299
+ import gc
300
+ gc.collect()
301
+ logging.info("Model cache cleared")
302
+ else:
303
+ logging.info("No model in cache to clear")
304
+
305
+
306
+ @spaces.GPU
307
+ def load_model_for_inference(model_path: str, cfg_path: str):
308
+ """Load model temporarily for inference (demo-style)"""
309
+ # Set CUDA settings safely within GPU context
310
+ torch.set_default_tensor_type('torch.cuda.FloatTensor') # Now safe to use CUDA tensors
311
+ torch.backends.cudnn.enabled = True
312
+ torch.backends.cudnn.benchmark = True
313
+
314
+ # Import these inside the function to avoid early CUDA initialization
315
+ try:
316
+ # Import selectively to avoid CUDA calls in Utils
317
+ from core.foundation_stereo import FoundationStereo
318
+ from omegaconf import OmegaConf
319
+ logging.info("Successfully imported required modules")
320
+
321
+ # Import set_logging_format safely
322
+ from Utils import set_logging_format
323
+ set_logging_format()
324
+
325
+ # Manual seed setting to avoid CUDA calls in Utils.set_seed
326
+ import random
327
+ random_seed = 0
328
+ np.random.seed(random_seed)
329
+ random.seed(random_seed)
330
+ torch.manual_seed(random_seed)
331
+ # CUDA seeding will be done after device is available
332
+
333
+ logging.info("Set logging format and seed")
334
+ except Exception as e:
335
+ logging.error(f"Failed to import modules: {e}")
336
+ raise RuntimeError(f"Import failed: {e}")
337
+
338
+ # Check if CUDA is available after ZeroGPU initialization
339
+ if not torch.cuda.is_available():
340
+ raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.")
341
+
342
+ # Use the first available CUDA device
343
+ device = torch.device("cuda")
344
+
345
+ # Now set CUDA seed safely within GPU context
346
+ try:
347
+ torch.cuda.manual_seed_all(random_seed)
348
+ torch.backends.cudnn.deterministic = True
349
+ torch.backends.cudnn.benchmark = False
350
+ except Exception as e:
351
+ logging.warning(f"Could not set CUDA seed: {e}")
352
+
353
+ try:
354
+ # Load config
355
+ cfg = OmegaConf.load(cfg_path)
356
+ cfg.setdefault("vit_size", "vitl")
357
+ logging.info("Loaded config file")
358
+
359
+ # Create model
360
+ model = FoundationStereo(cfg).to(device)
361
+ model.eval()
362
+ logging.info("Created model")
363
+
364
+ # Load checkpoint
365
+ ckpt = torch.load(model_path, map_location=device)
366
+ model.load_state_dict(ckpt["model"], strict=True)
367
+ logging.info("Loaded model weights")
368
+
369
+ # Memory optimizations
370
+ torch.set_grad_enabled(False)
371
+ model.half() # Use half precision
372
+ logging.info("Applied memory optimizations")
373
+
374
+ return model, device
375
+
376
+ except Exception as e:
377
+ logging.error(f"Model loading failed: {e}")
378
+ raise RuntimeError(f"Failed to load model: {e}")
379
+
380
+
381
+ # Fixed with static duration
382
+ @spaces.GPU(duration=60) # Static 60 seconds for basic processing
383
+ def process_stereo_pair(model_selection: str, left_image: np.ndarray, right_image: np.ndarray,
384
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
385
+ """
386
+ Main processing function for stereo pair (with model caching)
387
+ """
388
+ logging.info("Starting stereo pair processing...")
389
+
390
+ if left_image is None or right_image is None:
391
+ return None, "❌ Please upload both left and right images."
392
+
393
+ try:
394
+ # Import these inside to avoid early CUDA calls
395
+ logging.info("Importing required modules...")
396
+ from core.utils.utils import InputPadder
397
+ # Import vis_disparity safely - it shouldn't have CUDA calls but be careful
398
+ from Utils import vis_disparity
399
+ logging.info("✅ Successfully imported processing modules")
400
+
401
+ # Get cached model (will load if not cached or selection changed)
402
+ variant_name = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else model_selection
403
+ progress(0.1, desc=f"Loading cached model ({variant_name})...")
404
+ logging.info("🚀 Getting cached model...")
405
+ model, device = get_cached_model(model_selection)
406
+ logging.info("✅ Cached model loaded successfully")
407
+
408
+ progress(0.2, desc="Preprocessing images...")
409
+
410
+ # Validate input images
411
+ if left_image.shape != right_image.shape:
412
+ return None, "❌ Left and right images must have the same dimensions."
413
+
414
+ H, W = left_image.shape[:2]
415
+
416
+ # Convert to torch tensors and ensure they are contiguous
417
+ img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
418
+ img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
419
+
420
+ # Pad images and ensure contiguity
421
+ padder = InputPadder(img0.shape, divis_by=32, force_square=False)
422
+ img0, img1 = padder.pad(img0, img1)
423
+
424
+ # Ensure padded tensors are contiguous
425
+ img0 = img0.contiguous()
426
+ img1 = img1.contiguous()
427
+
428
+ progress(0.5, desc="Running inference...")
429
+
430
+ # Process stereo pair with autocast and ensure clean memory state
431
+ torch.cuda.empty_cache() # Clear any cached memory before inference
432
+
433
+ try:
434
+ with torch.amp.autocast("cuda", enabled=True):
435
+ # Ensure tensors are in the right format for cuDNN
436
+ if not img0.is_contiguous():
437
+ img0 = img0.contiguous()
438
+ if not img1.is_contiguous():
439
+ img1 = img1.contiguous()
440
+
441
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
442
+ except RuntimeError as e:
443
+ if "cuDNN" in str(e):
444
+ # Fallback: disable cuDNN optimizations and retry
445
+ logging.warning(f"cuDNN error encountered, retrying with fallback: {e}")
446
+ torch.backends.cudnn.enabled = False
447
+ try:
448
+ with torch.amp.autocast("cuda", enabled=True):
449
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
450
+ finally:
451
+ torch.backends.cudnn.enabled = True # Re-enable for future use
452
+ else:
453
+ raise e
454
+
455
+ # Unpad and convert to numpy
456
+ disp = padder.unpad(disp.float())
457
+ disp_cpu = disp.data.cpu().numpy().reshape(H, W)
458
+
459
+ progress(0.8, desc="Creating visualization...")
460
+
461
+ # Create visualization - ONLY disparity
462
+ disparity_vis = vis_disparity(disp_cpu)
463
+ result_image = disparity_vis
464
+
465
+ progress(1.0, desc="Complete!")
466
+
467
+ # Clean up intermediate tensors
468
+ del img0, img1, disp
469
+
470
+ # For ZeroGPU: Clean up model after inference
471
+ del model
472
+ torch.cuda.empty_cache()
473
+ gc.collect()
474
+
475
+ # Create status message
476
+ valid_mask = disp_cpu != np.inf
477
+ min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
478
+ max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
479
+ mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
480
+
481
+ # Get model variant for status
482
+ variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
483
+
484
+ # Check current memory usage (safely within GPU context)
485
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
486
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
487
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
488
+
489
+ status = f"""✅ Processing successful!
490
+ 🔧 Model: {variant} (ZeroGPU){memory_info}
491
+ 📊 Disparity Statistics:
492
+ • Range: {min_disp:.2f} - {max_disp:.2f}
493
+ • Mean: {mean_disp:.2f}
494
+ • Input size: {W}×{H}
495
+ • Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
496
+
497
+ return result_image, status
498
+
499
+ except Exception as e:
500
+ logging.error(f"Processing failed: {e}")
501
+ # Cleanup on error
502
+ if 'img0' in locals():
503
+ del img0
504
+ if 'img1' in locals():
505
+ del img1
506
+ if 'disp' in locals():
507
+ del disp
508
+ if 'model' in locals():
509
+ del model
510
+ # Clean up GPU memory
511
+ torch.cuda.empty_cache()
512
+ gc.collect()
513
+ return None, f"❌ Error: {str(e)}"
514
+
515
+
516
+ # Fixed with static duration
517
+ @spaces.GPU(duration=120) # Static 120 seconds for depth processing
518
+ def process_with_depth(model_selection: str, left_image: np.ndarray, right_image: np.ndarray,
519
+ camera_matrix: str, baseline: float,
520
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
521
+ """
522
+ Process stereo pair and generate depth map and point cloud (with model caching)
523
+ """
524
+ # Import these inside to avoid early CUDA calls
525
+ from core.utils.utils import InputPadder
526
+ # Import vis_disparity safely within GPU context
527
+ from Utils import vis_disparity
528
+
529
+ # Import Open3D inside GPU context
530
+ global OPEN3D_AVAILABLE
531
+ try:
532
+ import open3d as o3d
533
+ OPEN3D_AVAILABLE = True
534
+ except ImportError as e:
535
+ logging.warning(f"Open3D not available: {e}")
536
+ OPEN3D_AVAILABLE = False
537
+ return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
538
+
539
+ if left_image is None or right_image is None:
540
+ return None, None, None, "❌ Please upload both left and right images."
541
+
542
+ try:
543
+ progress(0.1, desc="Parsing camera parameters...")
544
+
545
+ # Parse camera matrix
546
+ try:
547
+ K_values = list(map(float, camera_matrix.strip().split()))
548
+ if len(K_values) != 9:
549
+ return None, None, None, "❌ Camera matrix must contain exactly 9 values."
550
+ K = np.array(K_values).reshape(3, 3)
551
+ except ValueError:
552
+ return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
553
+
554
+ if baseline <= 0:
555
+ return None, None, None, "❌ Baseline must be positive."
556
+
557
+ variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
558
+ progress(0.2, desc=f"Loading cached model ({variant})...")
559
+
560
+ # Get cached model (will load if not cached or selection changed)
561
+ model, device = get_cached_model(model_selection)
562
+
563
+ progress(0.4, desc="Running stereo inference...")
564
+
565
+ # Get disparity using the same process as the basic function
566
+ H, W = left_image.shape[:2]
567
+ img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
568
+ img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
569
+
570
+ padder = InputPadder(img0.shape, divis_by=32, force_square=False)
571
+ img0, img1 = padder.pad(img0, img1)
572
+
573
+ # Ensure padded tensors are contiguous
574
+ img0 = img0.contiguous()
575
+ img1 = img1.contiguous()
576
+
577
+ # Clear cache and ensure clean memory state before inference
578
+ torch.cuda.empty_cache()
579
+
580
+ try:
581
+ with torch.amp.autocast("cuda", enabled=True):
582
+ # Double-check tensor contiguity before cuDNN operations
583
+ if not img0.is_contiguous():
584
+ img0 = img0.contiguous()
585
+ if not img1.is_contiguous():
586
+ img1 = img1.contiguous()
587
+
588
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
589
+ except RuntimeError as e:
590
+ if "cuDNN" in str(e):
591
+ # Fallback: disable cuDNN optimizations and retry
592
+ logging.warning(f"cuDNN error encountered in depth processing, retrying with fallback: {e}")
593
+ torch.backends.cudnn.enabled = False
594
+ try:
595
+ with torch.amp.autocast("cuda", enabled=True):
596
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
597
+ finally:
598
+ torch.backends.cudnn.enabled = True # Re-enable for future use
599
+ else:
600
+ raise e
601
+
602
+ disp = padder.unpad(disp.float())
603
+ disp_cpu = disp.data.cpu().numpy().reshape(H, W)
604
+
605
+ # Clean up intermediate tensors early
606
+ del img0, img1, disp
607
+
608
+ # For ZeroGPU: Keep model reference for rest of processing
609
+ torch.cuda.empty_cache()
610
+
611
+ progress(0.6, desc="Converting to depth...")
612
+
613
+ # Remove invisible points (same as in original demo)
614
+ yy, xx = np.meshgrid(np.arange(disp_cpu.shape[0]), np.arange(disp_cpu.shape[1]), indexing='ij')
615
+ us_right = xx - disp_cpu
616
+ invalid = us_right < 0
617
+ disp_cpu[invalid] = np.inf
618
+
619
+ # Convert to depth using the formula from the original demo
620
+ depth = K[0, 0] * baseline / disp_cpu
621
+
622
+ # Visualize depth (no rotation)
623
+ depth_vis = vis_disparity(depth, max_val=10.0)
624
+
625
+ progress(0.8, desc="Generating point cloud...")
626
+
627
+ # Generate point cloud with proper coordinate transformation
628
+ fx, fy = K[0, 0], K[1, 1]
629
+ cx, cy = K[0, 2], K[1, 2]
630
+
631
+ # Create coordinate meshgrids
632
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
633
+
634
+ # Convert to 3D coordinates (proper camera coordinate system)
635
+ valid_depth = depth != np.inf
636
+ z = depth[valid_depth] # Z coordinate (depth)
637
+ x = (u[valid_depth] - cx) * z / fx # X coordinate
638
+ y = (v[valid_depth] - cy) * z / fy # Y coordinate
639
+
640
+ # Stack coordinates (X, Y, Z)
641
+ points = np.stack([x, y, z], axis=-1)
642
+
643
+ # Get corresponding colors
644
+ colors = left_image[valid_depth]
645
+
646
+ # Filter points by depth range
647
+ depth_mask = (z > 0) & (z <= 10.0)
648
+ valid_points = points[depth_mask]
649
+ valid_colors = colors[depth_mask]
650
+
651
+ if len(valid_points) == 0:
652
+ return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
653
+
654
+ # Subsample points for better 3D visualization performance
655
+ if len(valid_points) > 100000:
656
+ indices = np.random.choice(len(valid_points), 100000, replace=False)
657
+ valid_points = valid_points[indices]
658
+ valid_colors = valid_colors[indices]
659
+
660
+ # Transform coordinates for proper visualization orientation
661
+ # Standard computer vision: X right, Y down, Z forward
662
+ # For better 3D viewing: X right, Y up, Z backward
663
+ transformed_points = valid_points.copy()
664
+ transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
665
+ transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
666
+
667
+ # Create point cloud using transformed coordinates
668
+ pcd = o3d.geometry.PointCloud()
669
+ pcd.points = o3d.utility.Vector3dVector(transformed_points)
670
+ pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
671
+
672
+ # Save point cloud for download (.ply)
673
+ temp_ply_file = tempfile.NamedTemporaryFile(delete=False, suffix='.ply')
674
+ o3d.io.write_point_cloud(temp_ply_file.name, pcd)
675
+
676
+ # Create OBJ file for 3D visualization (better Gradio compatibility)
677
+ temp_obj_file = tempfile.NamedTemporaryFile(delete=False, suffix='.obj')
678
+
679
+ # Write OBJ file with proper vertex colors
680
+ with open(temp_obj_file.name, 'w') as f:
681
+ f.write("# Point cloud generated from stereo depth\n")
682
+ f.write(f"# Total points: {len(valid_points)}\n")
683
+
684
+ # Write vertices with RGB colors (0-1 range)
685
+ for i, (point, color) in enumerate(zip(transformed_points, valid_colors)):
686
+ # Ensure colors are in 0-1 range
687
+ r, g, b = np.clip(color / 255.0, 0, 1)
688
+ f.write(f"v {point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r:.6f} {g:.6f} {b:.6f}\n")
689
+
690
+ progress(1.0, desc="Complete!")
691
+
692
+ # For ZeroGPU: Clean up model after inference
693
+ del model
694
+ torch.cuda.empty_cache()
695
+ gc.collect()
696
+
697
+ # Check current memory usage (safely within GPU context)
698
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
699
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
700
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
701
+
702
+ status = f"""✅ Depth processing successful!
703
+ 🔧 Model: {variant} (ZeroGPU){memory_info}
704
+ 📊 Statistics:
705
+ • Valid points: {len(valid_points):,}
706
+ • Depth range: {z.min():.2f} - {z.max():.2f} m
707
+ • Baseline: {baseline} m
708
+ • Point cloud saved with {len(valid_points)} points
709
+ • 3D visualization ready (corrected orientation)"""
710
+
711
+ return depth_vis, temp_ply_file.name, temp_obj_file.name, status
712
+
713
+ except Exception as e:
714
+ logging.error(f"Depth processing failed: {e}")
715
+ # Cleanup on error
716
+ if 'img0' in locals():
717
+ del img0
718
+ if 'img1' in locals():
719
+ del img1
720
+ if 'disp' in locals():
721
+ del disp
722
+ if 'model' in locals():
723
+ del model
724
+ # Clean up GPU memory
725
+ torch.cuda.empty_cache()
726
+ gc.collect()
727
+ return None, None, None, f"❌ Error: {str(e)}"
728
+
729
+
730
+ def preload_all_models():
731
+ """Pre-download all Hugging Face models to cache during startup"""
732
+ logging.info("🔄 Pre-downloading all models to cache...")
733
+
734
+ downloaded_models = {}
735
+
736
+ for variant, info in MODEL_VARIANTS.items():
737
+ try:
738
+ logging.info(f"📥 Downloading {variant} model to cache...")
739
+ model_path, config_path = download_model_from_hf(variant, force_download=False)
740
+ downloaded_models[variant] = {
741
+ "model_path": model_path,
742
+ "config_path": config_path,
743
+ "display_name": info["display_name"]
744
+ }
745
+ logging.info(f"✅ {variant} model cached successfully")
746
+ except Exception as e:
747
+ logging.warning(f"⚠️ Failed to download {variant} model: {e}")
748
+ # Continue with other models even if one fails
749
+
750
+ logging.info(f"✅ Model pre-loading complete. {len(downloaded_models)}/{len(MODEL_VARIANTS)} models cached.")
751
+ return downloaded_models
752
+
753
+
754
+ def create_app() -> gr.Blocks:
755
+ """Create the Gradio application"""
756
+
757
+ global MODEL_PATH, CONFIG_PATH
758
+
759
+ # Debug: Print current directory and check for files
760
+ print(f"Current directory: {current_dir}")
761
+ print(f"Python working directory: {os.getcwd()}")
762
+
763
+ # Pre-download all models to cache
764
+ try:
765
+ cached_models = preload_all_models()
766
+ logging.info(f"Pre-loaded {len(cached_models)} models to cache")
767
+ except Exception as e:
768
+ logging.error(f"Failed to pre-load models: {e}")
769
+ cached_models = {}
770
+
771
+ # Get available models (this should be safe as it only does file system operations)
772
+ try:
773
+ available_models = get_available_models()
774
+ logging.info(f"Successfully got available models: {len(available_models)} found")
775
+ except Exception as e:
776
+ logging.error(f"Failed to get available models: {e}")
777
+ available_models = {}
778
+
779
+ # Find model and config paths (legacy) - should be safe as well
780
+ try:
781
+ MODEL_PATH, CONFIG_PATH = find_model_path()
782
+ logging.info("Successfully found model paths")
783
+ except Exception as e:
784
+ logging.error(f"Failed to find model paths: {e}")
785
+ MODEL_PATH, CONFIG_PATH = None, None
786
+
787
+ with gr.Blocks(
788
+ title="FoundationStereo - Stereo Depth Estimation",
789
+ theme=gr.themes.Soft(),
790
+ css="footer {visibility: hidden}",
791
+ delete_cache=(60, 60) # Delete cache after 60 seconds for ZeroGPU
792
+ ) as app:
793
+
794
+ gr.Markdown("""
795
+ # 🔍 FoundationStereo: Zero-Shot Stereo Matching
796
+
797
+ Upload a pair of **rectified** stereo images to get disparity estimation.
798
+
799
+ ⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
800
+ ⚡ **ZeroGPU Powered**: Runs on high-performance A100 GPUs for fast inference.
801
+ 📦 **Smart Caching**: All models are pre-downloaded for instant model switching.
802
+ """)
803
+
804
+ # Instructions section
805
+ with gr.Accordion("📋 Instructions to Run This Repository", open=False):
806
+ gr.Markdown("""
807
+ ## 🚀 How to Run This Demo
808
+ This is a **demo application** showcasing the FoundationStereo model for stereo matching estimation.
809
+
810
+ ### 🖼️ Input Requirements
811
+
812
+ 1. **Image Format**: Upload images in JPEG or PNG format.
813
+ 2. **Image Size**: Images should be of the same size and resolution.
814
+ 3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
815
+ 4. **Camera Parameters**: For advanced processing, provide camera parameters (camera matrix and baseline).
816
+
817
+ ### 📊 Using the Demo
818
+
819
+ 1. **Select Model**: Choose between low-cost (11-33-40) or high-quality (23-51-11) variants
820
+ 2. **Upload Images**: Provide rectified stereo image pairs
821
+ 3. **Basic Processing**: Get disparity visualization
822
+ 4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
823
+
824
+ ### Original Work
825
+
826
+ This demo is based on the original FoundationStereo research. Please visit the official resources:
827
+ - **Paper**: [FoundationStereo: Zero-Shot Stereo Matching via Foundation Model](https://arxiv.org/abs/2501.09898)
828
+ - **Project Page**: [https://nvlabs.github.io/FoundationStereo/](https://nvlabs.github.io/FoundationStereo/)
829
+ - **Official Repository**: [https://github.com/NVlabs/FoundationStereo](https://github.com/NVlabs/FoundationStereo)
830
+
831
+ **⚠️ Demo Notice**: This is a demonstration interface. For research and production use, please refer to the original repository and follow the official implementation guidelines.
832
+ """)
833
+
834
+ # Model selection
835
+ with gr.Row():
836
+ # Always include Hugging Face models in the choices
837
+ all_choices = list(available_models.keys())
838
+
839
+ # If no models found, add the HF models manually
840
+ if not all_choices:
841
+ all_choices = [
842
+ "FoundationStereo (Low-cost variant - 11-33-40) [Hugging Face]",
843
+ "FoundationStereo (High-quality variant - 23-51-11) [Hugging Face]"
844
+ ]
845
+
846
+ # Get default model (prefer Hugging Face low-cost variant)
847
+ default_model = None
848
+
849
+ # First try Hugging Face low-cost variant
850
+ for name in all_choices:
851
+ if "11-33-40" in name and "[Hugging Face]" in name:
852
+ default_model = name
853
+ break
854
+
855
+ # If no HF low-cost variant, try any low-cost variant
856
+ if default_model is None:
857
+ for name in all_choices:
858
+ if "11-33-40" in name:
859
+ default_model = name
860
+ break
861
+
862
+ # If no low-cost variant, use first available
863
+ if default_model is None:
864
+ default_model = all_choices[0] if all_choices else None
865
+
866
+ model_selector = gr.Dropdown(
867
+ choices=all_choices,
868
+ value=default_model,
869
+ label="🎯 Select Model",
870
+ info="Choose the FoundationStereo model variant. Hugging Face models download automatically.",
871
+ interactive=True
872
+ )
873
+
874
+ with gr.Tabs():
875
+ # Basic stereo processing tab
876
+ with gr.TabItem("🖼️ Basic Stereo Processing"):
877
+ with gr.Row():
878
+ with gr.Column():
879
+ left_input = gr.Image(
880
+ label="📷 Left Image",
881
+ type="numpy",
882
+ height=300
883
+ )
884
+ right_input = gr.Image(
885
+ label="📷 Right Image",
886
+ type="numpy",
887
+ height=300
888
+ )
889
+
890
+ process_btn = gr.Button(
891
+ "🚀 Process Stereo Pair",
892
+ variant="primary",
893
+ size="lg"
894
+ )
895
+
896
+ with gr.Column():
897
+ output_image = gr.Image(
898
+ label="📊 Disparity Visualization",
899
+ height=400
900
+ )
901
+ status_text = gr.Textbox(
902
+ label="Status",
903
+ interactive=False,
904
+ lines=8
905
+ )
906
+
907
+ # Example images
908
+ examples_list = []
909
+
910
+ # Example 1
911
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
912
+ examples_list.append([
913
+ os.path.join(current_dir, "assets", "example1", "left.png"),
914
+ os.path.join(current_dir, "assets", "example1", "right.png")
915
+ ])
916
+
917
+ # Example 2
918
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
919
+ examples_list.append([
920
+ os.path.join(current_dir, "assets", "example2", "left.png"),
921
+ os.path.join(current_dir, "assets", "example2", "right.png")
922
+ ])
923
+
924
+
925
+
926
+ gr.Examples(
927
+ examples=examples_list,
928
+ inputs=[left_input, right_input],
929
+ label="📋 Example Images"
930
+ )
931
+
932
+ # Advanced processing with depth
933
+ with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
934
+ with gr.Row():
935
+ with gr.Column():
936
+ left_input_adv = gr.Image(
937
+ label="📷 Left Image",
938
+ type="numpy",
939
+ height=250
940
+ )
941
+ right_input_adv = gr.Image(
942
+ label="📷 Right Image",
943
+ type="numpy",
944
+ height=250
945
+ )
946
+
947
+ # Camera parameters
948
+ with gr.Group():
949
+ gr.Markdown("### 📹 Camera Parameters")
950
+ camera_matrix_input = gr.Textbox(
951
+ label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
952
+ value="",
953
+
954
+ )
955
+ baseline_input = gr.Number(
956
+ label="Baseline (meters)",
957
+ value=None,
958
+ minimum=0.001,
959
+ maximum=10.0,
960
+ step=0.001
961
+ )
962
+
963
+ process_depth_btn = gr.Button(
964
+ "🔬 Process with Depth",
965
+ variant="primary",
966
+ size="lg"
967
+ )
968
+
969
+ with gr.Column():
970
+ depth_output = gr.Image(
971
+ label="📏 Depth Visualization",
972
+ height=300
973
+ )
974
+ pointcloud_output = gr.File(
975
+ label="☁️ Point Cloud Download (.ply)",
976
+ file_types=[".ply"]
977
+ )
978
+ status_depth = gr.Textbox(
979
+ label="Status",
980
+ interactive=False,
981
+ lines=6
982
+ )
983
+
984
+ # 3D Point Cloud Visualization
985
+ with gr.Row():
986
+ pointcloud_3d = gr.Model3D(
987
+ label="🌐 3D Point Cloud Viewer",
988
+ clear_color=[0.0, 0.0, 0.0, 0.0],
989
+ height=400
990
+ )
991
+
992
+ # Example images for advanced processing
993
+ examples_advanced_list = []
994
+
995
+ # Example 1 - Camera parameters from K.txt
996
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
997
+ examples_advanced_list.append([
998
+ os.path.join(current_dir, "assets", "example1", "left.png"),
999
+ os.path.join(current_dir, "assets", "example1", "right.png"),
1000
+ "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0", # Camera matrix
1001
+ 0.063 # Baseline in meters
1002
+ ])
1003
+
1004
+ # Example 2 - Camera parameters from K.txt
1005
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
1006
+ examples_advanced_list.append([
1007
+ os.path.join(current_dir, "assets", "example2", "left.png"),
1008
+ os.path.join(current_dir, "assets", "example2", "right.png"),
1009
+ "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0", # Camera matrix
1010
+ 0.537 # Baseline in meters (converted from 536.62mm)
1011
+ ])
1012
+
1013
+
1014
+
1015
+ gr.Examples(
1016
+ examples=examples_advanced_list,
1017
+ inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
1018
+ label="📋 Example Images with Camera Parameters"
1019
+ )
1020
+
1021
+ # Event handlers - Always enable since we have HF models
1022
+ process_btn.click(
1023
+ fn=process_stereo_pair,
1024
+ inputs=[model_selector, left_input, right_input],
1025
+ outputs=[output_image, status_text],
1026
+ show_progress=True
1027
+ )
1028
+
1029
+ if OPEN3D_AVAILABLE:
1030
+ process_depth_btn.click(
1031
+ fn=process_with_depth,
1032
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
1033
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
1034
+ show_progress=True
1035
+ )
1036
+ else:
1037
+ process_depth_btn.click(
1038
+ fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
1039
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
1040
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
1041
+ )
1042
+
1043
+ # Citation section at the bottom
1044
+ with gr.Accordion("📖 Citation", open=False):
1045
+ gr.Markdown("""
1046
+ ### 📄 Please Cite the Original Paper
1047
+
1048
+ If you use this work in your research, please cite:
1049
+
1050
+ ```bibtex
1051
+ @article{wen2025stereo,
1052
+ title={FoundationStereo: Zero-Shot Stereo Matching},
1053
+ author={Bowen Wen and Matthew Trepte and Joseph Aribido and Jan Kautz and Orazio Gallo and Stan Birchfield},
1054
+ journal={CVPR},
1055
+ year={2025}
1056
+ }
1057
+ ```
1058
+ """)
1059
+
1060
+ # Footer
1061
+ gr.Markdown(f"""
1062
+ ---
1063
+ ### 📝 Notes:
1064
+ - **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
1065
+ - **🤗 Hugging Face Integration**: Models are automatically downloaded from `{HF_REPO_ID}`
1066
+ - **📦 Smart Caching**: All models are pre-downloaded and cached for instant switching
1067
+ - **⚡ ZeroGPU Acceleration**: Powered by high-performance A100 GPUs
1068
+ - For best results, use PNG images without lossy compression
1069
+ - Model works on RGB images but also supports monochrome/IR stereo pairs
1070
+ - **Optimized for Spaces**: Memory-efficient inference on shared infrastructure
1071
+
1072
+ ### 🔗 References:
1073
+ - [FoundationStereo Paper](https://arxiv.org/abs/2501.09898)
1074
+ - [Project Website](https://nvlabs.github.io/FoundationStereo/)
1075
+ - [GitHub Repository](https://github.com/NVlabs/FoundationStereo)
1076
+ - [Hugging Face Models]({f"https://huggingface.co/{HF_REPO_ID}"})
1077
+ """)
1078
+
1079
+ return app
1080
+
1081
+
1082
+ def main():
1083
+ """Main function to launch the app"""
1084
+
1085
+ # Ensure no CUDA operations during startup
1086
+ if torch.cuda.is_available():
1087
+ logging.warning("CUDA detected during startup - this should not happen in ZeroGPU")
1088
+
1089
+ logging.info("🚀 Starting FoundationStereo Gradio App...")
1090
+
1091
+ # Parse command line arguments
1092
+ import argparse
1093
+ parser = argparse.ArgumentParser(description="FoundationStereo Gradio App")
1094
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
1095
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
1096
+ parser.add_argument("--share", action="store_true", help="Create shareable link")
1097
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
1098
+
1099
+ args = parser.parse_args()
1100
+
1101
+ if args.debug:
1102
+ logging.getLogger().setLevel(logging.DEBUG)
1103
+
1104
+ try:
1105
+ # Create and launch app
1106
+ logging.info("Creating Gradio app...")
1107
+ app = create_app()
1108
+ logging.info("✅ Gradio app created successfully")
1109
+
1110
+ logging.info(f"Launching app on {args.host}:{args.port}")
1111
+ if args.share:
1112
+ logging.info("Share link will be created")
1113
+
1114
+ # For ZeroGPU compatibility, launch with appropriate settings
1115
+ app.launch(
1116
+ server_name=args.host,
1117
+ server_port=args.port,
1118
+ share=args.share,
1119
+ show_error=True,
1120
+ favicon_path=None,
1121
+ ssr_mode=False, # Disable SSR for ZeroGPU compatibility
1122
+ allowed_paths=["./"] # Allow access to local files
1123
+ )
1124
+ except Exception as e:
1125
+ logging.error(f"Failed to launch app: {e}")
1126
+ raise
1127
+
1128
+
1129
+ if __name__ == "__main__":
1130
+ # Additional safety check for ZeroGPU environment
1131
+ if 'SPACE_ID' in os.environ:
1132
+ logging.info("Running in Hugging Face Spaces environment")
1133
+
1134
+ # Do not check CUDA status during startup - this can trigger CUDA initialization
1135
+ # The CUDA status will be checked inside the @spaces.GPU decorated functions
1136
+ logging.info("✅ CUDA status will be checked within GPU-decorated functions")
1137
+
1138
+ main()
FoundationStereo_demo/app_local.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import tempfile
5
+ import zipfile
6
+ import gc
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union
9
+ import numpy as np
10
+ import cv2
11
+ import gradio as gr
12
+ import imageio
13
+
14
+ import torch
15
+
16
+ # Set default tensor type if needed
17
+ # torch.set_default_tensor_type('torch.FloatTensor')
18
+
19
+ # Import other safe modules
20
+ from omegaconf import OmegaConf
21
+ from huggingface_hub import hf_hub_download, snapshot_download
22
+
23
+ # CUDA backend settings
24
+ # torch.backends.cudnn.enabled = False
25
+ # torch.backends.cudnn.benchmark = False
26
+
27
+ # Use current directory as base (gradio_app folder)
28
+ current_dir = os.path.dirname(os.path.abspath(__file__))
29
+ base_dir = current_dir # gradio_app folder
30
+
31
+ # Add current directory to path for local imports
32
+ sys.path.insert(0, current_dir)
33
+
34
+ # DO NOT import any local modules here that might use CUDA
35
+ # All local module imports will be done inside functions
36
+
37
+ # Import Open3D with error handling
38
+ OPEN3D_AVAILABLE = False
39
+ try:
40
+ # Set Open3D to CPU mode to avoid CUDA initialization
41
+ os.environ['OPEN3D_CPU_RENDERING'] = '1'
42
+ # Don't import open3d here - do it inside functions
43
+ # import open3d as o3d
44
+ OPEN3D_AVAILABLE = True # Assume available, will check later
45
+ except Exception as e:
46
+ logging.warning(f"Open3D setup failed: {e}")
47
+ OPEN3D_AVAILABLE = False
48
+
49
+ # Configure logging
50
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
51
+
52
+ # Hugging Face model repository configuration
53
+ HF_REPO_ID = "shriarul5273/FoundationStereo_models"
54
+ MODEL_VARIANTS = {
55
+ "11-33-40": {
56
+ "display_name": "FoundationStereo (Low-cost variant - 11-33-40)",
57
+ "model_file": "pretrained_models/11-33-40/model_best_bp2.pth",
58
+ "config_file": "pretrained_models/11-33-40/cfg.yaml"
59
+ },
60
+ "23-51-11": {
61
+ "display_name": "FoundationStereo (High-quality variant - 23-51-11)",
62
+ "model_file": "pretrained_models/23-51-11/model_best_bp2.pth",
63
+ "config_file": "pretrained_models/23-51-11/cfg.yaml"
64
+ }
65
+ }
66
+
67
+ # Global variables for model caching
68
+ MODEL_PATH: str = None
69
+ CONFIG_PATH: str = None
70
+
71
+ # Model cache to avoid reloading when selection doesn't change
72
+ _cached_model = None
73
+ _cached_device = None
74
+ _cached_model_selection = None
75
+
76
+
77
+ def aggressive_cleanup():
78
+ """Perform basic cleanup"""
79
+ import gc
80
+ gc.collect()
81
+ logging.info("Performed basic memory cleanup")
82
+
83
+
84
+ def check_gpu_memory():
85
+ """Check and log current GPU memory usage"""
86
+ try:
87
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
88
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
89
+ max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
90
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
91
+
92
+ logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
93
+ return allocated, reserved, max_allocated, total
94
+ except RuntimeError as e:
95
+ logging.warning(f"Failed to get GPU memory info: {e}")
96
+ return None, None, None, None
97
+
98
+
99
+ def download_model_from_hf(variant: str, force_download: bool = False) -> Tuple[str, str]:
100
+ """
101
+ Download model and config files from Hugging Face Hub
102
+
103
+ Args:
104
+ variant: Model variant ("11-33-40" or "23-51-11")
105
+ force_download: Force re-download even if files exist locally
106
+
107
+ Returns:
108
+ Tuple of (model_path, config_path)
109
+ """
110
+ if variant not in MODEL_VARIANTS:
111
+ raise ValueError(f"Unknown model variant: {variant}. Available: {list(MODEL_VARIANTS.keys())}")
112
+
113
+ variant_info = MODEL_VARIANTS[variant]
114
+
115
+ try:
116
+ if not force_download:
117
+ logging.info(f"📦 Checking cache for model variant: {variant}")
118
+ else:
119
+ logging.info(f"🔄 Force downloading model variant: {variant}")
120
+
121
+ # Download model file
122
+ model_path = hf_hub_download(
123
+ repo_id=HF_REPO_ID,
124
+ filename=variant_info["model_file"],
125
+ force_download=force_download,
126
+ local_dir_use_symlinks=False
127
+ )
128
+
129
+ # Download config file
130
+ config_path = hf_hub_download(
131
+ repo_id=HF_REPO_ID,
132
+ filename=variant_info["config_file"],
133
+ force_download=force_download,
134
+ local_dir_use_symlinks=False
135
+ )
136
+
137
+ if force_download:
138
+ logging.info(f"✅ Successfully downloaded {variant} model files")
139
+ else:
140
+ logging.info(f"✅ Successfully loaded {variant} model files from cache")
141
+
142
+ logging.debug(f"Model: {model_path}")
143
+ logging.debug(f"Config: {config_path}")
144
+
145
+ return model_path, config_path
146
+
147
+ except Exception as e:
148
+ logging.error(f"Failed to download model {variant}: {e}")
149
+ raise RuntimeError(f"Failed to download model {variant} from Hugging Face: {e}")
150
+
151
+
152
+ def get_available_models() -> dict:
153
+ """Get all available models with their display names and download info"""
154
+ models = {}
155
+
156
+ # First check local models (legacy support)
157
+ search_dirs = [
158
+ os.path.join(current_dir, "pretrained_models"),
159
+ os.path.join(os.path.dirname(current_dir), "pretrained_models")
160
+ ]
161
+
162
+ for search_dir in search_dirs:
163
+ if os.path.exists(search_dir):
164
+ for model_dir in os.listdir(search_dir):
165
+ model_path = os.path.join(search_dir, model_dir, "model_best_bp2.pth")
166
+ cfg_path = os.path.join(search_dir, model_dir, "cfg.yaml")
167
+
168
+ if os.path.exists(model_path) and os.path.exists(cfg_path):
169
+ # Create a descriptive name for the model
170
+ if model_dir == "11-33-40":
171
+ display_name = "FoundationStereo (Low-cost variant - 11-33-40) [Local]"
172
+ elif model_dir == "23-51-11":
173
+ display_name = "FoundationStereo (High-quality variant - 23-51-11) [Local]"
174
+ else:
175
+ display_name = f"FoundationStereo ({model_dir}) [Local]"
176
+
177
+ models[display_name] = {
178
+ "model_path": model_path,
179
+ "config_path": cfg_path,
180
+ "variant": model_dir,
181
+ "source": "local"
182
+ }
183
+
184
+ # Add Hugging Face models
185
+ for variant, info in MODEL_VARIANTS.items():
186
+ display_name = f"{info['display_name']} [Hugging Face]"
187
+ models[display_name] = {
188
+ "model_path": None, # Will be downloaded when needed
189
+ "config_path": None, # Will be downloaded when needed
190
+ "variant": variant,
191
+ "source": "huggingface"
192
+ }
193
+
194
+ return models
195
+
196
+
197
+ def find_model_path() -> Tuple[Optional[str], Optional[str]]:
198
+ """Find available model and config paths (legacy function for backward compatibility)"""
199
+ models = get_available_models()
200
+ if models:
201
+ # Prefer Hugging Face models over local ones
202
+ # First try to find HF low-cost variant
203
+ for display_name in models:
204
+ if "11-33-40" in display_name and "[Hugging Face]" in display_name:
205
+ return get_model_paths_from_selection(display_name)
206
+
207
+ # Then try local low-cost variant
208
+ for display_name in models:
209
+ if "11-33-40" in display_name:
210
+ return get_model_paths_from_selection(display_name)
211
+
212
+ # If no low-cost variant, return the first available
213
+ first_model_name = next(iter(models.keys()))
214
+ return get_model_paths_from_selection(first_model_name)
215
+ return None, None
216
+
217
+
218
+ def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[str]]:
219
+ """Get model and config paths from the selected model"""
220
+ models = get_available_models()
221
+
222
+ # Check if it's in our models dict
223
+ if model_selection in models:
224
+ model_info = models[model_selection]
225
+
226
+ # If it's a Hugging Face model, download it first (or get from cache)
227
+ if model_info["source"] == "huggingface":
228
+ variant = model_info["variant"]
229
+ try:
230
+ logging.info(f"📦 Retrieving {variant} model from cache...")
231
+ model_path, config_path = download_model_from_hf(variant, force_download=False)
232
+ return model_path, config_path
233
+ except Exception as e:
234
+ logging.error(f"Failed to get model {variant} from cache: {e}")
235
+ return None, None
236
+ else:
237
+ # Local model
238
+ logging.info(f"📁 Using local model: {model_selection}")
239
+ return model_info["model_path"], model_info["config_path"]
240
+
241
+ # Handle direct HF model selection (fallback)
242
+ elif "[Hugging Face]" in model_selection:
243
+ if "11-33-40" in model_selection:
244
+ variant = "11-33-40"
245
+ elif "23-51-11" in model_selection:
246
+ variant = "23-51-11"
247
+ else:
248
+ logging.error(f"Unknown HF model variant in: {model_selection}")
249
+ return None, None
250
+
251
+ try:
252
+ logging.info(f"📦 Retrieving {variant} model from cache...")
253
+ model_path, config_path = download_model_from_hf(variant, force_download=False)
254
+ return model_path, config_path
255
+ except Exception as e:
256
+ logging.error(f"Failed to get model {variant} from cache: {e}")
257
+ return None, None
258
+
259
+ return None, None
260
+
261
+
262
+ def get_cached_model(model_selection: str):
263
+ """Get cached model or load new one if selection changed"""
264
+ global _cached_model, _cached_device, _cached_model_selection
265
+
266
+ # Get model paths from selection
267
+ model_path, config_path = get_model_paths_from_selection(model_selection)
268
+
269
+ if model_path is None or config_path is None:
270
+ raise ValueError(f"Selected model not found: {model_selection}")
271
+
272
+ # Load model fresh for each inference
273
+ # Since models are pre-downloaded, this should be fast
274
+ logging.info(f"🚀 Loading cached model: {model_selection}")
275
+ model, device = load_model_for_inference(model_path, config_path)
276
+
277
+ logging.info(f"✅ Model loaded successfully from cache: {model_selection}")
278
+ return model, device
279
+
280
+
281
+ def clear_model_cache():
282
+ """Clear the cached model to free memory"""
283
+ global _cached_model, _cached_device, _cached_model_selection
284
+
285
+ if _cached_model is not None:
286
+ logging.info("Clearing model cache...")
287
+ del _cached_model
288
+ _cached_model = None
289
+ _cached_device = None
290
+ _cached_model_selection = None
291
+
292
+ # Simple cleanup
293
+ import gc
294
+ gc.collect()
295
+ logging.info("Model cache cleared")
296
+ else:
297
+ logging.info("No model in cache to clear")
298
+
299
+
300
+ def load_model_for_inference(model_path: str, cfg_path: str):
301
+ """Load model temporarily for inference"""
302
+ # Set CUDA settings
303
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
304
+ torch.backends.cudnn.enabled = True
305
+ torch.backends.cudnn.benchmark = True
306
+
307
+ # Import required modules
308
+ try:
309
+ # Import selectively to avoid CUDA calls in Utils
310
+ from core.foundation_stereo import FoundationStereo
311
+ from omegaconf import OmegaConf
312
+ logging.info("Successfully imported required modules")
313
+
314
+ # Import set_logging_format safely
315
+ from Utils import set_logging_format
316
+ set_logging_format()
317
+
318
+ # Manual seed setting to avoid CUDA calls in Utils.set_seed
319
+ import random
320
+ random_seed = 0
321
+ np.random.seed(random_seed)
322
+ random.seed(random_seed)
323
+ torch.manual_seed(random_seed)
324
+ # CUDA seeding will be done after device is available
325
+
326
+ logging.info("Set logging format and seed")
327
+ except Exception as e:
328
+ logging.error(f"Failed to import modules: {e}")
329
+ raise RuntimeError(f"Import failed: {e}")
330
+
331
+ # Check if CUDA is available
332
+ if not torch.cuda.is_available():
333
+ raise RuntimeError("CUDA is not available.")
334
+
335
+ # Use the first available CUDA device
336
+ device = torch.device("cuda")
337
+
338
+ # Set CUDA seed
339
+ try:
340
+ torch.cuda.manual_seed_all(random_seed)
341
+ torch.backends.cudnn.deterministic = True
342
+ torch.backends.cudnn.benchmark = False
343
+ except Exception as e:
344
+ logging.warning(f"Could not set CUDA seed: {e}")
345
+
346
+ try:
347
+ # Load config
348
+ cfg = OmegaConf.load(cfg_path)
349
+ cfg.setdefault("vit_size", "vitl")
350
+ logging.info("Loaded config file")
351
+
352
+ # Create model
353
+ model = FoundationStereo(cfg).to(device)
354
+ model.eval()
355
+ logging.info("Created model")
356
+
357
+ # Load checkpoint
358
+ ckpt = torch.load(model_path, map_location=device)
359
+ model.load_state_dict(ckpt["model"], strict=True)
360
+ logging.info("Loaded model weights")
361
+
362
+ # Memory optimizations
363
+ torch.set_grad_enabled(False)
364
+ model.half() # Use half precision
365
+ logging.info("Applied memory optimizations")
366
+
367
+ return model, device
368
+
369
+ except Exception as e:
370
+ logging.error(f"Model loading failed: {e}")
371
+ raise RuntimeError(f"Failed to load model: {e}")
372
+
373
+
374
+ def process_stereo_pair(model_selection: str, left_image: str, right_image: str,
375
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
376
+ """
377
+ Main processing function for stereo pair (with model caching)
378
+ """
379
+ logging.info("Starting stereo pair processing...")
380
+
381
+ if left_image is None or right_image is None:
382
+ return None, "❌ Please upload both left and right images."
383
+
384
+ # Convert image paths to numpy arrays
385
+ logging.info(f"Loading images: left={left_image}, right={right_image}")
386
+
387
+ try:
388
+ # Load left image
389
+ if left_image is None:
390
+ return None, "❌ Please upload a left image."
391
+
392
+ # Check if file exists first
393
+ if not os.path.exists(left_image):
394
+ logging.error(f"Left image file does not exist: {left_image}")
395
+ return None, f"❌ Left image file not found: {left_image}"
396
+
397
+ logging.info(f"Loading left image from: {left_image}")
398
+ left_img = None
399
+
400
+ # Try multiple loading methods
401
+ try:
402
+ # Method 1: OpenCV
403
+ left_img = cv2.imread(left_image)
404
+ if left_img is not None:
405
+ left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
406
+ logging.info("Left image loaded with OpenCV")
407
+ except Exception as e:
408
+ logging.warning(f"OpenCV failed for left image: {e}")
409
+
410
+ if left_img is None:
411
+ try:
412
+ # Method 2: PIL
413
+ from PIL import Image
414
+ with Image.open(left_image) as pil_img:
415
+ left_img = np.array(pil_img.convert('RGB'))
416
+ logging.info("Left image loaded with PIL")
417
+ except Exception as e:
418
+ logging.warning(f"PIL failed for left image: {e}")
419
+
420
+ if left_img is None:
421
+ try:
422
+ # Method 3: imageio
423
+ left_img = imageio.imread(left_image)
424
+ if len(left_img.shape) == 3 and left_img.shape[2] == 4:
425
+ # RGBA to RGB
426
+ left_img = left_img[:, :, :3]
427
+ logging.info("Left image loaded with imageio")
428
+ except Exception as e:
429
+ logging.warning(f"imageio failed for left image: {e}")
430
+
431
+ if left_img is None:
432
+ return None, f"❌ Failed to load left image with any method: {left_image}"
433
+
434
+ # Load right image
435
+ if right_image is None:
436
+ return None, "❌ Please upload a right image."
437
+
438
+ # Check if file exists first
439
+ if not os.path.exists(right_image):
440
+ logging.error(f"Right image file does not exist: {right_image}")
441
+ return None, f"❌ Right image file not found: {right_image}"
442
+
443
+ logging.info(f"Loading right image from: {right_image}")
444
+ right_img = None
445
+
446
+ # Try multiple loading methods
447
+ try:
448
+ # Method 1: OpenCV
449
+ right_img = cv2.imread(right_image)
450
+ if right_img is not None:
451
+ right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
452
+ logging.info("Right image loaded with OpenCV")
453
+ except Exception as e:
454
+ logging.warning(f"OpenCV failed for right image: {e}")
455
+
456
+ if right_img is None:
457
+ try:
458
+ # Method 2: PIL
459
+ from PIL import Image
460
+ with Image.open(right_image) as pil_img:
461
+ right_img = np.array(pil_img.convert('RGB'))
462
+ logging.info("Right image loaded with PIL")
463
+ except Exception as e:
464
+ logging.warning(f"PIL failed for right image: {e}")
465
+
466
+ if right_img is None:
467
+ try:
468
+ # Method 3: imageio
469
+ right_img = imageio.imread(right_image)
470
+ if len(right_img.shape) == 3 and right_img.shape[2] == 4:
471
+ # RGBA to RGB
472
+ right_img = right_img[:, :, :3]
473
+ logging.info("Right image loaded with imageio")
474
+ except Exception as e:
475
+ logging.warning(f"imageio failed for right image: {e}")
476
+
477
+ if right_img is None:
478
+ return None, f"❌ Failed to load right image with any method: {right_image}"
479
+
480
+ # Update variables
481
+ left_image = left_img
482
+ right_image = right_img
483
+
484
+ logging.info(f"Images loaded successfully - Left: {left_image.shape}, Right: {right_image.shape}")
485
+
486
+ except Exception as e:
487
+ logging.error(f"Failed to load images: {e}")
488
+ return None, f"❌ Failed to load images: {str(e)}"
489
+
490
+ try:
491
+ # Import these inside to avoid early CUDA calls
492
+ logging.info("Importing required modules...")
493
+ from core.utils.utils import InputPadder
494
+ # Import vis_disparity safely - it shouldn't have CUDA calls but be careful
495
+ from Utils import vis_disparity
496
+ logging.info("✅ Successfully imported processing modules")
497
+
498
+ # Get cached model (will load if not cached or selection changed)
499
+ variant_name = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else model_selection
500
+ progress(0.1, desc=f"Loading cached model ({variant_name})...")
501
+ logging.info("🚀 Getting cached model...")
502
+ model, device = get_cached_model(model_selection)
503
+ logging.info("✅ Cached model loaded successfully")
504
+
505
+ progress(0.2, desc="Preprocessing images...")
506
+
507
+ # Validate input images
508
+ if left_image.shape != right_image.shape:
509
+ return None, "❌ Left and right images must have the same dimensions."
510
+
511
+ H, W = left_image.shape[:2]
512
+
513
+ # Convert to torch tensors and ensure they are contiguous
514
+ img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
515
+ img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
516
+
517
+ # Pad images and ensure contiguity
518
+ padder = InputPadder(img0.shape, divis_by=32, force_square=False)
519
+ img0, img1 = padder.pad(img0, img1)
520
+
521
+ # Ensure padded tensors are contiguous
522
+ img0 = img0.contiguous()
523
+ img1 = img1.contiguous()
524
+
525
+ progress(0.5, desc="Running inference...")
526
+
527
+ # Process stereo pair with autocast and ensure clean memory state
528
+ torch.cuda.empty_cache() # Clear any cached memory before inference
529
+
530
+ try:
531
+ with torch.amp.autocast("cuda", enabled=True):
532
+ # Ensure tensors are in the right format for cuDNN
533
+ if not img0.is_contiguous():
534
+ img0 = img0.contiguous()
535
+ if not img1.is_contiguous():
536
+ img1 = img1.contiguous()
537
+
538
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
539
+ except RuntimeError as e:
540
+ if "cuDNN" in str(e):
541
+ # Fallback: disable cuDNN optimizations and retry
542
+ logging.warning(f"cuDNN error encountered, retrying with fallback: {e}")
543
+ torch.backends.cudnn.enabled = False
544
+ try:
545
+ with torch.amp.autocast("cuda", enabled=True):
546
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
547
+ finally:
548
+ torch.backends.cudnn.enabled = True # Re-enable for future use
549
+ else:
550
+ raise e
551
+
552
+ # Unpad and convert to numpy
553
+ disp = padder.unpad(disp.float())
554
+ disp_cpu = disp.data.cpu().numpy().reshape(H, W)
555
+
556
+ progress(0.8, desc="Creating visualization...")
557
+
558
+ # Create visualization - ONLY disparity
559
+ disparity_vis = vis_disparity(disp_cpu)
560
+ result_image = disparity_vis
561
+
562
+ progress(1.0, desc="Complete!")
563
+
564
+ # Clean up intermediate tensors
565
+ del img0, img1, disp
566
+
567
+ # Clean up model after inference
568
+ del model
569
+ torch.cuda.empty_cache()
570
+ gc.collect()
571
+
572
+ # Create status message
573
+ valid_mask = disp_cpu != np.inf
574
+ min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
575
+ max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
576
+ mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
577
+
578
+ # Get model variant for status
579
+ variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
580
+
581
+ # Check current memory usage
582
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
583
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
584
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
585
+
586
+ status = f"""✅ Processing successful!
587
+ 🔧 Model: {variant}{memory_info}
588
+ 📊 Disparity Statistics:
589
+ • Range: {min_disp:.2f} - {max_disp:.2f}
590
+ • Mean: {mean_disp:.2f}
591
+ • Input size: {W}×{H}
592
+ • Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
593
+
594
+ return result_image, status
595
+
596
+ except Exception as e:
597
+ logging.error(f"Processing failed: {e}")
598
+ # Cleanup on error
599
+ if 'img0' in locals():
600
+ del img0
601
+ if 'img1' in locals():
602
+ del img1
603
+ if 'disp' in locals():
604
+ del disp
605
+ if 'model' in locals():
606
+ del model
607
+ # Clean up GPU memory
608
+ torch.cuda.empty_cache()
609
+ gc.collect()
610
+ return None, f"❌ Error: {str(e)}"
611
+
612
+
613
+ def process_with_depth(model_selection: str, left_image: str, right_image: str,
614
+ camera_matrix: str, baseline: float,
615
+ progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
616
+ """
617
+ Process stereo pair and generate depth map and point cloud (with model caching)
618
+ """
619
+ from core.utils.utils import InputPadder
620
+ from Utils import vis_disparity
621
+
622
+ # Import Open3D
623
+ global OPEN3D_AVAILABLE
624
+ try:
625
+ import open3d as o3d
626
+ OPEN3D_AVAILABLE = True
627
+ except ImportError as e:
628
+ logging.warning(f"Open3D not available: {e}")
629
+ OPEN3D_AVAILABLE = False
630
+ return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
631
+
632
+ if left_image is None or right_image is None:
633
+ return None, None, None, "❌ Please upload both left and right images."
634
+
635
+ # Convert image paths to numpy arrays
636
+ logging.info(f"Loading images: left={left_image}, right={right_image}")
637
+
638
+ try:
639
+ # Load left image
640
+ if left_image is None:
641
+ return None, None, None, "❌ Left image is None."
642
+ if not os.path.exists(left_image):
643
+ return None, None, None, f"❌ Left image file does not exist: {left_image}"
644
+ left_img = None
645
+ # Try OpenCV
646
+ try:
647
+ left_img = cv2.imread(left_image)
648
+ if left_img is not None:
649
+ left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
650
+ except Exception as e:
651
+ logging.warning(f"OpenCV failed for left image: {e}")
652
+ # Try PIL if OpenCV fails
653
+ if left_img is None:
654
+ try:
655
+ from PIL import Image
656
+ left_img = np.array(Image.open(left_image).convert('RGB'))
657
+ except Exception as e:
658
+ logging.warning(f"PIL failed for left image: {e}")
659
+ # Try imageio if PIL fails
660
+ if left_img is None:
661
+ try:
662
+ import imageio
663
+ left_img = imageio.imread(left_image)
664
+ if left_img.ndim == 2:
665
+ left_img = np.stack([left_img]*3, axis=-1)
666
+ elif left_img.shape[2] == 4:
667
+ left_img = left_img[..., :3]
668
+ except Exception as e:
669
+ logging.warning(f"imageio failed for left image: {e}")
670
+ if left_img is None:
671
+ return None, None, None, f"❌ Could not load left image: {left_image}"
672
+
673
+ # Load right image
674
+ if right_image is None:
675
+ return None, None, None, "❌ Right image is None."
676
+ if not os.path.exists(right_image):
677
+ return None, None, None, f"❌ Right image file does not exist: {right_image}"
678
+ right_img = None
679
+ # Try OpenCV
680
+ try:
681
+ right_img = cv2.imread(right_image)
682
+ if right_img is not None:
683
+ right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
684
+ except Exception as e:
685
+ logging.warning(f"OpenCV failed for right image: {e}")
686
+ # Try PIL if OpenCV fails
687
+ if right_img is None:
688
+ try:
689
+ from PIL import Image
690
+ right_img = np.array(Image.open(right_image).convert('RGB'))
691
+ except Exception as e:
692
+ logging.warning(f"PIL failed for right image: {e}")
693
+ # Try imageio if PIL fails
694
+ if right_img is None:
695
+ try:
696
+ import imageio
697
+ right_img = imageio.imread(right_image)
698
+ if right_img.ndim == 2:
699
+ right_img = np.stack([right_img]*3, axis=-1)
700
+ elif right_img.shape[2] == 4:
701
+ right_img = right_img[..., :3]
702
+ except Exception as e:
703
+ logging.warning(f"imageio failed for right image: {e}")
704
+ if right_img is None:
705
+ return None, None, None, f"❌ Could not load right image: {right_image}"
706
+
707
+ # Update variables
708
+ left_image = left_img
709
+ right_image = right_img
710
+
711
+ logging.info(f"Images loaded successfully - Left: {left_image.shape}, Right: {right_image.shape}")
712
+
713
+ except Exception as e:
714
+ logging.error(f"Failed to load images: {e}")
715
+ return None, None, None, f"❌ Failed to load images: {str(e)}"
716
+
717
+ try:
718
+ progress(0.1, desc="Parsing camera parameters...")
719
+
720
+ # Parse camera matrix
721
+ try:
722
+ K_values = list(map(float, camera_matrix.strip().split()))
723
+ if len(K_values) != 9:
724
+ return None, None, None, "❌ Camera matrix must contain exactly 9 values."
725
+ K = np.array(K_values).reshape(3, 3)
726
+ except ValueError:
727
+ return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
728
+
729
+ if baseline <= 0:
730
+ return None, None, None, "❌ Baseline must be positive."
731
+
732
+ variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
733
+ progress(0.2, desc=f"Loading cached model ({variant})...")
734
+
735
+ # Get cached model (will load if not cached or selection changed)
736
+ model, device = get_cached_model(model_selection)
737
+
738
+ progress(0.4, desc="Running stereo inference...")
739
+
740
+ # Get disparity using the same process as the basic function
741
+ H, W = left_image.shape[:2]
742
+ img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
743
+ img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
744
+
745
+ padder = InputPadder(img0.shape, divis_by=32, force_square=False)
746
+ img0, img1 = padder.pad(img0, img1)
747
+
748
+ # Ensure padded tensors are contiguous
749
+ img0 = img0.contiguous()
750
+ img1 = img1.contiguous()
751
+
752
+ # Clear cache and ensure clean memory state before inference
753
+ torch.cuda.empty_cache()
754
+
755
+ try:
756
+ with torch.amp.autocast("cuda", enabled=True):
757
+ # Double-check tensor contiguity before cuDNN operations
758
+ if not img0.is_contiguous():
759
+ img0 = img0.contiguous()
760
+ if not img1.is_contiguous():
761
+ img1 = img1.contiguous()
762
+
763
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
764
+ except RuntimeError as e:
765
+ if "cuDNN" in str(e):
766
+ # Fallback: disable cuDNN optimizations and retry
767
+ logging.warning(f"cuDNN error encountered in depth processing, retrying with fallback: {e}")
768
+ torch.backends.cudnn.enabled = False
769
+ try:
770
+ with torch.amp.autocast("cuda", enabled=True):
771
+ disp = model.forward(img0, img1, iters=32, test_mode=True)
772
+ finally:
773
+ torch.backends.cudnn.enabled = True # Re-enable for future use
774
+ else:
775
+ raise e
776
+
777
+ disp = padder.unpad(disp.float())
778
+ disp_cpu = disp.data.cpu().numpy().reshape(H, W)
779
+
780
+ # Clean up intermediate tensors early
781
+ del img0, img1, disp
782
+
783
+ # Keep model reference for rest of processing
784
+ torch.cuda.empty_cache()
785
+
786
+ progress(0.6, desc="Converting to depth...")
787
+
788
+ # Remove invisible points (same as in original demo)
789
+ yy, xx = np.meshgrid(np.arange(disp_cpu.shape[0]), np.arange(disp_cpu.shape[1]), indexing='ij')
790
+ us_right = xx - disp_cpu
791
+ invalid = us_right < 0
792
+ disp_cpu[invalid] = np.inf
793
+
794
+ # Convert to depth using the formula from the original demo
795
+ depth = K[0, 0] * baseline / disp_cpu
796
+
797
+ # Visualize depth (no rotation)
798
+ depth_vis = vis_disparity(depth, max_val=10.0)
799
+
800
+ progress(0.8, desc="Generating point cloud...")
801
+
802
+ # Generate point cloud with proper coordinate transformation
803
+ fx, fy = K[0, 0], K[1, 1]
804
+ cx, cy = K[0, 2], K[1, 2]
805
+
806
+ # Create coordinate meshgrids
807
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
808
+
809
+ # Convert to 3D coordinates (proper camera coordinate system)
810
+ valid_depth = depth != np.inf
811
+ z = depth[valid_depth] # Z coordinate (depth)
812
+ x = (u[valid_depth] - cx) * z / fx # X coordinate
813
+ y = (v[valid_depth] - cy) * z / fy # Y coordinate
814
+
815
+ # Stack coordinates (X, Y, Z)
816
+ points = np.stack([x, y, z], axis=-1)
817
+
818
+ # Get corresponding colors
819
+ colors = left_image[valid_depth]
820
+
821
+ # Filter points by depth range
822
+ depth_mask = (z > 0) & (z <= 10.0)
823
+ valid_points = points[depth_mask]
824
+ valid_colors = colors[depth_mask]
825
+
826
+ if len(valid_points) == 0:
827
+ return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
828
+
829
+ # Subsample points for better 3D visualization performance
830
+ if len(valid_points) > 100000:
831
+ indices = np.random.choice(len(valid_points), 100000, replace=False)
832
+ valid_points = valid_points[indices]
833
+ valid_colors = valid_colors[indices]
834
+
835
+ # Transform coordinates for proper visualization orientation
836
+ # Standard computer vision: X right, Y down, Z forward
837
+ # For better 3D viewing: X right, Y up, Z backward
838
+ transformed_points = valid_points.copy()
839
+ transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
840
+ transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
841
+
842
+ # Generate point cloud using transformed coordinates
843
+ pcd = o3d.geometry.PointCloud()
844
+ pcd.points = o3d.utility.Vector3dVector(transformed_points)
845
+ pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
846
+
847
+ progress(1.0, desc="Complete!")
848
+
849
+ # Clean up model after inference
850
+ del model
851
+ torch.cuda.empty_cache()
852
+ gc.collect()
853
+
854
+ # Check current memory usage
855
+ current_memory = torch.cuda.memory_allocated(0) / 1024**3
856
+ max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
857
+ memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
858
+
859
+ status = f"""✅ Depth processing successful!
860
+ 🔧 Model: {variant}{memory_info}
861
+ 📊 Statistics:
862
+ • Valid points: {len(valid_points):,}
863
+ • Depth range: {z.min():.2f} - {z.max():.2f} m
864
+ • Baseline: {baseline} m
865
+ • Point cloud generated with {len(valid_points)} points (not saved to file)
866
+ • 3D visualization available (in-memory)"""
867
+
868
+ return depth_vis, None, None, status
869
+
870
+ except Exception as e:
871
+ logging.error(f"Depth processing failed: {e}")
872
+ # Cleanup on error
873
+ if 'img0' in locals():
874
+ del img0
875
+ if 'img1' in locals():
876
+ del img1
877
+ if 'disp' in locals():
878
+ del disp
879
+ if 'model' in locals():
880
+ del model
881
+ # Clean up GPU memory
882
+ torch.cuda.empty_cache()
883
+ gc.collect()
884
+ return None, None, None, f"❌ Error: {str(e)}"
885
+
886
+
887
+ def preload_all_models():
888
+ """Pre-download all Hugging Face models to cache during startup"""
889
+ logging.info("🔄 Pre-downloading all models to cache...")
890
+
891
+ downloaded_models = {}
892
+
893
+ for variant, info in MODEL_VARIANTS.items():
894
+ try:
895
+ logging.info(f"📥 Downloading {variant} model to cache...")
896
+ model_path, config_path = download_model_from_hf(variant, force_download=False)
897
+ downloaded_models[variant] = {
898
+ "model_path": model_path,
899
+ "config_path": config_path,
900
+ "display_name": info["display_name"]
901
+ }
902
+ logging.info(f"✅ {variant} model cached successfully")
903
+ except Exception as e:
904
+ logging.warning(f"⚠️ Failed to download {variant} model: {e}")
905
+ # Continue with other models even if one fails
906
+
907
+ logging.info(f"✅ Model pre-loading complete. {len(downloaded_models)}/{len(MODEL_VARIANTS)} models cached.")
908
+ return downloaded_models
909
+
910
+
911
+ def create_app() -> gr.Blocks:
912
+ """Create the Gradio application"""
913
+
914
+ global MODEL_PATH, CONFIG_PATH
915
+
916
+ # Debug: Print current directory and check for files
917
+ print(f"Current directory: {current_dir}")
918
+ print(f"Python working directory: {os.getcwd()}")
919
+
920
+ # Pre-download all models to cache
921
+ try:
922
+ cached_models = preload_all_models()
923
+ logging.info(f"Pre-loaded {len(cached_models)} models to cache")
924
+ except Exception as e:
925
+ logging.error(f"Failed to pre-load models: {e}")
926
+ cached_models = {}
927
+
928
+ # Get available models (this should be safe as it only does file system operations)
929
+ try:
930
+ available_models = get_available_models()
931
+ logging.info(f"Successfully got available models: {len(available_models)} found")
932
+ except Exception as e:
933
+ logging.error(f"Failed to get available models: {e}")
934
+ available_models = {}
935
+
936
+ # Find model and config paths (legacy) - should be safe as well
937
+ try:
938
+ MODEL_PATH, CONFIG_PATH = find_model_path()
939
+ logging.info("Successfully found model paths")
940
+ except Exception as e:
941
+ logging.error(f"Failed to find model paths: {e}")
942
+ MODEL_PATH, CONFIG_PATH = None, None
943
+
944
+ with gr.Blocks(
945
+ title="FoundationStereo - Stereo Depth Estimation",
946
+ theme=gr.themes.Soft(),
947
+ css="footer {visibility: hidden}",
948
+ delete_cache=(60, 60) # Delete cache after 60 seconds
949
+ ) as app:
950
+
951
+ gr.Markdown("""
952
+ # 🔍 FoundationStereo: Zero-Shot Stereo Matching
953
+
954
+ Upload a pair of **rectified** stereo images to get disparity estimation.
955
+
956
+ ⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
957
+ ⚡ **GPU Powered**: Runs on high-performance GPUs for fast inference.
958
+ 📦 **Smart Caching**: All models are pre-downloaded for instant model switching.
959
+ """)
960
+
961
+ # Instructions section
962
+ with gr.Accordion("📋 Instructions to Run This Repository", open=False):
963
+ gr.Markdown("""
964
+ ## 🚀 How to Run This Demo
965
+ This is a **demo application** showcasing the FoundationStereo model for stereo matching estimation.
966
+
967
+ ### 🖼️ Input Requirements
968
+
969
+ 1. **Image Format**: Upload images in JPEG or PNG format.
970
+ 2. **Image Size**: Images should be of the same size and resolution.
971
+ 3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
972
+ 4. **Camera Parameters**: For advanced processing, provide camera parameters (camera matrix and baseline).
973
+
974
+ ### 📊 Using the Demo
975
+
976
+ 1. **Select Model**: Choose between low-cost (11-33-40) or high-quality (23-51-11) variants
977
+ 2. **Upload Images**: Provide rectified stereo image pairs
978
+ 3. **Basic Processing**: Get disparity visualization
979
+ 4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
980
+
981
+ ### Original Work
982
+
983
+ This demo is based on the original FoundationStereo research. Please visit the official resources:
984
+ - **Paper**: [FoundationStereo: Zero-Shot Stereo Matching via Foundation Model](https://arxiv.org/abs/2501.09898)
985
+ - **Project Page**: [https://nvlabs.github.io/FoundationStereo/](https://nvlabs.github.io/FoundationStereo/)
986
+ - **Official Repository**: [https://github.com/NVlabs/FoundationStereo](https://github.com/NVlabs/FoundationStereo)
987
+
988
+ **⚠️ Demo Notice**: This is a demonstration interface. For research and production use, please refer to the original repository and follow the official implementation guidelines.
989
+ """)
990
+
991
+ # Model selection
992
+ with gr.Row():
993
+ # Always include Hugging Face models in the choices
994
+ all_choices = list(available_models.keys())
995
+
996
+ # If no models found, add the HF models manually
997
+ if not all_choices:
998
+ all_choices = [
999
+ "FoundationStereo (Low-cost variant - 11-33-40) [Hugging Face]",
1000
+ "FoundationStereo (High-quality variant - 23-51-11) [Hugging Face]"
1001
+ ]
1002
+
1003
+ # Get default model (prefer Hugging Face low-cost variant)
1004
+ default_model = None
1005
+
1006
+ # First try Hugging Face low-cost variant
1007
+ for name in all_choices:
1008
+ if "11-33-40" in name and "[Hugging Face]" in name:
1009
+ default_model = name
1010
+ break
1011
+
1012
+ # If no HF low-cost variant, try any low-cost variant
1013
+ if default_model is None:
1014
+ for name in all_choices:
1015
+ if "11-33-40" in name:
1016
+ default_model = name
1017
+ break
1018
+
1019
+ # If no low-cost variant, use first available
1020
+ if default_model is None:
1021
+ default_model = all_choices[0] if all_choices else None
1022
+
1023
+ model_selector = gr.Dropdown(
1024
+ choices=all_choices,
1025
+ value=default_model,
1026
+ label="🎯 Select Model",
1027
+ info="Choose the FoundationStereo model variant. Hugging Face models download automatically.",
1028
+ interactive=True
1029
+ )
1030
+
1031
+ with gr.Tabs():
1032
+ # Basic stereo processing tab
1033
+ with gr.TabItem("🖼️ Basic Stereo Processing"):
1034
+ with gr.Row():
1035
+ with gr.Column():
1036
+ left_input = gr.Image(
1037
+ label="📷 Left Image",
1038
+ type="filepath",
1039
+ height=300
1040
+ )
1041
+ right_input = gr.Image(
1042
+ label="📷 Right Image",
1043
+ type="filepath",
1044
+ height=300
1045
+ )
1046
+
1047
+ process_btn = gr.Button(
1048
+ "🚀 Process Stereo Pair",
1049
+ variant="primary",
1050
+ size="lg"
1051
+ )
1052
+
1053
+ with gr.Column():
1054
+ output_image = gr.Image(
1055
+ label="📊 Disparity Visualization",
1056
+ height=400
1057
+ )
1058
+ status_text = gr.Textbox(
1059
+ label="Status",
1060
+ interactive=False,
1061
+ lines=8
1062
+ )
1063
+
1064
+ # Example images
1065
+ examples_list = []
1066
+
1067
+ # Example 1
1068
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
1069
+ examples_list.append([
1070
+ os.path.join(current_dir, "assets", "example1", "left.png"),
1071
+ os.path.join(current_dir, "assets", "example1", "right.png")
1072
+ ])
1073
+
1074
+ # Example 2
1075
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
1076
+ examples_list.append([
1077
+ os.path.join(current_dir, "assets", "example2", "left.png"),
1078
+ os.path.join(current_dir, "assets", "example2", "right.png")
1079
+ ])
1080
+
1081
+
1082
+
1083
+ gr.Examples(
1084
+ examples=examples_list,
1085
+ inputs=[left_input, right_input],
1086
+ label="📋 Example Images"
1087
+ )
1088
+
1089
+ # Advanced processing with depth
1090
+ with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
1091
+ with gr.Row():
1092
+ with gr.Column():
1093
+ left_input_adv = gr.Image(
1094
+ label="📷 Left Image",
1095
+ type="filepath",
1096
+ height=250
1097
+ )
1098
+ right_input_adv = gr.Image(
1099
+ label="📷 Right Image",
1100
+ type="filepath",
1101
+ height=250
1102
+ )
1103
+
1104
+ # Camera parameters
1105
+ with gr.Group():
1106
+ gr.Markdown("### 📹 Camera Parameters")
1107
+ camera_matrix_input = gr.Textbox(
1108
+ label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
1109
+ value="",
1110
+
1111
+ )
1112
+ baseline_input = gr.Number(
1113
+ label="Baseline (meters)",
1114
+ value=None,
1115
+ minimum=0.001,
1116
+ maximum=10.0,
1117
+ step=0.001
1118
+ )
1119
+
1120
+ process_depth_btn = gr.Button(
1121
+ "🔬 Process with Depth",
1122
+ variant="primary",
1123
+ size="lg"
1124
+ )
1125
+
1126
+ with gr.Column():
1127
+ depth_output = gr.Image(
1128
+ label="📏 Depth Visualization",
1129
+ height=300
1130
+ )
1131
+ pointcloud_output = gr.File(
1132
+ label="☁️ Point Cloud Download (.ply)",
1133
+ file_types=[".ply"]
1134
+ )
1135
+ status_depth = gr.Textbox(
1136
+ label="Status",
1137
+ interactive=False,
1138
+ lines=6
1139
+ )
1140
+
1141
+ # 3D Point Cloud Visualization
1142
+ with gr.Row():
1143
+ pointcloud_3d = gr.Model3D(
1144
+ label="🌐 3D Point Cloud Viewer",
1145
+ clear_color=[0.0, 0.0, 0.0, 0.0],
1146
+ height=400
1147
+ )
1148
+
1149
+ # Example images for advanced processing
1150
+ examples_advanced_list = []
1151
+
1152
+ # Example 1 - Camera parameters from K.txt
1153
+ if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
1154
+ examples_advanced_list.append([
1155
+ os.path.join(current_dir, "assets", "example1", "left.png"),
1156
+ os.path.join(current_dir, "assets", "example1", "right.png"),
1157
+ "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0", # Camera matrix
1158
+ 0.063 # Baseline in meters
1159
+ ])
1160
+
1161
+ # Example 2 - Camera parameters from K.txt
1162
+ if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
1163
+ examples_advanced_list.append([
1164
+ os.path.join(current_dir, "assets", "example2", "left.png"),
1165
+ os.path.join(current_dir, "assets", "example2", "right.png"),
1166
+ "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0", # Camera matrix
1167
+ 0.537 # Baseline in meters (converted from 536.62mm)
1168
+ ])
1169
+
1170
+
1171
+
1172
+ gr.Examples(
1173
+ examples=examples_advanced_list,
1174
+ inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
1175
+ label="📋 Example Images with Camera Parameters"
1176
+ )
1177
+
1178
+ # Event handlers - Always enable since we have HF models
1179
+ process_btn.click(
1180
+ fn=process_stereo_pair,
1181
+ inputs=[model_selector, left_input, right_input],
1182
+ outputs=[output_image, status_text],
1183
+ show_progress=True
1184
+ )
1185
+
1186
+ if OPEN3D_AVAILABLE:
1187
+ process_depth_btn.click(
1188
+ fn=process_with_depth,
1189
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
1190
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
1191
+ show_progress=True
1192
+ )
1193
+ else:
1194
+ process_depth_btn.click(
1195
+ fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
1196
+ inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
1197
+ outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
1198
+ )
1199
+
1200
+ # Citation section at the bottom
1201
+ with gr.Accordion("📖 Citation", open=False):
1202
+ gr.Markdown("""
1203
+ ### 📄 Please Cite the Original Paper
1204
+
1205
+ If you use this work in your research, please cite:
1206
+
1207
+ ```bibtex
1208
+ @article{wen2025stereo,
1209
+ title={FoundationStereo: Zero-Shot Stereo Matching},
1210
+ author={Bowen Wen and Matthew Trepte and Joseph Aribido and Jan Kautz and Orazio Gallo and Stan Birchfield},
1211
+ journal={CVPR},
1212
+ year={2025}
1213
+ }
1214
+ ```
1215
+ """)
1216
+
1217
+ # Footer
1218
+ gr.Markdown(f"""
1219
+ ---
1220
+ ### 📝 Notes:
1221
+ - **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
1222
+ - **🤗 Hugging Face Integration**: Models are automatically downloaded from `{HF_REPO_ID}`
1223
+ - **📦 Smart Caching**: All models are pre-downloaded and cached for instant switching
1224
+ - **⚡ GPU Acceleration**: Powered by high-performance GPUs
1225
+ - For best results, use PNG images without lossy compression
1226
+ - Model works on RGB images but also supports monochrome/IR stereo pairs
1227
+ - **Optimized for Performance**: Memory-efficient inference
1228
+
1229
+ ### 🔗 References:
1230
+ - [FoundationStereo Paper](https://arxiv.org/abs/2501.09898)
1231
+ - [Project Website](https://nvlabs.github.io/FoundationStereo/)
1232
+ - [GitHub Repository](https://github.com/NVlabs/FoundationStereo)
1233
+ - [Hugging Face Models]({f"https://huggingface.co/{HF_REPO_ID}"})
1234
+ """)
1235
+
1236
+ return app
1237
+
1238
+
1239
+ def main():
1240
+ """Main function to launch the app"""
1241
+
1242
+ # Ensure no CUDA operations during startup
1243
+ if torch.cuda.is_available():
1244
+ logging.warning("CUDA detected during startup")
1245
+
1246
+ logging.info("🚀 Starting FoundationStereo Gradio App...")
1247
+
1248
+ # Parse command line arguments
1249
+ import argparse
1250
+ parser = argparse.ArgumentParser(description="FoundationStereo Gradio App")
1251
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
1252
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
1253
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
1254
+
1255
+ args = parser.parse_args()
1256
+
1257
+ if args.debug:
1258
+ logging.getLogger().setLevel(logging.DEBUG)
1259
+
1260
+ try:
1261
+ # Create and launch app
1262
+ logging.info("Creating Gradio app...")
1263
+ app = create_app()
1264
+ logging.info("✅ Gradio app created successfully")
1265
+
1266
+ logging.info(f"Launching app on {args.host}:{args.port}")
1267
+
1268
+ # Launch with appropriate settings
1269
+ app.launch(
1270
+ server_name=args.host,
1271
+ server_port=args.port,
1272
+ share=False,
1273
+ show_error=True,
1274
+ favicon_path=None,
1275
+ ssr_mode=False, # Disable SSR for compatibility
1276
+ allowed_paths=["./"] # Allow access to local files
1277
+ )
1278
+ except Exception as e:
1279
+ logging.error(f"Failed to launch app: {e}")
1280
+ raise
1281
+
1282
+
1283
+ if __name__ == "__main__":
1284
+ # Additional safety check for Spaces environment
1285
+ if 'SPACE_ID' in os.environ:
1286
+ logging.info("Running in Hugging Face Spaces environment")
1287
+
1288
+ # Do not check CUDA status during startup - this can trigger CUDA initialization
1289
+ # The CUDA status will be checked inside the GPU decorated functions
1290
+ logging.info("✅ CUDA status will be checked within GPU functions")
1291
+
1292
+ main()
FoundationStereo_demo/core/extractor.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+ import torch,logging,os,sys,urllib,warnings
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ code_dir = os.path.dirname(os.path.realpath(__file__))
14
+ sys.path.append(f'{code_dir}/../')
15
+ from core.submodule import *
16
+ from Utils import *
17
+ import timm
18
+
19
+
20
+ class ResidualBlock(nn.Module):
21
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
22
+ super(ResidualBlock, self).__init__()
23
+
24
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
25
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
26
+ self.relu = nn.ReLU(inplace=True)
27
+
28
+ num_groups = planes // 8
29
+
30
+ if norm_fn == 'group':
31
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
32
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
33
+ if not (stride == 1 and in_planes == planes):
34
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
35
+
36
+ elif norm_fn == 'batch':
37
+ self.norm1 = nn.BatchNorm2d(planes)
38
+ self.norm2 = nn.BatchNorm2d(planes)
39
+ if not (stride == 1 and in_planes == planes):
40
+ self.norm3 = nn.BatchNorm2d(planes)
41
+
42
+ elif norm_fn == 'instance':
43
+ self.norm1 = nn.InstanceNorm2d(planes)
44
+ self.norm2 = nn.InstanceNorm2d(planes)
45
+ if not (stride == 1 and in_planes == planes):
46
+ self.norm3 = nn.InstanceNorm2d(planes)
47
+
48
+ elif norm_fn=='layer':
49
+ self.norm1 = LayerNorm2d(planes)
50
+ self.norm2 = LayerNorm2d(planes)
51
+ if not (stride == 1 and in_planes == planes):
52
+ self.norm3 = LayerNorm2d(planes)
53
+
54
+ elif norm_fn == 'none':
55
+ self.norm1 = nn.Sequential()
56
+ self.norm2 = nn.Sequential()
57
+ if not (stride == 1 and in_planes == planes):
58
+ self.norm3 = nn.Sequential()
59
+
60
+ if stride == 1 and in_planes == planes:
61
+ self.downsample = None
62
+
63
+ else:
64
+ self.downsample = nn.Sequential(
65
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
66
+
67
+
68
+ def forward(self, x):
69
+ y = x
70
+ y = self.conv1(y)
71
+ y = self.norm1(y)
72
+ y = self.relu(y)
73
+ y = self.conv2(y)
74
+ y = self.norm2(y)
75
+ y = self.relu(y)
76
+
77
+ if self.downsample is not None:
78
+ x = self.downsample(x)
79
+
80
+ return self.relu(x+y)
81
+
82
+
83
+
84
+ class MultiBasicEncoder(nn.Module):
85
+ def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
86
+ super(MultiBasicEncoder, self).__init__()
87
+ self.norm_fn = norm_fn
88
+ self.downsample = downsample
89
+
90
+ if self.norm_fn == 'group':
91
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
92
+
93
+ elif self.norm_fn == 'batch':
94
+ self.norm1 = nn.BatchNorm2d(64)
95
+
96
+ elif self.norm_fn == 'instance':
97
+ self.norm1 = nn.InstanceNorm2d(64)
98
+
99
+ elif self.norm_fn=='layer':
100
+ self.norm1 = LayerNorm2d(64)
101
+
102
+ elif self.norm_fn == 'none':
103
+ self.norm1 = nn.Sequential()
104
+
105
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
106
+ self.relu1 = nn.ReLU(inplace=True)
107
+
108
+ self.in_planes = 64
109
+ self.layer1 = self._make_layer(64, stride=1)
110
+ self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
111
+ self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
112
+ self.layer4 = self._make_layer(128, stride=2)
113
+ self.layer5 = self._make_layer(128, stride=2)
114
+
115
+ output_list = []
116
+
117
+ for dim in output_dim:
118
+ conv_out = nn.Sequential(
119
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
120
+ nn.Conv2d(128, dim[2], 3, padding=1))
121
+ output_list.append(conv_out)
122
+
123
+ self.outputs04 = nn.ModuleList(output_list)
124
+
125
+ output_list = []
126
+ for dim in output_dim:
127
+ conv_out = nn.Sequential(
128
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
129
+ nn.Conv2d(128, dim[1], 3, padding=1))
130
+ output_list.append(conv_out)
131
+
132
+ self.outputs08 = nn.ModuleList(output_list)
133
+
134
+ output_list = []
135
+ for dim in output_dim:
136
+ conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
137
+ output_list.append(conv_out)
138
+
139
+ self.outputs16 = nn.ModuleList(output_list)
140
+
141
+ if dropout > 0:
142
+ self.dropout = nn.Dropout2d(p=dropout)
143
+ else:
144
+ self.dropout = None
145
+
146
+ for m in self.modules():
147
+ if isinstance(m, nn.Conv2d):
148
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
149
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
150
+ if m.weight is not None:
151
+ nn.init.constant_(m.weight, 1)
152
+ if m.bias is not None:
153
+ nn.init.constant_(m.bias, 0)
154
+
155
+ def _make_layer(self, dim, stride=1):
156
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
157
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
158
+ layers = (layer1, layer2)
159
+
160
+ self.in_planes = dim
161
+ return nn.Sequential(*layers)
162
+
163
+ def forward(self, x, dual_inp=False, num_layers=3):
164
+
165
+ x = self.conv1(x)
166
+ x = self.norm1(x)
167
+ x = self.relu1(x)
168
+ x = self.layer1(x)
169
+ x = self.layer2(x)
170
+ x = self.layer3(x)
171
+ if dual_inp:
172
+ v = x
173
+ x = x[:(x.shape[0]//2)]
174
+
175
+ outputs04 = [f(x) for f in self.outputs04]
176
+ if num_layers == 1:
177
+ return (outputs04, v) if dual_inp else (outputs04,)
178
+
179
+ y = self.layer4(x)
180
+ outputs08 = [f(y) for f in self.outputs08]
181
+
182
+ if num_layers == 2:
183
+ return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08)
184
+
185
+ z = self.layer5(y)
186
+ outputs16 = [f(z) for f in self.outputs16]
187
+
188
+ return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16)
189
+
190
+
191
+
192
+ class ContextNetDino(MultiBasicEncoder):
193
+ def __init__(self, args, output_dim=[128], norm_fn='batch', downsample=3):
194
+ nn.Module.__init__(self)
195
+ self.args = args
196
+ self.patch_size = 14
197
+ self.image_size = 518
198
+ self.vit_feat_dim = 384
199
+ code_dir = os.path.dirname(os.path.realpath(__file__))
200
+
201
+ self.out_dims = output_dim
202
+
203
+ self.norm_fn = norm_fn
204
+
205
+ if self.norm_fn == 'group':
206
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
207
+
208
+ elif self.norm_fn == 'batch':
209
+ self.norm1 = nn.BatchNorm2d(64)
210
+
211
+ elif self.norm_fn == 'instance':
212
+ self.norm1 = nn.InstanceNorm2d(64)
213
+
214
+ elif self.norm_fn=='layer':
215
+ self.norm1 = LayerNorm2d(64)
216
+
217
+ elif self.norm_fn == 'none':
218
+ self.norm1 = nn.Sequential()
219
+
220
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
221
+ self.relu1 = nn.ReLU(inplace=True)
222
+
223
+ self.in_planes = 64
224
+ self.layer1 = self._make_layer(64, stride=1)
225
+ self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
226
+ self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
227
+ self.layer4 = self._make_layer(128, stride=2)
228
+ self.layer5 = self._make_layer(128, stride=2)
229
+ self.down = nn.Sequential(
230
+ nn.Conv2d(128, 128, kernel_size=4, stride=4, padding=0),
231
+ nn.BatchNorm2d(128),
232
+ )
233
+ vit_dim = DepthAnythingFeature.model_configs[self.args.vit_size]['features']//2
234
+ self.conv2 = BasicConv(128+vit_dim, 128, kernel_size=3, padding=1)
235
+ self.norm = nn.BatchNorm2d(256)
236
+
237
+ output_list = []
238
+ for dim in output_dim:
239
+ conv_out = nn.Sequential(
240
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
241
+ nn.Conv2d(128, dim[2], 3, padding=1))
242
+ output_list.append(conv_out)
243
+
244
+ self.outputs04 = nn.ModuleList(output_list)
245
+
246
+ output_list = []
247
+ for dim in output_dim:
248
+ conv_out = nn.Sequential(
249
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
250
+ nn.Conv2d(128, dim[1], 3, padding=1))
251
+ output_list.append(conv_out)
252
+
253
+ self.outputs08 = nn.ModuleList(output_list)
254
+
255
+ output_list = []
256
+ for dim in output_dim:
257
+ conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
258
+ output_list.append(conv_out)
259
+
260
+ self.outputs16 = nn.ModuleList(output_list)
261
+
262
+ def forward(self, x_in, vit_feat, dual_inp=False, num_layers=3):
263
+ B,C,H,W = x_in.shape
264
+ x = self.conv1(x_in)
265
+ x = self.norm1(x)
266
+ x = self.relu1(x)
267
+ x = self.layer1(x)
268
+ x = self.layer2(x)
269
+ x = self.layer3(x)
270
+
271
+ divider = np.lcm(self.patch_size, 16)
272
+ H_resize, W_resize = get_resize_keep_aspect_ratio(H,W, divider=divider, max_H=1344, max_W=1344)
273
+ x = torch.cat([x, vit_feat], dim=1)
274
+ x = self.conv2(x)
275
+ outputs04 = [f(x) for f in self.outputs04]
276
+
277
+ y = self.layer4(x)
278
+ outputs08 = [f(y) for f in self.outputs08]
279
+
280
+ z = self.layer5(y)
281
+ outputs16 = [f(z) for f in self.outputs16]
282
+
283
+ return (outputs04, outputs08, outputs16)
284
+
285
+
286
+ class DepthAnythingFeature(nn.Module):
287
+ model_configs = {
288
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
289
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
290
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
291
+ }
292
+
293
+ def __init__(self, encoder='vits'):
294
+ super().__init__()
295
+ from depth_anything.dpt import DepthAnything
296
+ self.encoder = encoder
297
+ depth_anything = DepthAnything(self.model_configs[encoder])
298
+ self.depth_anything = depth_anything
299
+
300
+ self.intermediate_layer_idx = { #!NOTE For V2
301
+ 'vits': [2, 5, 8, 11],
302
+ 'vitb': [2, 5, 8, 11],
303
+ 'vitl': [4, 11, 17, 23],
304
+ 'vitg': [9, 19, 29, 39]
305
+ }
306
+
307
+
308
+ def forward(self, x):
309
+ """
310
+ @x: (B,C,H,W)
311
+ """
312
+ h, w = x.shape[-2:]
313
+ features = self.depth_anything.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
314
+
315
+
316
+ patch_size = self.depth_anything.pretrained.patch_size
317
+ patch_h, patch_w = h // patch_size, w // patch_size
318
+ out, path_1, path_2, path_3, path_4, disp = self.depth_anything.depth_head.forward(features, patch_h, patch_w, return_intermediate=True)
319
+
320
+ return {'out':out, 'path_1':path_1, 'path_2':path_2, 'path_3':path_3, 'path_4':path_4, 'features':features, 'disp':disp} # path_1 is 1/2; path_2 is 1/4
321
+
322
+
323
+ class Feature(nn.Module):
324
+ def __init__(self, args):
325
+ super(Feature, self).__init__()
326
+ self.args = args
327
+ model = timm.create_model('edgenext_small', pretrained=True, features_only=False)
328
+ self.stem = model.stem
329
+ self.stages = model.stages
330
+ chans = [48, 96, 160, 304]
331
+ self.chans = chans
332
+ self.dino = DepthAnythingFeature(encoder=self.args.vit_size)
333
+ self.dino = freeze_model(self.dino)
334
+ vit_feat_dim = DepthAnythingFeature.model_configs[self.args.vit_size]['features']//2
335
+
336
+ self.deconv32_16 = Conv2x_IN(chans[3], chans[2], deconv=True, concat=True)
337
+ self.deconv16_8 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True)
338
+ self.deconv8_4 = Conv2x_IN(chans[1]*2, chans[0], deconv=True, concat=True)
339
+ self.conv4 = nn.Sequential(
340
+ BasicConv(chans[0]*2+vit_feat_dim, chans[0]*2+vit_feat_dim, kernel_size=3, stride=1, padding=1, norm='instance'),
341
+ ResidualBlock(chans[0]*2+vit_feat_dim, chans[0]*2+vit_feat_dim, norm_fn='instance'),
342
+ ResidualBlock(chans[0]*2+vit_feat_dim, chans[0]*2+vit_feat_dim, norm_fn='instance'),
343
+ )
344
+
345
+ self.patch_size = 14
346
+ self.d_out = [chans[0]*2+vit_feat_dim, chans[1]*2, chans[2]*2, chans[3]]
347
+
348
+ def forward(self, x):
349
+ B,C,H,W = x.shape
350
+ divider = np.lcm(self.patch_size, 16)
351
+ H_resize, W_resize = get_resize_keep_aspect_ratio(H,W, divider=divider, max_H=1344, max_W=1344)
352
+ x_in_ = F.interpolate(x, size=(H_resize, W_resize), mode='bicubic', align_corners=False)
353
+ self.dino = self.dino.eval()
354
+ with torch.no_grad():
355
+ output = self.dino(x_in_)
356
+ vit_feat = output['out']
357
+ vit_feat = F.interpolate(vit_feat, size=(H//4,W//4), mode='bilinear', align_corners=True)
358
+ x = self.stem(x)
359
+ x4 = self.stages[0](x)
360
+ x8 = self.stages[1](x4)
361
+ x16 = self.stages[2](x8)
362
+ x32 = self.stages[3](x16)
363
+
364
+ x16 = self.deconv32_16(x32, x16)
365
+ x8 = self.deconv16_8(x16, x8)
366
+ x4 = self.deconv8_4(x8, x4)
367
+ x4 = torch.cat([x4, vit_feat], dim=1)
368
+ x4 = self.conv4(x4)
369
+ return [x4, x8, x16, x32], vit_feat
370
+
371
+
FoundationStereo_demo/core/foundation_stereo.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+ import torch,pdb,logging,timm
11
+ import torchvision # Add missing torchvision import
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import sys,os
15
+ code_dir = os.path.dirname(os.path.realpath(__file__))
16
+ sys.path.append(f'{code_dir}/../')
17
+ from core.update import *
18
+ from core.extractor import *
19
+ from core.geometry import Combined_Geo_Encoding_Volume
20
+ from core.submodule import *
21
+ from core.utils.utils import *
22
+ from Utils import *
23
+ import time,huggingface_hub
24
+
25
+
26
+ try:
27
+ autocast = torch.cuda.amp.autocast
28
+ except:
29
+ class autocast:
30
+ def __init__(self, enabled):
31
+ pass
32
+ def __enter__(self):
33
+ pass
34
+ def __exit__(self, *args):
35
+ pass
36
+
37
+
38
+ def normalize_image(img):
39
+ '''
40
+ @img: (B,C,H,W) in range 0-255, RGB order
41
+ '''
42
+ tf = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
43
+ normalized = tf(img/255.0)
44
+ return normalized.contiguous() # Ensure contiguous tensor
45
+
46
+
47
+ class hourglass(nn.Module):
48
+ def __init__(self, cfg, in_channels, feat_dims=None):
49
+ super().__init__()
50
+ self.cfg = cfg
51
+ self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
52
+ padding=1, stride=2, dilation=1),
53
+ Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17))
54
+
55
+ self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
56
+ padding=1, stride=2, dilation=1),
57
+ Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17))
58
+
59
+ self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
60
+ padding=1, stride=2, dilation=1),
61
+ Conv3dNormActReduced(in_channels*6, in_channels*6, kernel_size=3, kernel_disp=17))
62
+
63
+
64
+ self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
65
+ relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
66
+
67
+ self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
68
+ relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
69
+
70
+ self.conv1_up = BasicConv(in_channels*2, in_channels, deconv=True, is_3d=True, bn=True,
71
+ relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
72
+ self.conv_out = nn.Sequential(
73
+ Conv3dNormActReduced(in_channels, in_channels, kernel_size=3, kernel_disp=17),
74
+ Conv3dNormActReduced(in_channels, in_channels, kernel_size=3, kernel_disp=17),
75
+ )
76
+
77
+ self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
78
+ Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17),
79
+ Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17),)
80
+
81
+ self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
82
+ Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17),
83
+ Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17))
84
+ self.atts = nn.ModuleDict({
85
+ "4": CostVolumeDisparityAttention(d_model=in_channels, nhead=4, dim_feedforward=in_channels, norm_first=False, num_transformer=4, max_len=self.cfg['max_disp']//16),
86
+ })
87
+ self.conv_patch = nn.Sequential(
88
+ nn.Conv3d(in_channels, in_channels, kernel_size=4, stride=4, padding=0, groups=in_channels),
89
+ nn.BatchNorm3d(in_channels),
90
+ )
91
+
92
+ self.feature_att_8 = FeatureAtt(in_channels*2, feat_dims[1])
93
+ self.feature_att_16 = FeatureAtt(in_channels*4, feat_dims[2])
94
+ self.feature_att_32 = FeatureAtt(in_channels*6, feat_dims[3])
95
+ self.feature_att_up_16 = FeatureAtt(in_channels*4, feat_dims[2])
96
+ self.feature_att_up_8 = FeatureAtt(in_channels*2, feat_dims[1])
97
+
98
+ def forward(self, x, features):
99
+ conv1 = self.conv1(x)
100
+ conv1 = self.feature_att_8(conv1, features[1])
101
+
102
+ conv2 = self.conv2(conv1)
103
+ conv2 = self.feature_att_16(conv2, features[2])
104
+
105
+ conv3 = self.conv3(conv2)
106
+ conv3 = self.feature_att_32(conv3, features[3])
107
+
108
+ conv3_up = self.conv3_up(conv3)
109
+ conv2 = torch.cat((conv3_up, conv2), dim=1)
110
+ conv2 = self.agg_0(conv2)
111
+ conv2 = self.feature_att_up_16(conv2, features[2])
112
+
113
+ conv2_up = self.conv2_up(conv2)
114
+ conv1 = torch.cat((conv2_up, conv1), dim=1)
115
+ conv1 = self.agg_1(conv1)
116
+ conv1 = self.feature_att_up_8(conv1, features[1])
117
+
118
+ conv = self.conv1_up(conv1)
119
+ x = self.conv_patch(x)
120
+ x = self.atts["4"](x)
121
+ x = F.interpolate(x, scale_factor=4, mode='trilinear', align_corners=False)
122
+ conv = conv + x
123
+ conv = self.conv_out(conv)
124
+
125
+ return conv
126
+
127
+
128
+
129
+ class FoundationStereo(nn.Module, huggingface_hub.PyTorchModelHubMixin):
130
+ def __init__(self, args):
131
+ super().__init__()
132
+ self.args = args
133
+
134
+ context_dims = args.hidden_dims
135
+ self.cv_group = 8
136
+ volume_dim = 28
137
+
138
+ self.cnet = ContextNetDino(args, output_dim=[args.hidden_dims, context_dims], downsample=args.n_downsample)
139
+ self.update_block = BasicSelectiveMultiUpdateBlock(self.args, self.args.hidden_dims[0], volume_dim=volume_dim)
140
+ self.sam = SpatialAttentionExtractor()
141
+ self.cam = ChannelAttentionEnhancement(self.args.hidden_dims[0])
142
+
143
+ self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, kernel_size=3, padding=3//2) for i in range(self.args.n_gru_layers)])
144
+
145
+ self.feature = Feature(args)
146
+ self.proj_cmb = nn.Conv2d(self.feature.d_out[0], 12, kernel_size=1, padding=0)
147
+
148
+ self.stem_2 = nn.Sequential(
149
+ BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
150
+ nn.Conv2d(32, 32, 3, 1, 1, bias=False),
151
+ nn.InstanceNorm2d(32), nn.ReLU()
152
+ )
153
+ self.stem_4 = nn.Sequential(
154
+ BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
155
+ nn.Conv2d(48, 48, 3, 1, 1, bias=False),
156
+ nn.InstanceNorm2d(48), nn.ReLU()
157
+ )
158
+
159
+
160
+ self.spx_2_gru = Conv2x(32, 32, True, bn=False)
161
+ self.spx_gru = nn.Sequential(
162
+ nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),
163
+ )
164
+
165
+
166
+ self.corr_stem = nn.Sequential(
167
+ nn.Conv3d(32, volume_dim, kernel_size=1),
168
+ BasicConv(volume_dim, volume_dim, kernel_size=3, padding=1, is_3d=True),
169
+ ResnetBasicBlock3D(volume_dim, volume_dim, kernel_size=3, stride=1, padding=1),
170
+ ResnetBasicBlock3D(volume_dim, volume_dim, kernel_size=3, stride=1, padding=1),
171
+ )
172
+ self.corr_feature_att = FeatureAtt(volume_dim, self.feature.d_out[0])
173
+ self.cost_agg = hourglass(cfg=self.args, in_channels=volume_dim, feat_dims=self.feature.d_out)
174
+ self.classifier = nn.Sequential(
175
+ BasicConv(volume_dim, volume_dim//2, kernel_size=3, padding=1, is_3d=True),
176
+ ResnetBasicBlock3D(volume_dim//2, volume_dim//2, kernel_size=3, stride=1, padding=1),
177
+ nn.Conv3d(volume_dim//2, 1, kernel_size=7, padding=3),
178
+ )
179
+
180
+ r = self.args.corr_radius
181
+ dx = torch.linspace(-r, r, 2*r+1, requires_grad=False).reshape(1, 1, 2*r+1, 1)
182
+ self.dx = dx
183
+
184
+
185
+ def upsample_disp(self, disp, mask_feat_4, stem_2x):
186
+
187
+ with autocast(enabled=self.args.mixed_precision):
188
+ xspx = self.spx_2_gru(mask_feat_4, stem_2x) # 1/2 resolution
189
+ spx_pred = self.spx_gru(xspx)
190
+ spx_pred = F.softmax(spx_pred, 1)
191
+ up_disp = context_upsample(disp*4., spx_pred).unsqueeze(1)
192
+
193
+ return up_disp.float()
194
+
195
+
196
+ def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False, low_memory=False, init_disp=None):
197
+ """ Estimate disparity between pair of frames """
198
+ B = len(image1)
199
+ low_memory = low_memory or (self.args.get('low_memory', False))
200
+ image1 = normalize_image(image1)
201
+ image2 = normalize_image(image2)
202
+ with autocast(enabled=self.args.mixed_precision):
203
+ out, vit_feat = self.feature(torch.cat([image1, image2], dim=0))
204
+ vit_feat = vit_feat[:B]
205
+ features_left = [o[:B] for o in out]
206
+ features_right = [o[B:] for o in out]
207
+ stem_2x = self.stem_2(image1)
208
+
209
+ gwc_volume = build_gwc_volume(features_left[0], features_right[0], self.args.max_disp//4, self.cv_group) # Group-wise correlation volume (B, N_group, max_disp, H, W)
210
+ left_tmp = self.proj_cmb(features_left[0])
211
+ right_tmp = self.proj_cmb(features_right[0])
212
+ concat_volume = build_concat_volume(left_tmp, right_tmp, maxdisp=self.args.max_disp//4)
213
+ del left_tmp, right_tmp
214
+ comb_volume = torch.cat([gwc_volume, concat_volume], dim=1)
215
+ comb_volume = self.corr_stem(comb_volume)
216
+ comb_volume = self.corr_feature_att(comb_volume, features_left[0])
217
+ comb_volume = self.cost_agg(comb_volume, features_left)
218
+
219
+ # Init disp from geometry encoding volume
220
+ prob = F.softmax(self.classifier(comb_volume).squeeze(1), dim=1) #(B, max_disp, H, W)
221
+ if init_disp is None:
222
+ init_disp = disparity_regression(prob, self.args.max_disp//4) # Weighted sum of disparity
223
+
224
+ cnet_list = self.cnet(image1, vit_feat=vit_feat, num_layers=self.args.n_gru_layers) #(1/4, 1/8, 1/16)
225
+ cnet_list = list(cnet_list)
226
+ net_list = [torch.tanh(x[0]) for x in cnet_list] # Hidden information
227
+ inp_list = [torch.relu(x[1]) for x in cnet_list] # Context information list of pyramid levels
228
+ inp_list = [self.cam(x) * x for x in inp_list]
229
+ att = [self.sam(x) for x in inp_list]
230
+
231
+ geo_fn = Combined_Geo_Encoding_Volume(features_left[0].float(), features_right[0].float(), comb_volume.float(), num_levels=self.args.corr_levels, dx=self.dx)
232
+ b, c, h, w = features_left[0].shape
233
+ coords = torch.arange(w, dtype=torch.float, device=init_disp.device).reshape(1,1,w,1).repeat(b, h, 1, 1) # (B,H,W,1) Horizontal only
234
+ disp = init_disp.float()
235
+ disp_preds = []
236
+
237
+ # GRUs iterations to update disparity (1/4 resolution)
238
+ for itr in range(iters):
239
+ disp = disp.detach()
240
+ geo_feat = geo_fn(disp, coords, low_memory=low_memory)
241
+ with autocast(enabled=self.args.mixed_precision):
242
+ net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, att)
243
+
244
+ disp = disp + delta_disp.float()
245
+ if test_mode and itr < iters-1:
246
+ continue
247
+
248
+ # upsample predictions
249
+ disp_up = self.upsample_disp(disp.float(), mask_feat_4.float(), stem_2x.float())
250
+ disp_preds.append(disp_up)
251
+
252
+
253
+ if test_mode:
254
+ return disp_up
255
+
256
+ return init_disp, disp_preds
257
+
258
+
259
+ def run_hierachical(self, image1, image2, iters=12, test_mode=False, low_memory=False, small_ratio=0.5):
260
+ B,_,H,W = image1.shape
261
+ img1_small = F.interpolate(image1, scale_factor=small_ratio, align_corners=False, mode='bilinear')
262
+ img2_small = F.interpolate(image2, scale_factor=small_ratio, align_corners=False, mode='bilinear')
263
+ padder = InputPadder(img1_small.shape[-2:], divis_by=32, force_square=False)
264
+ img1_small, img2_small = padder.pad(img1_small, img2_small)
265
+ disp_small = self.forward(img1_small, img2_small, test_mode=True, iters=iters, low_memory=low_memory)
266
+ disp_small = padder.unpad(disp_small.float())
267
+ disp_small_up = F.interpolate(disp_small, size=(H,W), mode='bilinear', align_corners=True) * 1/small_ratio
268
+ disp_small_up = disp_small_up.clip(0, None)
269
+
270
+ padder = InputPadder(image1.shape[-2:], divis_by=32, force_square=False)
271
+ image1, image2, disp_small_up = padder.pad(image1, image2, disp_small_up)
272
+ disp_small_up += padder._pad[0]
273
+ init_disp = F.interpolate(disp_small_up, scale_factor=0.25, mode='bilinear', align_corners=True) * 0.25 # Init disp will be 1/4
274
+ disp = self.forward(image1, image2, iters=iters, test_mode=test_mode, low_memory=low_memory, init_disp=init_disp)
275
+ disp = padder.unpad(disp.float())
276
+ return disp
277
+
FoundationStereo_demo/core/geometry.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+ import torch,pdb,os,sys
11
+ import torch.nn.functional as F
12
+ from core.utils.utils import bilinear_sampler
13
+ code_dir = os.path.dirname(os.path.realpath(__file__))
14
+ sys.path.append(f'{code_dir}/../')
15
+ from Utils import *
16
+
17
+ class Combined_Geo_Encoding_Volume:
18
+ def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, dx=None):
19
+ self.num_levels = num_levels
20
+ self.geo_volume_pyramid = []
21
+ self.init_corr_pyramid = []
22
+ self.dx = dx
23
+
24
+ # all pairs correlation
25
+ init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2)
26
+
27
+ b, h, w, _, w2 = init_corr.shape
28
+ b, c, d, h, w = geo_volume.shape
29
+ geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d).contiguous()
30
+
31
+ init_corr = init_corr.reshape(b*h*w, 1, 1, w2)
32
+ self.geo_volume_pyramid.append(geo_volume)
33
+ self.init_corr_pyramid.append(init_corr)
34
+ for i in range(self.num_levels-1):
35
+ geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2])
36
+ self.geo_volume_pyramid.append(geo_volume)
37
+
38
+ for i in range(self.num_levels-1):
39
+ init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
40
+ self.init_corr_pyramid.append(init_corr)
41
+
42
+
43
+ def __call__(self, disp, coords, low_memory=False):
44
+ b, _, h, w = disp.shape
45
+ self.dx = self.dx.to(disp.device)
46
+ out_pyramid = []
47
+ for i in range(self.num_levels):
48
+ geo_volume = self.geo_volume_pyramid[i]
49
+ x0 = self.dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
50
+ y0 = torch.zeros_like(x0)
51
+
52
+ disp_lvl = torch.cat([x0,y0], dim=-1)
53
+ geo_volume = bilinear_sampler(geo_volume, disp_lvl, low_memory=low_memory)
54
+ geo_volume = geo_volume.reshape(b, h, w, -1)
55
+
56
+ init_corr = self.init_corr_pyramid[i]
57
+ init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + self.dx # X on right image
58
+ init_coords_lvl = torch.cat([init_x0,y0], dim=-1)
59
+ init_corr = bilinear_sampler(init_corr, init_coords_lvl, low_memory=low_memory)
60
+ init_corr = init_corr.reshape(b, h, w, -1)
61
+
62
+ out_pyramid.append(geo_volume)
63
+ out_pyramid.append(init_corr)
64
+ out_pyramid = torch.cat(out_pyramid, dim=-1)
65
+ return out_pyramid.permute(0, 3, 1, 2).contiguous() #(B,C,H,W)
66
+
67
+
68
+ @staticmethod
69
+ def corr(fmap1, fmap2):
70
+ B, D, H, W1 = fmap1.shape
71
+ _, _, _, W2 = fmap2.shape
72
+ fmap1 = fmap1.reshape(B, D, H, W1)
73
+ fmap2 = fmap2.reshape(B, D, H, W2)
74
+ with torch.cuda.amp.autocast(enabled=False):
75
+ corr = torch.einsum('aijk,aijh->ajkh', F.normalize(fmap1.float(), dim=1), F.normalize(fmap2.float(), dim=1))
76
+ corr = corr.reshape(B, H, W1, 1, W2)
77
+ return corr
FoundationStereo_demo/core/submodule.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+ import torch,pdb,os,sys
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ code_dir = os.path.dirname(os.path.realpath(__file__))
15
+ sys.path.append(f'{code_dir}/../')
16
+ from Utils import *
17
+
18
+
19
+ def _is_contiguous(tensor: torch.Tensor) -> bool:
20
+ if torch.jit.is_scripting():
21
+ return tensor.is_contiguous()
22
+ else:
23
+ return tensor.is_contiguous(memory_format=torch.contiguous_format)
24
+
25
+
26
+ class LayerNorm2d(nn.LayerNorm):
27
+ r""" https://huggingface.co/spaces/Roll20/pet_score/blob/b258ef28152ab0d5b377d9142a23346f863c1526/lib/timm/models/convnext.py#L85
28
+ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
29
+ """
30
+
31
+ def __init__(self, normalized_shape, eps=1e-6):
32
+ super().__init__(normalized_shape, eps=eps)
33
+
34
+ def forward(self, x) -> torch.Tensor:
35
+ """
36
+ @x: (B,C,H,W)
37
+ """
38
+ if _is_contiguous(x):
39
+ return F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2).contiguous()
40
+ else:
41
+ s, u = torch.var_mean(x, dim=1, keepdim=True)
42
+ x = (x - u) * torch.rsqrt(s + self.eps)
43
+ x = x * self.weight[:, None, None] + self.bias[:, None, None]
44
+ return x
45
+
46
+
47
+
48
+ class BasicConv(nn.Module):
49
+
50
+ def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, norm='batch', **kwargs):
51
+ super(BasicConv, self).__init__()
52
+
53
+ self.relu = relu
54
+ self.use_bn = bn
55
+ self.bn = nn.Identity()
56
+ if is_3d:
57
+ if deconv:
58
+ self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
59
+ else:
60
+ self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
61
+ if self.use_bn:
62
+ if norm=='batch':
63
+ self.bn = nn.BatchNorm3d(out_channels)
64
+ elif norm=='instance':
65
+ self.bn = nn.InstanceNorm3d(out_channels)
66
+ else:
67
+ if deconv:
68
+ self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
69
+ else:
70
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
71
+ if self.use_bn:
72
+ if norm=='batch':
73
+ self.bn = nn.BatchNorm2d(out_channels)
74
+ elif norm=='instance':
75
+ self.bn = nn.InstanceNorm2d(out_channels)
76
+
77
+ def forward(self, x):
78
+ x = self.conv(x)
79
+ if self.use_bn:
80
+ x = self.bn(x)
81
+ if self.relu:
82
+ x = nn.LeakyReLU()(x)#, inplace=True)
83
+ return x
84
+
85
+
86
+ class Conv3dNormActReduced(nn.Module):
87
+ def __init__(self, C_in, C_out, hidden=None, kernel_size=3, kernel_disp=None, stride=1, norm=nn.BatchNorm3d):
88
+ super().__init__()
89
+ if kernel_disp is None:
90
+ kernel_disp = kernel_size
91
+ if hidden is None:
92
+ hidden = C_out
93
+ self.conv1 = nn.Sequential(
94
+ nn.Conv3d(C_in, hidden, kernel_size=(1,kernel_size,kernel_size), padding=(0, kernel_size//2, kernel_size//2), stride=(1, stride, stride)),
95
+ norm(hidden),
96
+ nn.ReLU(),
97
+ )
98
+ self.conv2 = nn.Sequential(
99
+ nn.Conv3d(hidden, C_out, kernel_size=(kernel_disp, 1, 1), padding=(kernel_disp//2, 0, 0), stride=(stride, 1, 1)),
100
+ norm(C_out),
101
+ nn.ReLU(),
102
+ )
103
+
104
+
105
+ def forward(self, x):
106
+ """
107
+ @x: (B,C,D,H,W)
108
+ """
109
+ x = self.conv1(x)
110
+ x = self.conv2(x)
111
+ return x
112
+
113
+
114
+
115
+
116
+ class ResnetBasicBlock(nn.Module):
117
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm2d, bias=False):
118
+ super().__init__()
119
+ self.norm_layer = norm_layer
120
+ if groups != 1 or base_width != 64:
121
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
122
+ if dilation > 1:
123
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
124
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
125
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
126
+ if self.norm_layer is not None:
127
+ self.bn1 = norm_layer(planes)
128
+ self.relu = nn.ReLU(inplace=True)
129
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
130
+ if self.norm_layer is not None:
131
+ self.bn2 = norm_layer(planes)
132
+ self.downsample = downsample
133
+ self.stride = stride
134
+
135
+
136
+ def forward(self, x):
137
+ identity = x
138
+
139
+ out = self.conv1(x)
140
+ if self.norm_layer is not None:
141
+ out = self.bn1(out)
142
+ out = self.relu(out)
143
+
144
+ out = self.conv2(out)
145
+ if self.norm_layer is not None:
146
+ out = self.bn2(out)
147
+
148
+ if self.downsample is not None:
149
+ identity = self.downsample(x)
150
+ out += identity
151
+ out = self.relu(out)
152
+
153
+ return out
154
+
155
+
156
+ class ResnetBasicBlock3D(nn.Module):
157
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm3d, bias=False):
158
+ super().__init__()
159
+ self.norm_layer = norm_layer
160
+ if groups != 1 or base_width != 64:
161
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
162
+ if dilation > 1:
163
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
164
+ self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
165
+ if self.norm_layer is not None:
166
+ self.bn1 = norm_layer(planes)
167
+ self.relu = nn.ReLU(inplace=True)
168
+ self.conv2 = nn.Conv3d(planes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
169
+ if self.norm_layer is not None:
170
+ self.bn2 = norm_layer(planes)
171
+ self.downsample = downsample
172
+ self.stride = stride
173
+
174
+
175
+ def forward(self, x):
176
+ identity = x
177
+
178
+ out = self.conv1(x)
179
+ if self.norm_layer is not None:
180
+ out = self.bn1(out)
181
+ out = self.relu(out)
182
+
183
+ out = self.conv2(out)
184
+ if self.norm_layer is not None:
185
+ out = self.bn2(out)
186
+
187
+ if self.downsample is not None:
188
+ identity = self.downsample(x)
189
+ out += identity
190
+ out = self.relu(out)
191
+
192
+ return out
193
+
194
+
195
+ class FlashMultiheadAttention(nn.Module):
196
+ def __init__(self, embed_dim, num_heads):
197
+ super().__init__()
198
+ self.num_heads = num_heads
199
+ self.embed_dim = embed_dim
200
+ self.head_dim = embed_dim // num_heads
201
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
202
+
203
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
204
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
205
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
206
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
207
+
208
+ def forward(self, query, key, value, attn_mask=None, window_size=(-1,-1)):
209
+ """
210
+ @query: (B,L,C)
211
+ """
212
+ B,L,C = query.shape
213
+ Q = self.q_proj(query)
214
+ K = self.k_proj(key)
215
+ V = self.v_proj(value)
216
+
217
+ Q = Q.view(Q.size(0), Q.size(1), self.num_heads, self.head_dim)
218
+ K = K.view(K.size(0), K.size(1), self.num_heads, self.head_dim)
219
+ V = V.view(V.size(0), V.size(1), self.num_heads, self.head_dim)
220
+
221
+ attn_output = F.scaled_dot_product_attention(Q, K, V)
222
+
223
+ attn_output = attn_output.reshape(B,L,-1)
224
+ output = self.out_proj(attn_output)
225
+
226
+ return output
227
+
228
+
229
+
230
+ class FlashAttentionTransformerEncoderLayer(nn.Module):
231
+ def __init__(self, embed_dim, num_heads, dim_feedforward, dropout=0.1, act=nn.GELU, norm=nn.LayerNorm):
232
+ super().__init__()
233
+ self.self_attn = FlashMultiheadAttention(embed_dim, num_heads)
234
+ self.act = act()
235
+
236
+ self.linear1 = nn.Linear(embed_dim, dim_feedforward)
237
+ self.dropout = nn.Dropout(dropout)
238
+ self.linear2 = nn.Linear(dim_feedforward, embed_dim)
239
+
240
+ self.norm1 = norm(embed_dim)
241
+ self.norm2 = norm(embed_dim)
242
+ self.dropout1 = nn.Dropout(dropout)
243
+ self.dropout2 = nn.Dropout(dropout)
244
+
245
+ def forward(self, src, src_mask=None, window_size=(-1, -1)):
246
+ src2 = self.self_attn(src, src, src, src_mask, window_size=window_size)
247
+ src = src + self.dropout1(src2)
248
+ src = self.norm1(src)
249
+
250
+ src2 = self.linear2(self.dropout(self.act(self.linear1(src))))
251
+ src = src + self.dropout2(src2)
252
+ src = self.norm2(src)
253
+
254
+ return src
255
+
256
+
257
+
258
+ class UpsampleConv(nn.Module):
259
+ def __init__(self, C_in, C_out, is_3d=False, kernel_size=3, bias=True, stride=1, padding=1):
260
+ super().__init__()
261
+ self.is_3d = is_3d
262
+ if is_3d:
263
+ self.conv = nn.Conv3d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=bias)
264
+ else:
265
+ self.conv = nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=bias)
266
+
267
+ def forward(self, x):
268
+ if self.is_3d:
269
+ mode = 'trilinear'
270
+ else:
271
+ mode = 'bilinear'
272
+ x = F.interpolate(x, size=None, scale_factor=2, align_corners=False, mode=mode)
273
+ x = self.conv(x)
274
+ return x
275
+
276
+
277
+
278
+ class Conv2x(nn.Module):
279
+
280
+ def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
281
+ super(Conv2x, self).__init__()
282
+ self.concat = concat
283
+ self.is_3d = is_3d
284
+ if deconv and is_3d:
285
+ kernel = (4, 4, 4)
286
+ elif deconv:
287
+ kernel = 4
288
+ else:
289
+ kernel = 3
290
+
291
+ if deconv and is_3d and keep_dispc:
292
+ kernel = (1, 4, 4)
293
+ stride = (1, 2, 2)
294
+ padding = (0, 1, 1)
295
+ self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=bn, relu=True, kernel_size=kernel, stride=stride, padding=padding)
296
+ else:
297
+ self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=bn, relu=True, kernel_size=kernel, stride=2, padding=1)
298
+
299
+ if self.concat:
300
+ mul = 2 if keep_concat else 1
301
+ self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
302
+ else:
303
+ self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
304
+
305
+ def forward(self, x, rem):
306
+ x = self.conv1(x)
307
+ if x.shape != rem.shape:
308
+ x = F.interpolate(x, size=(rem.shape[-2], rem.shape[-1]), mode='bilinear')
309
+ if self.concat:
310
+ x = torch.cat((x, rem), 1)
311
+ else:
312
+ x = x + rem
313
+ x = self.conv2(x)
314
+ return x
315
+
316
+
317
+ class BasicConv_IN(nn.Module):
318
+
319
+ def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
320
+ super(BasicConv_IN, self).__init__()
321
+
322
+ self.relu = relu
323
+ self.use_in = IN
324
+ if is_3d:
325
+ if deconv:
326
+ self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
327
+ else:
328
+ self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
329
+ self.IN = nn.InstanceNorm3d(out_channels)
330
+ else:
331
+ if deconv:
332
+ self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
333
+ else:
334
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
335
+ self.IN = nn.InstanceNorm2d(out_channels)
336
+
337
+ def forward(self, x):
338
+ x = self.conv(x)
339
+ if self.use_in:
340
+ x = self.IN(x)
341
+ if self.relu:
342
+ x = nn.LeakyReLU()(x)#, inplace=True)
343
+ return x
344
+
345
+
346
+ class Conv2x_IN(nn.Module):
347
+
348
+ def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
349
+ super(Conv2x_IN, self).__init__()
350
+ self.concat = concat
351
+ self.is_3d = is_3d
352
+ if deconv and is_3d:
353
+ kernel = (4, 4, 4)
354
+ elif deconv:
355
+ kernel = 4
356
+ else:
357
+ kernel = 3
358
+
359
+ if deconv and is_3d and keep_dispc:
360
+ kernel = (1, 4, 4)
361
+ stride = (1, 2, 2)
362
+ padding = (0, 1, 1)
363
+ self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
364
+ else:
365
+ self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
366
+
367
+ if self.concat:
368
+ mul = 2 if keep_concat else 1
369
+ self.conv2 = ResnetBasicBlock(out_channels*2, out_channels*mul, kernel_size=3, stride=1, padding=1, norm_layer=nn.InstanceNorm2d)
370
+ else:
371
+ self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
372
+
373
+ def forward(self, x, rem):
374
+ x = self.conv1(x)
375
+ if x.shape != rem.shape:
376
+ x = F.interpolate(x, size=(rem.shape[-2], rem.shape[-1]), mode='bilinear')
377
+ if self.concat:
378
+ x = torch.cat((x, rem), 1)
379
+ else:
380
+ x = x + rem
381
+ x = self.conv2(x)
382
+ return x
383
+
384
+
385
+ def groupwise_correlation(fea1, fea2, num_groups):
386
+ B, C, H, W = fea1.shape
387
+ assert C % num_groups == 0, f"C:{C}, num_groups:{num_groups}"
388
+ channels_per_group = C // num_groups
389
+ fea1 = fea1.reshape(B, num_groups, channels_per_group, H, W)
390
+ fea2 = fea2.reshape(B, num_groups, channels_per_group, H, W)
391
+ with torch.cuda.amp.autocast(enabled=False):
392
+ cost = (F.normalize(fea1.float(), dim=2) * F.normalize(fea2.float(), dim=2)).sum(dim=2) #!NOTE Divide first for numerical stability
393
+ assert cost.shape == (B, num_groups, H, W)
394
+ return cost
395
+
396
+ def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups, stride=1):
397
+ """
398
+ @refimg_fea: left image feature
399
+ @targetimg_fea: right image feature
400
+ """
401
+ B, C, H, W = refimg_fea.shape
402
+ volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
403
+ for i in range(maxdisp):
404
+ if i > 0:
405
+ volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], num_groups)
406
+ else:
407
+ volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups)
408
+ volume = volume.contiguous()
409
+ return volume
410
+
411
+
412
+
413
+ def build_concat_volume(refimg_fea, targetimg_fea, maxdisp):
414
+ B, C, H, W = refimg_fea.shape
415
+ volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W])
416
+ for i in range(maxdisp):
417
+ if i > 0:
418
+ volume[:, :C, i, :, :] = refimg_fea[:, :, :, :]
419
+ volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i]
420
+ else:
421
+ volume[:, :C, i, :, :] = refimg_fea
422
+ volume[:, C:, i, :, :] = targetimg_fea
423
+ volume = volume.contiguous()
424
+ return volume
425
+
426
+
427
+
428
+ def disparity_regression(x, maxdisp):
429
+ assert len(x.shape) == 4
430
+ disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
431
+ disp_values = disp_values.reshape(1, maxdisp, 1, 1)
432
+ return torch.sum(x * disp_values, 1, keepdim=True)
433
+
434
+
435
+ class FeatureAtt(nn.Module):
436
+ def __init__(self, cv_chan, feat_chan):
437
+ super(FeatureAtt, self).__init__()
438
+
439
+ self.feat_att = nn.Sequential(
440
+ BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
441
+ nn.Conv2d(feat_chan//2, cv_chan, 1)
442
+ )
443
+
444
+ def forward(self, cv, feat):
445
+ '''
446
+ @cv: cost volume (B,C,D,H,W)
447
+ @feat: (B,C,H,W)
448
+ '''
449
+ feat_att = self.feat_att(feat).unsqueeze(2) #(B,C,1,H,W)
450
+ cv = torch.sigmoid(feat_att)*cv
451
+ return cv
452
+
453
+ def context_upsample(disp_low, up_weights):
454
+ """
455
+ @disp_low: (b,1,h,w) 1/4 resolution
456
+ @up_weights: (b,9,4*h,4*w) Image resolution
457
+ """
458
+ b, c, h, w = disp_low.shape
459
+
460
+ disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
461
+ disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
462
+
463
+ disp = (disp_unfold*up_weights).sum(1)
464
+
465
+ return disp
466
+
467
+
468
+
469
+ class PositionalEmbedding(nn.Module):
470
+ def __init__(self, d_model, max_len=512):
471
+ super().__init__()
472
+
473
+ # Compute the positional encodings once in log space.
474
+ pe = torch.zeros(max_len, d_model).float()
475
+ pe.require_grad = False
476
+
477
+ position = torch.arange(0, max_len).float().unsqueeze(1) #(N,1)
478
+ div_term = (torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)).exp()[None]
479
+
480
+ pe[:, 0::2] = torch.sin(position * div_term) #(N, d_model/2)
481
+ pe[:, 1::2] = torch.cos(position * div_term)
482
+
483
+ pe = pe.unsqueeze(0)
484
+ self.pe = pe
485
+ # self.register_buffer('pe', pe) #(1, max_len, D)
486
+
487
+
488
+ def forward(self, x, resize_embed=False):
489
+ '''
490
+ @x: (B,N,D)
491
+ '''
492
+ self.pe = self.pe.to(x.device).to(x.dtype)
493
+ pe = self.pe
494
+ if pe.shape[1]<x.shape[1]:
495
+ if resize_embed:
496
+ pe = F.interpolate(pe.permute(0,2,1), size=x.shape[1], mode='linear', align_corners=False).permute(0,2,1)
497
+ else:
498
+ raise RuntimeError(f'x:{x.shape}, pe:{pe.shape}')
499
+ return x + pe[:, :x.size(1)]
500
+
501
+
502
+
503
+ class CostVolumeDisparityAttention(nn.Module):
504
+ def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, act=nn.GELU, norm_first=False, num_transformer=6, max_len=512, resize_embed=False):
505
+ super().__init__()
506
+ self.resize_embed = resize_embed
507
+ self.sa = nn.ModuleList([])
508
+ for _ in range(num_transformer):
509
+ self.sa.append(FlashAttentionTransformerEncoderLayer(embed_dim=d_model, num_heads=nhead, dim_feedforward=dim_feedforward, act=act, dropout=dropout))
510
+ self.pos_embed0 = PositionalEmbedding(d_model, max_len=max_len)
511
+
512
+
513
+ def forward(self, cv, window_size=(-1,-1)):
514
+ """
515
+ @cv: (B,C,D,H,W) where D is max disparity
516
+ """
517
+ x = cv
518
+ B,C,D,H,W = x.shape
519
+ x = x.permute(0,3,4,2,1).reshape(B*H*W, D, C)
520
+ x = self.pos_embed0(x, resize_embed=self.resize_embed) #!NOTE No resize since disparity is pre-determined
521
+ for i in range(len(self.sa)):
522
+ x = self.sa[i](x, window_size=window_size)
523
+ x = x.reshape(B,H,W,D,C).permute(0,4,3,1,2)
524
+
525
+ return x
526
+
527
+
528
+
529
+ class ChannelAttentionEnhancement(nn.Module):
530
+ def __init__(self, in_planes, ratio=16):
531
+ super(ChannelAttentionEnhancement, self).__init__()
532
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
533
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
534
+
535
+ self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
536
+ nn.ReLU(),
537
+ nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
538
+ self.sigmoid = nn.Sigmoid()
539
+
540
+ def forward(self, x):
541
+ avg_out = self.fc(self.avg_pool(x))
542
+ max_out = self.fc(self.max_pool(x))
543
+ out = avg_out + max_out
544
+ return self.sigmoid(out)
545
+
546
+ class SpatialAttentionExtractor(nn.Module):
547
+ def __init__(self, kernel_size=7):
548
+ super(SpatialAttentionExtractor, self).__init__()
549
+
550
+ self.samconv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
551
+ self.sigmoid = nn.Sigmoid()
552
+
553
+ def forward(self, x):
554
+ avg_out = torch.mean(x, dim=1, keepdim=True)
555
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
556
+ x = torch.cat([avg_out, max_out], dim=1)
557
+ x = self.samconv(x)
558
+ return self.sigmoid(x)
559
+
560
+
561
+
562
+ class EdgeNextConvEncoder(nn.Module):
563
+ def __init__(self, dim, layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7, norm='layer'):
564
+ super().__init__()
565
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
566
+ if norm=='layer':
567
+ self.norm = LayerNorm2d(dim, eps=1e-6)
568
+ else:
569
+ self.norm = nn.Identity()
570
+ self.pwconv1 = nn.Linear(dim, expan_ratio * dim)
571
+ self.act = nn.GELU()
572
+ self.pwconv2 = nn.Linear(expan_ratio * dim, dim)
573
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None
574
+
575
+ def forward(self, x):
576
+ input = x
577
+ x = self.dwconv(x)
578
+ x = self.norm(x)
579
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
580
+ x = self.pwconv1(x)
581
+ x = self.act(x)
582
+ x = self.pwconv2(x)
583
+ if self.gamma is not None:
584
+ x = self.gamma * x
585
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
586
+
587
+ x = input + x
588
+ return x
FoundationStereo_demo/core/update.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+ import torch,pdb,os,sys
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+ from torch import einsum
15
+ code_dir = os.path.dirname(os.path.realpath(__file__))
16
+ sys.path.append(f'{code_dir}/../')
17
+ from core.submodule import *
18
+ from core.extractor import *
19
+
20
+ class DispHead(nn.Module):
21
+ def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
22
+ super(DispHead, self).__init__()
23
+ self.conv = nn.Sequential(
24
+ nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1),
25
+ nn.ReLU(),
26
+ EdgeNextConvEncoder(input_dim, expan_ratio=4, kernel_size=7, norm=None),
27
+ EdgeNextConvEncoder(input_dim, expan_ratio=4, kernel_size=7, norm=None),
28
+ nn.Conv2d(input_dim, output_dim, 3, padding=1),
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.conv(x)
33
+
34
+ class ConvGRU(nn.Module):
35
+ def __init__(self, hidden_dim, input_dim, kernel_size=3):
36
+ super(ConvGRU, self).__init__()
37
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
38
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
39
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
40
+
41
+ def forward(self, h, cz, cr, cq, *x_list):
42
+ x = torch.cat(x_list, dim=1)
43
+ hx = torch.cat([h, x], dim=1)
44
+ z = torch.sigmoid(self.convz(hx) + cz)
45
+ r = torch.sigmoid(self.convr(hx) + cr)
46
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
47
+ h = (1-z) * h + z * q
48
+ return h
49
+
50
+
51
+ class BasicMotionEncoder(nn.Module):
52
+ def __init__(self, args, ngroup=8):
53
+ super(BasicMotionEncoder, self).__init__()
54
+ self.args = args
55
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1) * (ngroup+1)
56
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
57
+ self.convc2 = nn.Conv2d(256, 256, 3, padding=1)
58
+ self.convd1 = nn.Conv2d(1, 64, 7, padding=3)
59
+ self.convd2 = nn.Conv2d(64, 64, 3, padding=1)
60
+ self.conv = nn.Conv2d(64+256, 128-1, 3, padding=1)
61
+
62
+ def forward(self, disp, corr):
63
+ cor = F.relu(self.convc1(corr))
64
+ cor = F.relu(self.convc2(cor))
65
+ disp_ = F.relu(self.convd1(disp))
66
+ disp_ = F.relu(self.convd2(disp_))
67
+
68
+ cor_disp = torch.cat([cor, disp_], dim=1)
69
+ out = F.relu(self.conv(cor_disp))
70
+ return torch.cat([out, disp], dim=1)
71
+
72
+ def pool2x(x):
73
+ return F.avg_pool2d(x, 3, stride=2, padding=1)
74
+
75
+ def pool4x(x):
76
+ return F.avg_pool2d(x, 5, stride=4, padding=1)
77
+
78
+ def interp(x, dest):
79
+ interp_args = {'mode': 'bilinear', 'align_corners': True}
80
+ return F.interpolate(x, dest.shape[2:], **interp_args)
81
+
82
+
83
+ class RaftConvGRU(nn.Module):
84
+ def __init__(self, hidden_dim=128, input_dim=256, kernel_size=3):
85
+ super().__init__()
86
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
87
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
88
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
89
+
90
+ def forward(self, h, x, hx):
91
+ z = torch.sigmoid(self.convz(hx))
92
+ r = torch.sigmoid(self.convr(hx))
93
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
94
+ h = (1-z) * h + z * q
95
+ return h
96
+
97
+
98
+ class SelectiveConvGRU(nn.Module):
99
+ def __init__(self, hidden_dim=128, input_dim=256, small_kernel_size=1, large_kernel_size=3, patch_size=None):
100
+ super(SelectiveConvGRU, self).__init__()
101
+ self.conv0 = nn.Sequential(
102
+ nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1),
103
+ nn.ReLU(),
104
+ )
105
+ self.conv1 = nn.Sequential(
106
+ nn.Conv2d(input_dim+hidden_dim, input_dim+hidden_dim, kernel_size=3, padding=1),
107
+ nn.ReLU(),
108
+ )
109
+ self.small_gru = RaftConvGRU(hidden_dim, input_dim, small_kernel_size)
110
+ self.large_gru = RaftConvGRU(hidden_dim, input_dim, large_kernel_size)
111
+
112
+ def forward(self, att, h, *x):
113
+ x = torch.cat(x, dim=1)
114
+ x = self.conv0(x)
115
+ hx = torch.cat([x, h], dim=1)
116
+ hx = self.conv1(hx)
117
+ h = self.small_gru(h, x, hx) * att + self.large_gru(h, x, hx) * (1 - att)
118
+
119
+ return h
120
+
121
+
122
+ class BasicSelectiveMultiUpdateBlock(nn.Module):
123
+ def __init__(self, args, hidden_dim=128, volume_dim=8):
124
+ super().__init__()
125
+ self.args = args
126
+ self.encoder = BasicMotionEncoder(args, volume_dim)
127
+
128
+ if args.n_gru_layers == 3:
129
+ self.gru16 = SelectiveConvGRU(hidden_dim, hidden_dim * 2)
130
+ if args.n_gru_layers >= 2:
131
+ self.gru08 = SelectiveConvGRU(hidden_dim, hidden_dim * (args.n_gru_layers == 3) + hidden_dim * 2)
132
+ self.gru04 = SelectiveConvGRU(hidden_dim, hidden_dim * (args.n_gru_layers > 1) + hidden_dim * 2)
133
+ self.disp_head = DispHead(hidden_dim, 256)
134
+ self.mask = nn.Sequential(
135
+ nn.Conv2d(128, 64, 3, padding=1),
136
+ nn.ReLU(inplace=True),
137
+ nn.Conv2d(64, 32, 3, padding=1),
138
+ nn.ReLU(inplace=True),
139
+ )
140
+
141
+ def forward(self, net, inp, corr, disp, att):
142
+ if self.args.n_gru_layers == 3:
143
+ net[2] = self.gru16(att[2], net[2], inp[2], pool2x(net[1]))
144
+ if self.args.n_gru_layers >= 2:
145
+ if self.args.n_gru_layers > 2:
146
+ net[1] = self.gru08(att[1], net[1], inp[1], pool2x(net[0]), interp(net[2], net[1]))
147
+ else:
148
+ net[1] = self.gru08(att[1], net[1], inp[1], pool2x(net[0]))
149
+
150
+ motion_features = self.encoder(disp, corr)
151
+ motion_features = torch.cat([inp[0], motion_features], dim=1)
152
+ if self.args.n_gru_layers > 1:
153
+ net[0] = self.gru04(att[0], net[0], motion_features, interp(net[1], net[0]))
154
+
155
+ delta_disp = self.disp_head(net[0])
156
+
157
+ # scale mask to balence gradients
158
+ mask = .25 * self.mask(net[0])
159
+ return net, mask, delta_disp
FoundationStereo_demo/core/utils/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+
10
+
11
+ import torch,pdb,logging
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from scipy import interpolate
15
+
16
+
17
+ class InputPadder:
18
+ """ Pads images such that dimensions are divisible by 8 """
19
+ def __init__(self, dims, mode='sintel', divis_by=8, force_square=False):
20
+ self.ht, self.wd = dims[-2:]
21
+ if force_square:
22
+ max_side = max(self.ht, self.wd)
23
+ pad_ht = ((max_side // divis_by) + 1) * divis_by - self.ht
24
+ pad_wd = ((max_side // divis_by) + 1) * divis_by - self.wd
25
+ else:
26
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
27
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
28
+ if mode == 'sintel':
29
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
30
+ else:
31
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
32
+
33
+ def pad(self, *inputs):
34
+ assert all((x.ndim == 4) for x in inputs)
35
+ # Ensure padded tensors are contiguous to avoid cuDNN issues
36
+ return [F.pad(x, self._pad, mode='replicate').contiguous() for x in inputs]
37
+
38
+ def unpad(self, x):
39
+ assert x.ndim == 4
40
+ ht, wd = x.shape[-2:]
41
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
42
+ # Ensure unpadded tensor is contiguous
43
+ return x[..., c[0]:c[1], c[2]:c[3]].contiguous()
44
+
45
+
46
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False, low_memory=False):
47
+ """ Wrapper for grid_sample, uses pixel coordinates """
48
+ H, W = img.shape[-2:]
49
+ xgrid, ygrid = coords.split([1,1], dim=-1)
50
+ xgrid = 2*xgrid/(W-1) - 1 # Normalize to [-1,1]
51
+ assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
52
+ grid = torch.cat([xgrid, ygrid], dim=-1).to(img.dtype).contiguous()
53
+ img = F.grid_sample(img, grid, align_corners=True).contiguous()
54
+ if mask:
55
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
56
+ return img, mask.float().contiguous()
57
+ return img
58
+
59
+
60
+ def coords_grid(batch, ht, wd):
61
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
62
+ coords = torch.stack(coords[::-1], dim=0).float()
63
+ return coords[None].repeat(batch, 1, 1, 1)
64
+
FoundationStereo_demo/depth_anything/LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
FoundationStereo_demo/depth_anything/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # depth_anything package
2
+ # This file allows depth_anything to be imported as a package
FoundationStereo_demo/depth_anything/blocks.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape*2
16
+ out_shape3 = out_shape*4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape*8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(
21
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
22
+ )
23
+ scratch.layer2_rn = nn.Conv2d(
24
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
25
+ )
26
+ scratch.layer3_rn = nn.Conv2d(
27
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
28
+ )
29
+ if len(in_shape) >= 4:
30
+ scratch.layer4_rn = nn.Conv2d(
31
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
32
+ )
33
+
34
+ return scratch
35
+
36
+
37
+ class ResidualConvUnit(nn.Module):
38
+ """Residual convolution module.
39
+ """
40
+
41
+ def __init__(self, features, activation, bn):
42
+ """Init.
43
+
44
+ Args:
45
+ features (int): number of features
46
+ """
47
+ super().__init__()
48
+
49
+ self.bn = bn
50
+
51
+ self.groups=1
52
+
53
+ self.conv1 = nn.Conv2d(
54
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
55
+ )
56
+
57
+ self.conv2 = nn.Conv2d(
58
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
59
+ )
60
+
61
+ if self.bn==True:
62
+ self.bn1 = nn.BatchNorm2d(features)
63
+ self.bn2 = nn.BatchNorm2d(features)
64
+
65
+ self.activation = activation
66
+
67
+ self.skip_add = nn.quantized.FloatFunctional()
68
+
69
+ def forward(self, x):
70
+ """Forward pass.
71
+
72
+ Args:
73
+ x (tensor): input
74
+
75
+ Returns:
76
+ tensor: output
77
+ """
78
+
79
+ out = self.activation(x)
80
+ out = self.conv1(out)
81
+ if self.bn==True:
82
+ out = self.bn1(out)
83
+
84
+ out = self.activation(out)
85
+ out = self.conv2(out)
86
+ if self.bn==True:
87
+ out = self.bn2(out)
88
+
89
+ if self.groups > 1:
90
+ out = self.conv_merge(out)
91
+
92
+ return self.skip_add.add(out, x)
93
+
94
+
95
+ class FeatureFusionBlock(nn.Module):
96
+ """Feature fusion block.
97
+ """
98
+
99
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
100
+ """Init.
101
+
102
+ Args:
103
+ features (int): number of features
104
+ """
105
+ super(FeatureFusionBlock, self).__init__()
106
+
107
+ self.deconv = deconv
108
+ self.align_corners = align_corners
109
+
110
+ self.groups=1
111
+
112
+ self.expand = expand
113
+ out_features = features
114
+ if self.expand==True:
115
+ out_features = features//2
116
+
117
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
118
+
119
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
120
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
121
+
122
+ self.skip_add = nn.quantized.FloatFunctional()
123
+
124
+ self.size=size
125
+
126
+ def forward(self, *xs, size=None):
127
+ """Forward pass.
128
+
129
+ Returns:
130
+ tensor: output
131
+ """
132
+ output = xs[0]
133
+
134
+ if len(xs) == 2:
135
+ res = self.resConfUnit1(xs[1])
136
+ output = self.skip_add.add(output, res)
137
+
138
+ output = self.resConfUnit2(output)
139
+
140
+ if (size is None) and (self.size is None):
141
+ modifier = {"scale_factor": 2}
142
+ elif size is None:
143
+ modifier = {"size": self.size}
144
+ else:
145
+ modifier = {"size": size}
146
+
147
+ output = nn.functional.interpolate(
148
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
149
+ )
150
+
151
+ output = self.out_conv(output)
152
+
153
+ return output
FoundationStereo_demo/depth_anything/dpt.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch,os,sys,pdb
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ code_dir = os.path.dirname(os.path.realpath(__file__))
6
+ sys.path.append(f'{code_dir}/../')
7
+ from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
+
9
+
10
+ def _make_fusion_block(features, use_bn, size = None):
11
+ return FeatureFusionBlock(
12
+ features,
13
+ nn.ReLU(False),
14
+ deconv=False,
15
+ bn=use_bn,
16
+ expand=False,
17
+ align_corners=True,
18
+ size=size,
19
+ )
20
+
21
+
22
+ class DPTHead(nn.Module):
23
+ def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
24
+ super(DPTHead, self).__init__()
25
+
26
+ self.nclass = nclass
27
+ self.use_clstoken = use_clstoken
28
+
29
+ self.projects = nn.ModuleList([
30
+ nn.Conv2d(
31
+ in_channels=in_channels,
32
+ out_channels=out_channel,
33
+ kernel_size=1,
34
+ stride=1,
35
+ padding=0,
36
+ ) for out_channel in out_channels
37
+ ])
38
+
39
+ self.resize_layers = nn.ModuleList([
40
+ nn.ConvTranspose2d(
41
+ in_channels=out_channels[0],
42
+ out_channels=out_channels[0],
43
+ kernel_size=4,
44
+ stride=4,
45
+ padding=0),
46
+ nn.ConvTranspose2d(
47
+ in_channels=out_channels[1],
48
+ out_channels=out_channels[1],
49
+ kernel_size=2,
50
+ stride=2,
51
+ padding=0),
52
+ nn.Identity(),
53
+ nn.Conv2d(
54
+ in_channels=out_channels[3],
55
+ out_channels=out_channels[3],
56
+ kernel_size=3,
57
+ stride=2,
58
+ padding=1)
59
+ ])
60
+
61
+ if use_clstoken:
62
+ self.readout_projects = nn.ModuleList()
63
+ for _ in range(len(self.projects)):
64
+ self.readout_projects.append(
65
+ nn.Sequential(
66
+ nn.Linear(2 * in_channels, in_channels),
67
+ nn.GELU()))
68
+
69
+ self.scratch = _make_scratch(
70
+ out_channels,
71
+ features,
72
+ groups=1,
73
+ expand=False,
74
+ )
75
+
76
+ self.scratch.stem_transpose = None
77
+
78
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
79
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
80
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
81
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
82
+
83
+ head_features_1 = features
84
+ head_features_2 = 32
85
+
86
+ if nclass > 1:
87
+ self.scratch.output_conv = nn.Sequential(
88
+ nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
89
+ nn.ReLU(True),
90
+ nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
91
+ )
92
+ else:
93
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
94
+
95
+ self.scratch.output_conv2 = nn.Sequential(
96
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
97
+ nn.ReLU(True),
98
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
99
+ nn.ReLU(True),
100
+ nn.Identity(),
101
+ )
102
+
103
+ def forward(self, out_features, patch_h, patch_w, return_intermediate=False, patch_size=14):
104
+ out = []
105
+ for i, x in enumerate(out_features):
106
+ if self.use_clstoken:
107
+ x, cls_token = x[0], x[1]
108
+ readout = cls_token.unsqueeze(1).expand_as(x)
109
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
110
+ else:
111
+ x = x[0]
112
+
113
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
114
+
115
+ x = self.projects[i](x)
116
+ x = self.resize_layers[i](x)
117
+
118
+ out.append(x)
119
+
120
+ layer_1, layer_2, layer_3, layer_4 = out
121
+
122
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
123
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
124
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
125
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
126
+
127
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
128
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
129
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
130
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
131
+
132
+ out = self.scratch.output_conv1(path_1)
133
+ out = F.interpolate(out, (int(patch_h * patch_size), int(patch_w * patch_size)), mode="bilinear", align_corners=True)
134
+ if return_intermediate:
135
+ depth = self.scratch.output_conv2(out)
136
+ depth = F.relu(depth)
137
+ disp = 1/depth
138
+ disp[depth==0] = 0
139
+ disp = disp/disp.max()
140
+ return out, path_1, path_2, path_3, path_4, disp
141
+
142
+ else:
143
+ out = self.scratch.output_conv2(out)
144
+ return out
145
+
146
+
147
+ class DPT_DINOv2(nn.Module):
148
+ def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, pretrained_dino=False):
149
+ super(DPT_DINOv2, self).__init__()
150
+
151
+ assert encoder in ['vits', 'vitb', 'vitl']
152
+
153
+ # in case the Internet connection is not stable, please load the DINOv2 locally
154
+ # if localhub:
155
+ # self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
156
+ # else:
157
+ self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder), pretrained=pretrained_dino)
158
+
159
+
160
+ dim = self.pretrained.blocks[0].attn.qkv.in_features
161
+
162
+ self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
163
+
164
+ def forward(self, x):
165
+ h, w = x.shape[-2:]
166
+
167
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
168
+ patch_size = self.pretrained.patch_size
169
+ patch_h, patch_w = h // patch_size, w // patch_size
170
+ output = self.depth_head(features, patch_h, patch_w, patch_size=patch_size, return_intermediate=True)
171
+ return output
172
+
173
+
174
+ class DepthAnything(DPT_DINOv2):
175
+ def __init__(self, config):
176
+ super().__init__(**config)
177
+
178
+ def forward(self, x):
179
+ h, w = x.shape[-2:]
180
+
181
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
182
+ patch_size = self.pretrained.patch_size
183
+ patch_h, patch_w = h // patch_size, w // patch_size
184
+ depth = self.depth_head(features, patch_h, patch_w, patch_size=patch_size)
185
+ depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
186
+ depth = F.relu(depth)
187
+
188
+ return depth.squeeze(1)
189
+
190
+
191
+ if __name__ == '__main__':
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument(
194
+ "--encoder",
195
+ default="vits",
196
+ type=str,
197
+ choices=["vits", "vitb", "vitl"],
198
+ )
199
+ args = parser.parse_args()
200
+
201
+ model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
202
+
203
+ print(model)
FoundationStereo_demo/depth_anything/util/transform.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from PIL import Image, ImageOps, ImageFilter
3
+ import torch
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+
7
+ import numpy as np
8
+ import cv2
9
+ import math
10
+
11
+
12
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
13
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
14
+
15
+ Args:
16
+ sample (dict): sample
17
+ size (tuple): image size
18
+
19
+ Returns:
20
+ tuple: new size
21
+ """
22
+ shape = list(sample["disparity"].shape)
23
+
24
+ if shape[0] >= size[0] and shape[1] >= size[1]:
25
+ return sample
26
+
27
+ scale = [0, 0]
28
+ scale[0] = size[0] / shape[0]
29
+ scale[1] = size[1] / shape[1]
30
+
31
+ scale = max(scale)
32
+
33
+ shape[0] = math.ceil(scale * shape[0])
34
+ shape[1] = math.ceil(scale * shape[1])
35
+
36
+ # resize
37
+ sample["image"] = cv2.resize(
38
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
39
+ )
40
+
41
+ sample["disparity"] = cv2.resize(
42
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
43
+ )
44
+ sample["mask"] = cv2.resize(
45
+ sample["mask"].astype(np.float32),
46
+ tuple(shape[::-1]),
47
+ interpolation=cv2.INTER_NEAREST,
48
+ )
49
+ sample["mask"] = sample["mask"].astype(bool)
50
+
51
+ return tuple(shape)
52
+
53
+
54
+ class Resize(object):
55
+ """Resize sample to given size (width, height).
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ width,
61
+ height,
62
+ resize_target=True,
63
+ keep_aspect_ratio=False,
64
+ ensure_multiple_of=1,
65
+ resize_method="lower_bound",
66
+ image_interpolation_method=cv2.INTER_AREA,
67
+ ):
68
+ """Init.
69
+
70
+ Args:
71
+ width (int): desired output width
72
+ height (int): desired output height
73
+ resize_target (bool, optional):
74
+ True: Resize the full sample (image, mask, target).
75
+ False: Resize image only.
76
+ Defaults to True.
77
+ keep_aspect_ratio (bool, optional):
78
+ True: Keep the aspect ratio of the input sample.
79
+ Output sample might not have the given width and height, and
80
+ resize behaviour depends on the parameter 'resize_method'.
81
+ Defaults to False.
82
+ ensure_multiple_of (int, optional):
83
+ Output width and height is constrained to be multiple of this parameter.
84
+ Defaults to 1.
85
+ resize_method (str, optional):
86
+ "lower_bound": Output will be at least as large as the given size.
87
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
88
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
89
+ Defaults to "lower_bound".
90
+ """
91
+ self.__width = width
92
+ self.__height = height
93
+
94
+ self.__resize_target = resize_target
95
+ self.__keep_aspect_ratio = keep_aspect_ratio
96
+ self.__multiple_of = ensure_multiple_of
97
+ self.__resize_method = resize_method
98
+ self.__image_interpolation_method = image_interpolation_method
99
+
100
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
101
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ if max_val is not None and y > max_val:
104
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
105
+
106
+ if y < min_val:
107
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
108
+
109
+ return y
110
+
111
+ def get_size(self, width, height):
112
+ # determine new height and width
113
+ scale_height = self.__height / height
114
+ scale_width = self.__width / width
115
+
116
+ if self.__keep_aspect_ratio:
117
+ if self.__resize_method == "lower_bound":
118
+ # scale such that output size is lower bound
119
+ if scale_width > scale_height:
120
+ # fit width
121
+ scale_height = scale_width
122
+ else:
123
+ # fit height
124
+ scale_width = scale_height
125
+ elif self.__resize_method == "upper_bound":
126
+ # scale such that output size is upper bound
127
+ if scale_width < scale_height:
128
+ # fit width
129
+ scale_height = scale_width
130
+ else:
131
+ # fit height
132
+ scale_width = scale_height
133
+ elif self.__resize_method == "minimal":
134
+ # scale as least as possbile
135
+ if abs(1 - scale_width) < abs(1 - scale_height):
136
+ # fit width
137
+ scale_height = scale_width
138
+ else:
139
+ # fit height
140
+ scale_width = scale_height
141
+ else:
142
+ raise ValueError(
143
+ f"resize_method {self.__resize_method} not implemented"
144
+ )
145
+
146
+ if self.__resize_method == "lower_bound":
147
+ new_height = self.constrain_to_multiple_of(
148
+ scale_height * height, min_val=self.__height
149
+ )
150
+ new_width = self.constrain_to_multiple_of(
151
+ scale_width * width, min_val=self.__width
152
+ )
153
+ elif self.__resize_method == "upper_bound":
154
+ new_height = self.constrain_to_multiple_of(
155
+ scale_height * height, max_val=self.__height
156
+ )
157
+ new_width = self.constrain_to_multiple_of(
158
+ scale_width * width, max_val=self.__width
159
+ )
160
+ elif self.__resize_method == "minimal":
161
+ new_height = self.constrain_to_multiple_of(scale_height * height)
162
+ new_width = self.constrain_to_multiple_of(scale_width * width)
163
+ else:
164
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
165
+
166
+ return (new_width, new_height)
167
+
168
+ def __call__(self, sample):
169
+ width, height = self.get_size(
170
+ sample["image"].shape[1], sample["image"].shape[0]
171
+ )
172
+
173
+ # resize sample
174
+ sample["image"] = cv2.resize(
175
+ sample["image"],
176
+ (width, height),
177
+ interpolation=self.__image_interpolation_method,
178
+ )
179
+
180
+ if self.__resize_target:
181
+ if "disparity" in sample:
182
+ sample["disparity"] = cv2.resize(
183
+ sample["disparity"],
184
+ (width, height),
185
+ interpolation=cv2.INTER_NEAREST,
186
+ )
187
+
188
+ if "depth" in sample:
189
+ sample["depth"] = cv2.resize(
190
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
191
+ )
192
+
193
+ if "semseg_mask" in sample:
194
+ # sample["semseg_mask"] = cv2.resize(
195
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
196
+ # )
197
+ sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
198
+
199
+ if "mask" in sample:
200
+ sample["mask"] = cv2.resize(
201
+ sample["mask"].astype(np.float32),
202
+ (width, height),
203
+ interpolation=cv2.INTER_NEAREST,
204
+ )
205
+ # sample["mask"] = sample["mask"].astype(bool)
206
+
207
+ # print(sample['image'].shape, sample['depth'].shape)
208
+ return sample
209
+
210
+
211
+ class NormalizeImage(object):
212
+ """Normlize image by given mean and std.
213
+ """
214
+
215
+ def __init__(self, mean, std):
216
+ self.__mean = mean
217
+ self.__std = std
218
+
219
+ def __call__(self, sample):
220
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
221
+
222
+ return sample
223
+
224
+
225
+ class PrepareForNet(object):
226
+ """Prepare sample for usage as network input.
227
+ """
228
+
229
+ def __init__(self):
230
+ pass
231
+
232
+ def __call__(self, sample):
233
+ image = np.transpose(sample["image"], (2, 0, 1))
234
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
235
+
236
+ if "mask" in sample:
237
+ sample["mask"] = sample["mask"].astype(np.float32)
238
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
239
+
240
+ if "depth" in sample:
241
+ depth = sample["depth"].astype(np.float32)
242
+ sample["depth"] = np.ascontiguousarray(depth)
243
+
244
+ if "semseg_mask" in sample:
245
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
246
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
247
+
248
+ return sample
assets/example1/K.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0
2
+ 0.063
assets/example1/left.png ADDED

Git LFS Details

  • SHA256: f080f84d7b2e28ba110eea80f0504cad9390f2399d6ffde6cd1e668200c3ef48
  • Pointer size: 131 Bytes
  • Size of remote file: 719 kB
assets/example1/right.png ADDED

Git LFS Details

  • SHA256: 337991aa0b35417ae64d6f66819522c233a4455eacfd34c9dc114ad569ec50f4
  • Pointer size: 131 Bytes
  • Size of remote file: 720 kB
assets/example2/K.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ cam0=[1733.74 0 792.27; 0 1733.74 541.89; 0 0 1]
2
+ cam1=[1733.74 0 792.27; 0 1733.74 541.89; 0 0 1]
3
+ doffs=0
4
+ baseline=536.62
5
+ width=1920
6
+ height=1080
7
+ ndisp=170
8
+ vmin=55
9
+ vmax=142
assets/example2/left.png ADDED

Git LFS Details

  • SHA256: 280be6eac4b525eee6d49f0afd32c11ef0b83d2cad3e77e946fe525fda16a355
  • Pointer size: 132 Bytes
  • Size of remote file: 2.76 MB
assets/example2/right.png ADDED

Git LFS Details

  • SHA256: 97be2568394bae63e26bf62343bfd04adfb372c9c96710784745bd7130a0c7d8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.76 MB