Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8ca4dce
1
Parent(s):
8e11b47
added CREStereo and FoundationStereo code
Browse files- .gitattributes +1 -0
- .gitignore +3 -0
- CREStereo_demo/app.py +967 -0
- CREStereo_demo/app_local.py +889 -0
- CREStereo_demo/models/.gitkeep +0 -0
- CREStereo_demo/models/crestereo_eth3d.pth +3 -0
- CREStereo_demo/nets/__init__.py +1 -0
- CREStereo_demo/nets/attention/__init__.py +2 -0
- CREStereo_demo/nets/attention/linear_attention.py +81 -0
- CREStereo_demo/nets/attention/position_encoding.py +41 -0
- CREStereo_demo/nets/attention/transformer.py +100 -0
- CREStereo_demo/nets/corr.py +148 -0
- CREStereo_demo/nets/crestereo.py +258 -0
- CREStereo_demo/nets/extractor.py +123 -0
- CREStereo_demo/nets/update.py +91 -0
- CREStereo_demo/nets/utils/__init__.py +1 -0
- CREStereo_demo/nets/utils/utils.py +108 -0
- FoundationStereo_demo/Utils.py +160 -0
- FoundationStereo_demo/app.py +1138 -0
- FoundationStereo_demo/app_local.py +1292 -0
- FoundationStereo_demo/core/extractor.py +371 -0
- FoundationStereo_demo/core/foundation_stereo.py +277 -0
- FoundationStereo_demo/core/geometry.py +77 -0
- FoundationStereo_demo/core/submodule.py +588 -0
- FoundationStereo_demo/core/update.py +159 -0
- FoundationStereo_demo/core/utils/utils.py +64 -0
- FoundationStereo_demo/depth_anything/LICENSE.txt +201 -0
- FoundationStereo_demo/depth_anything/__init__.py +2 -0
- FoundationStereo_demo/depth_anything/blocks.py +153 -0
- FoundationStereo_demo/depth_anything/dpt.py +203 -0
- FoundationStereo_demo/depth_anything/util/transform.py +248 -0
- assets/example1/K.txt +2 -0
- assets/example1/left.png +3 -0
- assets/example1/right.png +3 -0
- assets/example2/K.txt +9 -0
- assets/example2/left.png +3 -0
- assets/example2/right.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.Identifier
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
CREStereo_demo/app.py
ADDED
|
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CREStereo Gradio Demo with ZeroGPU Integration
|
| 3 |
+
|
| 4 |
+
This demo showcases the CREStereo model for stereo depth estimation.
|
| 5 |
+
Optimized for Hugging Face Spaces with ZeroGPU support.
|
| 6 |
+
|
| 7 |
+
Key ZeroGPU optimizations:
|
| 8 |
+
- @spaces.GPU decorators for GPU-intensive functions
|
| 9 |
+
- CUDA operations only within GPU context
|
| 10 |
+
- Memory-efficient inference with cleanup
|
| 11 |
+
- Safe CUDA initialization patterns
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import logging
|
| 17 |
+
import tempfile
|
| 18 |
+
import gc
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional, Tuple, Union
|
| 21 |
+
import numpy as np
|
| 22 |
+
import cv2
|
| 23 |
+
import gradio as gr
|
| 24 |
+
import imageio
|
| 25 |
+
|
| 26 |
+
# Import spaces BEFORE torch to ensure proper ZeroGPU initialization
|
| 27 |
+
import spaces
|
| 28 |
+
|
| 29 |
+
# Import torch after spaces - avoid any CUDA calls during import
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
from torch.cuda.amp import autocast
|
| 34 |
+
|
| 35 |
+
# Completely avoid CUDA operations during import phase
|
| 36 |
+
# Do not set default tensor type or modify CUDA settings outside GPU context
|
| 37 |
+
# torch.set_default_tensor_type('torch.FloatTensor') # Commented out - causes CUDA init
|
| 38 |
+
|
| 39 |
+
# Do not modify CUDA settings during import - this can trigger CUDA initialization
|
| 40 |
+
# torch.backends.cudnn.enabled = False # Commented out
|
| 41 |
+
# torch.backends.cudnn.benchmark = False # Commented out
|
| 42 |
+
|
| 43 |
+
# Use current directory as base
|
| 44 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
+
base_dir = current_dir
|
| 46 |
+
|
| 47 |
+
# Add current directory to path for local imports
|
| 48 |
+
sys.path.insert(0, current_dir)
|
| 49 |
+
|
| 50 |
+
# Import local modules
|
| 51 |
+
from nets import Model
|
| 52 |
+
|
| 53 |
+
# Import Open3D with error handling
|
| 54 |
+
OPEN3D_AVAILABLE = False
|
| 55 |
+
try:
|
| 56 |
+
# Set Open3D to CPU mode to avoid CUDA initialization
|
| 57 |
+
os.environ['OPEN3D_CPU_RENDERING'] = '1'
|
| 58 |
+
# Don't import open3d here - do it inside functions
|
| 59 |
+
# import open3d as o3d
|
| 60 |
+
OPEN3D_AVAILABLE = True # Assume available, will check later
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logging.warning(f"Open3D setup failed: {e}")
|
| 63 |
+
OPEN3D_AVAILABLE = False
|
| 64 |
+
|
| 65 |
+
# Configure logging
|
| 66 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 67 |
+
|
| 68 |
+
# Model configuration
|
| 69 |
+
MODEL_VARIANTS = {
|
| 70 |
+
"crestereo_eth3d": {
|
| 71 |
+
"display_name": "CREStereo ETH3D (Pre-trained model)",
|
| 72 |
+
"model_file": "models/crestereo_eth3d.pth",
|
| 73 |
+
"max_disp": 256
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
# Global variables for model caching
|
| 78 |
+
_cached_model = None
|
| 79 |
+
_cached_device = None
|
| 80 |
+
_cached_model_selection = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class InputPadder:
|
| 84 |
+
""" Pads images such that dimensions are divisible by divis_by """
|
| 85 |
+
def __init__(self, dims, divis_by=8, force_square=False):
|
| 86 |
+
self.ht, self.wd = dims[-2:]
|
| 87 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
| 88 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
| 89 |
+
|
| 90 |
+
if force_square:
|
| 91 |
+
# Make the padded dimensions square
|
| 92 |
+
max_dim = max(self.ht + pad_ht, self.wd + pad_wd)
|
| 93 |
+
pad_ht = max_dim - self.ht
|
| 94 |
+
pad_wd = max_dim - self.wd
|
| 95 |
+
|
| 96 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 97 |
+
|
| 98 |
+
def pad(self, *inputs):
|
| 99 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
| 100 |
+
|
| 101 |
+
def unpad(self, x):
|
| 102 |
+
ht, wd = x.shape[-2:]
|
| 103 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 104 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def aggressive_cleanup():
|
| 108 |
+
"""Perform basic cleanup - no CUDA operations outside GPU context"""
|
| 109 |
+
import gc
|
| 110 |
+
gc.collect()
|
| 111 |
+
logging.info("Performed basic memory cleanup")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@spaces.GPU
|
| 115 |
+
def initialize_gpu_context():
|
| 116 |
+
"""Initialize GPU context safely for ZeroGPU"""
|
| 117 |
+
try:
|
| 118 |
+
# Set CUDA settings safely within GPU context
|
| 119 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
| 120 |
+
torch.backends.cudnn.enabled = True
|
| 121 |
+
torch.backends.cudnn.benchmark = True
|
| 122 |
+
|
| 123 |
+
# Check GPU availability and log info
|
| 124 |
+
if torch.cuda.is_available():
|
| 125 |
+
device_name = torch.cuda.get_device_name(0)
|
| 126 |
+
memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 127 |
+
logging.info(f"GPU initialized: {device_name}, Total memory: {memory_total:.2f}GB")
|
| 128 |
+
return True
|
| 129 |
+
else:
|
| 130 |
+
logging.error("CUDA not available after GPU context initialization")
|
| 131 |
+
return False
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logging.error(f"GPU context initialization failed: {e}")
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@spaces.GPU
|
| 138 |
+
def check_gpu_memory():
|
| 139 |
+
"""Check and log current GPU memory usage - only call within GPU context"""
|
| 140 |
+
try:
|
| 141 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 142 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 143 |
+
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 144 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 145 |
+
|
| 146 |
+
logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
|
| 147 |
+
return allocated, reserved, max_allocated, total
|
| 148 |
+
except RuntimeError as e:
|
| 149 |
+
logging.warning(f"Failed to get GPU memory info: {e}")
|
| 150 |
+
return None, None, None, None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_available_models() -> dict:
|
| 154 |
+
"""Get all available models with their display names"""
|
| 155 |
+
models = {}
|
| 156 |
+
|
| 157 |
+
# Check for local models
|
| 158 |
+
for variant, info in MODEL_VARIANTS.items():
|
| 159 |
+
model_path = os.path.join(current_dir, info["model_file"])
|
| 160 |
+
|
| 161 |
+
if os.path.exists(model_path):
|
| 162 |
+
display_name = info["display_name"]
|
| 163 |
+
models[display_name] = {
|
| 164 |
+
"model_path": model_path,
|
| 165 |
+
"variant": variant,
|
| 166 |
+
"max_disp": info["max_disp"],
|
| 167 |
+
"source": "local"
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
return models
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]:
|
| 174 |
+
"""Get model path and config from the selected model"""
|
| 175 |
+
models = get_available_models()
|
| 176 |
+
|
| 177 |
+
# Check if it's in our models dict
|
| 178 |
+
if model_selection in models:
|
| 179 |
+
model_info = models[model_selection]
|
| 180 |
+
logging.info(f"📁 Using local model: {model_selection}")
|
| 181 |
+
return model_info["model_path"], model_info
|
| 182 |
+
|
| 183 |
+
return None, None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@spaces.GPU
|
| 187 |
+
def load_model_for_inference(model_path: str, model_info: dict):
|
| 188 |
+
"""Load CREStereo model for inference temporarily (demo-style)"""
|
| 189 |
+
# Set CUDA settings safely within GPU context
|
| 190 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor') # Now safe to use CUDA tensors
|
| 191 |
+
torch.backends.cudnn.enabled = True
|
| 192 |
+
torch.backends.cudnn.benchmark = True
|
| 193 |
+
|
| 194 |
+
# Check if CUDA is available after ZeroGPU initialization
|
| 195 |
+
if not torch.cuda.is_available():
|
| 196 |
+
raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.")
|
| 197 |
+
|
| 198 |
+
# Use the first available CUDA device
|
| 199 |
+
device = torch.device("cuda")
|
| 200 |
+
|
| 201 |
+
# Set CUDA seed safely within GPU context
|
| 202 |
+
try:
|
| 203 |
+
random_seed = 0
|
| 204 |
+
torch.cuda.manual_seed_all(random_seed)
|
| 205 |
+
torch.backends.cudnn.deterministic = True
|
| 206 |
+
torch.backends.cudnn.benchmark = False
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logging.warning(f"Could not set CUDA seed: {e}")
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
# Create model
|
| 212 |
+
max_disp = model_info.get("max_disp", 256)
|
| 213 |
+
model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True)
|
| 214 |
+
|
| 215 |
+
# Load checkpoint
|
| 216 |
+
ckpt = torch.load(model_path, map_location=device)
|
| 217 |
+
model.load_state_dict(ckpt, strict=True)
|
| 218 |
+
model.to(device)
|
| 219 |
+
model.eval()
|
| 220 |
+
|
| 221 |
+
logging.info("Loaded CREStereo model weights")
|
| 222 |
+
|
| 223 |
+
# Memory optimizations
|
| 224 |
+
torch.set_grad_enabled(False)
|
| 225 |
+
logging.info("Applied memory optimizations")
|
| 226 |
+
|
| 227 |
+
return model, device
|
| 228 |
+
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logging.error(f"Model loading failed: {e}")
|
| 231 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_cached_model(model_selection: str):
|
| 235 |
+
"""Get cached model or load new one if selection changed"""
|
| 236 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 237 |
+
|
| 238 |
+
# Get model paths from selection
|
| 239 |
+
model_path, model_info = get_model_paths_from_selection(model_selection)
|
| 240 |
+
|
| 241 |
+
if model_path is None or model_info is None:
|
| 242 |
+
raise ValueError(f"Selected model not found: {model_selection}")
|
| 243 |
+
|
| 244 |
+
# Check if we need to reload the model
|
| 245 |
+
if (_cached_model is None or
|
| 246 |
+
_cached_model_selection != model_selection):
|
| 247 |
+
|
| 248 |
+
# Clear previous model if exists
|
| 249 |
+
if _cached_model is not None:
|
| 250 |
+
del _cached_model
|
| 251 |
+
torch.cuda.empty_cache()
|
| 252 |
+
gc.collect()
|
| 253 |
+
|
| 254 |
+
logging.info(f"🚀 Loading model: {model_selection}")
|
| 255 |
+
_cached_model, _cached_device = load_model_for_inference(model_path, model_info)
|
| 256 |
+
_cached_model_selection = model_selection
|
| 257 |
+
|
| 258 |
+
logging.info(f"✅ Model loaded successfully: {model_selection}")
|
| 259 |
+
else:
|
| 260 |
+
logging.info(f"✅ Using cached model: {model_selection}")
|
| 261 |
+
|
| 262 |
+
return _cached_model, _cached_device
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def clear_model_cache():
|
| 266 |
+
"""Clear the cached model to free memory"""
|
| 267 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 268 |
+
|
| 269 |
+
if _cached_model is not None:
|
| 270 |
+
logging.info("Clearing model cache...")
|
| 271 |
+
del _cached_model
|
| 272 |
+
_cached_model = None
|
| 273 |
+
_cached_device = None
|
| 274 |
+
_cached_model_selection = None
|
| 275 |
+
|
| 276 |
+
# Simple cleanup
|
| 277 |
+
import gc
|
| 278 |
+
gc.collect()
|
| 279 |
+
torch.cuda.empty_cache()
|
| 280 |
+
logging.info("Model cache cleared")
|
| 281 |
+
else:
|
| 282 |
+
logging.info("No model in cache to clear")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def inference(left, right, model, device, n_iter=20):
|
| 286 |
+
"""Run CREStereo inference on stereo pair"""
|
| 287 |
+
print("Model Forwarding...")
|
| 288 |
+
imgL = left.transpose(2, 0, 1)
|
| 289 |
+
imgR = right.transpose(2, 0, 1)
|
| 290 |
+
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
| 291 |
+
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
| 292 |
+
|
| 293 |
+
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
| 294 |
+
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
| 295 |
+
|
| 296 |
+
# Use InputPadder to handle any image size
|
| 297 |
+
padder = InputPadder(imgL.shape, divis_by=8)
|
| 298 |
+
imgL_padded, imgR_padded = padder.pad(imgL, imgR)
|
| 299 |
+
|
| 300 |
+
# Downsample for coarse prediction
|
| 301 |
+
imgL_dw2 = F.interpolate(
|
| 302 |
+
imgL_padded,
|
| 303 |
+
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
|
| 304 |
+
mode="bilinear",
|
| 305 |
+
align_corners=True,
|
| 306 |
+
)
|
| 307 |
+
imgR_dw2 = F.interpolate(
|
| 308 |
+
imgR_padded,
|
| 309 |
+
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
|
| 310 |
+
mode="bilinear",
|
| 311 |
+
align_corners=True,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
with torch.inference_mode():
|
| 315 |
+
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
|
| 316 |
+
pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2)
|
| 317 |
+
|
| 318 |
+
# Unpad the result to original dimensions
|
| 319 |
+
pred_flow = padder.unpad(pred_flow)
|
| 320 |
+
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
| 321 |
+
|
| 322 |
+
return pred_disp
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def vis_disparity(disparity_map, max_val=None):
|
| 326 |
+
"""Visualize disparity map"""
|
| 327 |
+
if max_val is None:
|
| 328 |
+
disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0
|
| 329 |
+
else:
|
| 330 |
+
disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255)
|
| 331 |
+
|
| 332 |
+
disp_vis = disp_vis.astype("uint8")
|
| 333 |
+
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
| 334 |
+
disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB)
|
| 335 |
+
return disp_vis
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# Fixed with static duration
|
| 339 |
+
@spaces.GPU(duration=60) # Static 60 seconds for basic processing
|
| 340 |
+
def process_stereo_pair(model_selection: str, left_image: str, right_image: str,
|
| 341 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
|
| 342 |
+
"""
|
| 343 |
+
Main processing function for stereo pair (with model caching)
|
| 344 |
+
"""
|
| 345 |
+
logging.info("Starting stereo pair processing...")
|
| 346 |
+
|
| 347 |
+
if left_image is None or right_image is None:
|
| 348 |
+
return None, "❌ Please upload both left and right images."
|
| 349 |
+
|
| 350 |
+
# Convert image paths to numpy arrays
|
| 351 |
+
logging.info(f"Loading images: left={left_image}, right={right_image}")
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
# Load left image
|
| 355 |
+
if not os.path.exists(left_image):
|
| 356 |
+
logging.error(f"Left image file does not exist: {left_image}")
|
| 357 |
+
return None, f"❌ Left image file not found: {left_image}"
|
| 358 |
+
|
| 359 |
+
logging.info(f"Loading left image from: {left_image}")
|
| 360 |
+
left_img = cv2.imread(left_image)
|
| 361 |
+
if left_img is not None:
|
| 362 |
+
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
|
| 363 |
+
else:
|
| 364 |
+
# Try with imageio as fallback
|
| 365 |
+
left_img = imageio.imread(left_image)
|
| 366 |
+
if len(left_img.shape) == 3 and left_img.shape[2] == 4:
|
| 367 |
+
left_img = left_img[:, :, :3]
|
| 368 |
+
|
| 369 |
+
# Load right image
|
| 370 |
+
if not os.path.exists(right_image):
|
| 371 |
+
logging.error(f"Right image file does not exist: {right_image}")
|
| 372 |
+
return None, f"❌ Right image file not found: {right_image}"
|
| 373 |
+
|
| 374 |
+
logging.info(f"Loading right image from: {right_image}")
|
| 375 |
+
right_img = cv2.imread(right_image)
|
| 376 |
+
if right_img is not None:
|
| 377 |
+
right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
|
| 378 |
+
else:
|
| 379 |
+
# Try with imageio as fallback
|
| 380 |
+
right_img = imageio.imread(right_image)
|
| 381 |
+
if len(right_img.shape) == 3 and right_img.shape[2] == 4:
|
| 382 |
+
right_img = right_img[:, :, :3]
|
| 383 |
+
|
| 384 |
+
logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}")
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logging.error(f"Failed to load images: {e}")
|
| 388 |
+
return None, f"❌ Failed to load images: {str(e)}"
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
# Get cached model
|
| 392 |
+
variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
|
| 393 |
+
progress(0.1, desc=f"Loading cached model ({variant_name})...")
|
| 394 |
+
logging.info("🚀 Getting cached model...")
|
| 395 |
+
model, device = get_cached_model(model_selection)
|
| 396 |
+
logging.info("✅ Cached model loaded successfully")
|
| 397 |
+
|
| 398 |
+
progress(0.2, desc="Preprocessing images...")
|
| 399 |
+
|
| 400 |
+
# Validate input images
|
| 401 |
+
if left_img.shape != right_img.shape:
|
| 402 |
+
return None, "❌ Left and right images must have the same dimensions."
|
| 403 |
+
|
| 404 |
+
H, W = left_img.shape[:2]
|
| 405 |
+
|
| 406 |
+
progress(0.5, desc="Running inference...")
|
| 407 |
+
|
| 408 |
+
# Process stereo pair
|
| 409 |
+
torch.cuda.empty_cache() # Clear any cached memory before inference
|
| 410 |
+
|
| 411 |
+
disp_cpu = inference(left_img, right_img, model, device, n_iter=20)
|
| 412 |
+
|
| 413 |
+
progress(0.8, desc="Creating visualization...")
|
| 414 |
+
|
| 415 |
+
# Create visualization
|
| 416 |
+
disparity_vis = vis_disparity(disp_cpu)
|
| 417 |
+
result_image = disparity_vis
|
| 418 |
+
|
| 419 |
+
progress(1.0, desc="Complete!")
|
| 420 |
+
|
| 421 |
+
# Create status message
|
| 422 |
+
valid_mask = ~np.isinf(disp_cpu)
|
| 423 |
+
min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
|
| 424 |
+
max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
|
| 425 |
+
mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
|
| 426 |
+
|
| 427 |
+
# Get model variant for status
|
| 428 |
+
variant = variant_name
|
| 429 |
+
|
| 430 |
+
# Check current memory usage
|
| 431 |
+
try:
|
| 432 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 433 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 434 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 435 |
+
except:
|
| 436 |
+
memory_info = ""
|
| 437 |
+
|
| 438 |
+
status = f"""✅ Processing successful!
|
| 439 |
+
🔧 Model: {variant}{memory_info}
|
| 440 |
+
📊 Disparity Statistics:
|
| 441 |
+
• Range: {min_disp:.2f} - {max_disp:.2f}
|
| 442 |
+
• Mean: {mean_disp:.2f}
|
| 443 |
+
• Input size: {W}×{H}
|
| 444 |
+
• Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
|
| 445 |
+
|
| 446 |
+
return result_image, status
|
| 447 |
+
|
| 448 |
+
except Exception as e:
|
| 449 |
+
logging.error(f"Processing failed: {e}")
|
| 450 |
+
# Clean up GPU memory
|
| 451 |
+
torch.cuda.empty_cache()
|
| 452 |
+
gc.collect()
|
| 453 |
+
return None, f"❌ Error: {str(e)}"
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
# Fixed with static duration
|
| 457 |
+
@spaces.GPU(duration=120) # Static 120 seconds for depth processing
|
| 458 |
+
def process_with_depth(model_selection: str, left_image: str, right_image: str,
|
| 459 |
+
camera_matrix: str, baseline: float,
|
| 460 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
|
| 461 |
+
"""
|
| 462 |
+
Process stereo pair and generate depth map and point cloud (with model caching)
|
| 463 |
+
"""
|
| 464 |
+
# Import Open3D
|
| 465 |
+
global OPEN3D_AVAILABLE
|
| 466 |
+
try:
|
| 467 |
+
import open3d as o3d
|
| 468 |
+
OPEN3D_AVAILABLE = True
|
| 469 |
+
except ImportError as e:
|
| 470 |
+
logging.warning(f"Open3D not available: {e}")
|
| 471 |
+
OPEN3D_AVAILABLE = False
|
| 472 |
+
return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
|
| 473 |
+
|
| 474 |
+
if left_image is None or right_image is None:
|
| 475 |
+
return None, None, None, "❌ Please upload both left and right images."
|
| 476 |
+
|
| 477 |
+
try:
|
| 478 |
+
progress(0.1, desc="Parsing camera parameters...")
|
| 479 |
+
|
| 480 |
+
# Parse camera matrix
|
| 481 |
+
try:
|
| 482 |
+
K_values = list(map(float, camera_matrix.strip().split()))
|
| 483 |
+
if len(K_values) != 9:
|
| 484 |
+
return None, None, None, "❌ Camera matrix must contain exactly 9 values."
|
| 485 |
+
K = np.array(K_values).reshape(3, 3)
|
| 486 |
+
except ValueError:
|
| 487 |
+
return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
|
| 488 |
+
|
| 489 |
+
if baseline <= 0:
|
| 490 |
+
return None, None, None, "❌ Baseline must be positive."
|
| 491 |
+
|
| 492 |
+
# First get disparity using the same process as basic function
|
| 493 |
+
disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress)
|
| 494 |
+
|
| 495 |
+
if disparity_result is None:
|
| 496 |
+
return None, None, None, status
|
| 497 |
+
|
| 498 |
+
# Load images again for depth processing
|
| 499 |
+
left_img = cv2.imread(left_image)
|
| 500 |
+
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
|
| 501 |
+
|
| 502 |
+
# Get disparity from model again (we need the raw values, not the visualization)
|
| 503 |
+
model, device = get_cached_model(model_selection)
|
| 504 |
+
disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20)
|
| 505 |
+
|
| 506 |
+
progress(0.6, desc="Converting to depth...")
|
| 507 |
+
|
| 508 |
+
# Remove invisible points
|
| 509 |
+
H, W = disp_cpu.shape
|
| 510 |
+
yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
|
| 511 |
+
us_right = xx - disp_cpu
|
| 512 |
+
invalid = us_right < 0
|
| 513 |
+
disp_cpu[invalid] = np.inf
|
| 514 |
+
|
| 515 |
+
# Convert to depth using the formula: depth = focal_length * baseline / disparity
|
| 516 |
+
depth = K[0, 0] * baseline / disp_cpu
|
| 517 |
+
|
| 518 |
+
# Visualize depth
|
| 519 |
+
depth_vis = vis_disparity(depth, max_val=10.0)
|
| 520 |
+
|
| 521 |
+
progress(0.8, desc="Generating point cloud...")
|
| 522 |
+
|
| 523 |
+
# Generate point cloud
|
| 524 |
+
fx, fy = K[0, 0], K[1, 1]
|
| 525 |
+
cx, cy = K[0, 2], K[1, 2]
|
| 526 |
+
|
| 527 |
+
# Create coordinate meshgrids
|
| 528 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 529 |
+
|
| 530 |
+
# Convert to 3D coordinates
|
| 531 |
+
valid_depth = ~np.isinf(depth)
|
| 532 |
+
z = depth[valid_depth] # Z coordinate (depth)
|
| 533 |
+
x = (u[valid_depth] - cx) * z / fx # X coordinate
|
| 534 |
+
y = (v[valid_depth] - cy) * z / fy # Y coordinate
|
| 535 |
+
|
| 536 |
+
# Stack coordinates (X, Y, Z)
|
| 537 |
+
points = np.stack([x, y, z], axis=-1)
|
| 538 |
+
|
| 539 |
+
# Get corresponding colors
|
| 540 |
+
colors = left_img[valid_depth]
|
| 541 |
+
|
| 542 |
+
# Filter points by depth range
|
| 543 |
+
depth_mask = (z > 0) & (z <= 10.0)
|
| 544 |
+
valid_points = points[depth_mask]
|
| 545 |
+
valid_colors = colors[depth_mask]
|
| 546 |
+
|
| 547 |
+
if len(valid_points) == 0:
|
| 548 |
+
return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
|
| 549 |
+
|
| 550 |
+
# Subsample points for better performance
|
| 551 |
+
if len(valid_points) > 100000:
|
| 552 |
+
indices = np.random.choice(len(valid_points), 100000, replace=False)
|
| 553 |
+
valid_points = valid_points[indices]
|
| 554 |
+
valid_colors = valid_colors[indices]
|
| 555 |
+
|
| 556 |
+
# Transform coordinates for proper visualization
|
| 557 |
+
transformed_points = valid_points.copy()
|
| 558 |
+
transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
|
| 559 |
+
transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
|
| 560 |
+
|
| 561 |
+
# Generate point cloud
|
| 562 |
+
pcd = o3d.geometry.PointCloud()
|
| 563 |
+
pcd.points = o3d.utility.Vector3dVector(transformed_points)
|
| 564 |
+
pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
|
| 565 |
+
|
| 566 |
+
progress(1.0, desc="Complete!")
|
| 567 |
+
|
| 568 |
+
# Check current memory usage
|
| 569 |
+
try:
|
| 570 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 571 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 572 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 573 |
+
except:
|
| 574 |
+
memory_info = ""
|
| 575 |
+
|
| 576 |
+
variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
|
| 577 |
+
|
| 578 |
+
status = f"""✅ Depth processing successful!
|
| 579 |
+
🔧 Model: {variant}{memory_info}
|
| 580 |
+
📊 Statistics:
|
| 581 |
+
• Valid points: {len(valid_points):,}
|
| 582 |
+
• Depth range: {z.min():.2f} - {z.max():.2f} m
|
| 583 |
+
• Baseline: {baseline} m
|
| 584 |
+
• Point cloud generated with {len(valid_points)} points
|
| 585 |
+
• 3D visualization available"""
|
| 586 |
+
|
| 587 |
+
return depth_vis, None, None, status
|
| 588 |
+
|
| 589 |
+
except Exception as e:
|
| 590 |
+
logging.error(f"Depth processing failed: {e}")
|
| 591 |
+
torch.cuda.empty_cache()
|
| 592 |
+
gc.collect()
|
| 593 |
+
return None, None, None, f"❌ Error: {str(e)}"
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def create_app() -> gr.Blocks:
|
| 597 |
+
"""Create the Gradio application"""
|
| 598 |
+
|
| 599 |
+
# Get available models
|
| 600 |
+
try:
|
| 601 |
+
available_models = get_available_models()
|
| 602 |
+
logging.info(f"Successfully got available models: {len(available_models)} found")
|
| 603 |
+
except Exception as e:
|
| 604 |
+
logging.error(f"Failed to get available models: {e}")
|
| 605 |
+
available_models = {}
|
| 606 |
+
|
| 607 |
+
with gr.Blocks(
|
| 608 |
+
title="CREStereo - Stereo Depth Estimation",
|
| 609 |
+
theme=gr.themes.Soft(),
|
| 610 |
+
css="footer {visibility: hidden}",
|
| 611 |
+
delete_cache=(60, 60)
|
| 612 |
+
) as app:
|
| 613 |
+
|
| 614 |
+
gr.Markdown("""
|
| 615 |
+
# 🔍 CREStereo: Practical Stereo Matching
|
| 616 |
+
|
| 617 |
+
Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo.
|
| 618 |
+
|
| 619 |
+
⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
|
| 620 |
+
⚡ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference.
|
| 621 |
+
""")
|
| 622 |
+
|
| 623 |
+
# Instructions section
|
| 624 |
+
with gr.Accordion("📋 Instructions", open=False):
|
| 625 |
+
gr.Markdown("""
|
| 626 |
+
## 🚀 How to Use This Demo
|
| 627 |
+
|
| 628 |
+
### 🖼️ Input Requirements
|
| 629 |
+
1. **Image Format**: Upload images in JPEG or PNG format.
|
| 630 |
+
2. **Image Size**: Images should be of the same size and resolution.
|
| 631 |
+
3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
|
| 632 |
+
4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance.
|
| 633 |
+
|
| 634 |
+
### 📊 Using the Demo
|
| 635 |
+
1. **Select Model**: Choose the CREStereo model variant
|
| 636 |
+
2. **Upload Images**: Provide rectified stereo image pairs
|
| 637 |
+
3. **Basic Processing**: Get disparity visualization
|
| 638 |
+
4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
|
| 639 |
+
|
| 640 |
+
### 📖 Original Work
|
| 641 |
+
This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network.
|
| 642 |
+
- **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483)
|
| 643 |
+
- **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo)
|
| 644 |
+
""")
|
| 645 |
+
|
| 646 |
+
# Model selection
|
| 647 |
+
with gr.Row():
|
| 648 |
+
all_choices = list(available_models.keys())
|
| 649 |
+
|
| 650 |
+
if not all_choices:
|
| 651 |
+
all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"]
|
| 652 |
+
|
| 653 |
+
default_model = all_choices[0] if all_choices else None
|
| 654 |
+
|
| 655 |
+
model_selector = gr.Dropdown(
|
| 656 |
+
choices=all_choices,
|
| 657 |
+
value=default_model,
|
| 658 |
+
label="🎯 Select Model",
|
| 659 |
+
info="Choose the CREStereo model variant.",
|
| 660 |
+
interactive=True
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
with gr.Tabs():
|
| 664 |
+
# Basic stereo processing tab
|
| 665 |
+
with gr.TabItem("🖼️ Basic Stereo Processing"):
|
| 666 |
+
with gr.Row():
|
| 667 |
+
with gr.Column():
|
| 668 |
+
left_input = gr.Image(
|
| 669 |
+
label="📷 Left Image",
|
| 670 |
+
type="filepath",
|
| 671 |
+
height=300
|
| 672 |
+
)
|
| 673 |
+
right_input = gr.Image(
|
| 674 |
+
label="📷 Right Image",
|
| 675 |
+
type="filepath",
|
| 676 |
+
height=300
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
process_btn = gr.Button(
|
| 680 |
+
"🚀 Process Stereo Pair",
|
| 681 |
+
variant="primary",
|
| 682 |
+
size="lg"
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
with gr.Column():
|
| 686 |
+
output_image = gr.Image(
|
| 687 |
+
label="📊 Disparity Visualization",
|
| 688 |
+
height=400
|
| 689 |
+
)
|
| 690 |
+
status_text = gr.Textbox(
|
| 691 |
+
label="Status",
|
| 692 |
+
interactive=False,
|
| 693 |
+
lines=8
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Example images
|
| 697 |
+
examples_list = []
|
| 698 |
+
|
| 699 |
+
# Example 1
|
| 700 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 701 |
+
examples_list.append([
|
| 702 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 703 |
+
os.path.join(current_dir, "assets", "example1", "right.png")
|
| 704 |
+
])
|
| 705 |
+
|
| 706 |
+
# Example 2
|
| 707 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 708 |
+
examples_list.append([
|
| 709 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 710 |
+
os.path.join(current_dir, "assets", "example2", "right.png")
|
| 711 |
+
])
|
| 712 |
+
|
| 713 |
+
if examples_list:
|
| 714 |
+
gr.Examples(
|
| 715 |
+
examples=examples_list,
|
| 716 |
+
inputs=[left_input, right_input],
|
| 717 |
+
label="📋 Example Images"
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
# Advanced processing with depth
|
| 721 |
+
with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
|
| 722 |
+
with gr.Row():
|
| 723 |
+
with gr.Column():
|
| 724 |
+
left_input_adv = gr.Image(
|
| 725 |
+
label="📷 Left Image",
|
| 726 |
+
type="filepath",
|
| 727 |
+
height=250
|
| 728 |
+
)
|
| 729 |
+
right_input_adv = gr.Image(
|
| 730 |
+
label="📷 Right Image",
|
| 731 |
+
type="filepath",
|
| 732 |
+
height=250
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Camera parameters
|
| 736 |
+
with gr.Group():
|
| 737 |
+
gr.Markdown("### 📹 Camera Parameters")
|
| 738 |
+
camera_matrix_input = gr.Textbox(
|
| 739 |
+
label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
|
| 740 |
+
value="",
|
| 741 |
+
)
|
| 742 |
+
baseline_input = gr.Number(
|
| 743 |
+
label="Baseline (meters)",
|
| 744 |
+
value=None,
|
| 745 |
+
minimum=0.001,
|
| 746 |
+
maximum=10.0,
|
| 747 |
+
step=0.001
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
process_depth_btn = gr.Button(
|
| 751 |
+
"🔬 Process with Depth",
|
| 752 |
+
variant="primary",
|
| 753 |
+
size="lg"
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
with gr.Column():
|
| 757 |
+
depth_output = gr.Image(
|
| 758 |
+
label="📏 Depth Visualization",
|
| 759 |
+
height=300
|
| 760 |
+
)
|
| 761 |
+
pointcloud_output = gr.File(
|
| 762 |
+
label="☁️ Point Cloud Download (.ply)",
|
| 763 |
+
file_types=[".ply"]
|
| 764 |
+
)
|
| 765 |
+
status_depth = gr.Textbox(
|
| 766 |
+
label="Status",
|
| 767 |
+
interactive=False,
|
| 768 |
+
lines=6
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# 3D Point Cloud Visualization
|
| 772 |
+
with gr.Row():
|
| 773 |
+
pointcloud_3d = gr.Model3D(
|
| 774 |
+
label="🌐 3D Point Cloud Viewer",
|
| 775 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 776 |
+
height=400
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Example images for advanced processing
|
| 780 |
+
examples_advanced_list = []
|
| 781 |
+
|
| 782 |
+
# Try to read camera parameters from K.txt files
|
| 783 |
+
# Example 1
|
| 784 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 785 |
+
k_file = os.path.join(current_dir, "assets", "example1", "K.txt")
|
| 786 |
+
camera_matrix_str = ""
|
| 787 |
+
baseline_val = 0.063 # default
|
| 788 |
+
|
| 789 |
+
if os.path.exists(k_file):
|
| 790 |
+
try:
|
| 791 |
+
with open(k_file, 'r') as f:
|
| 792 |
+
lines = f.readlines()
|
| 793 |
+
if len(lines) >= 1:
|
| 794 |
+
camera_matrix_str = lines[0].strip()
|
| 795 |
+
if len(lines) >= 2:
|
| 796 |
+
baseline_val = float(lines[1].strip())
|
| 797 |
+
except:
|
| 798 |
+
camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0"
|
| 799 |
+
|
| 800 |
+
examples_advanced_list.append([
|
| 801 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 802 |
+
os.path.join(current_dir, "assets", "example1", "right.png"),
|
| 803 |
+
camera_matrix_str,
|
| 804 |
+
baseline_val
|
| 805 |
+
])
|
| 806 |
+
|
| 807 |
+
# Example 2
|
| 808 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 809 |
+
k_file = os.path.join(current_dir, "assets", "example2", "K.txt")
|
| 810 |
+
camera_matrix_str = ""
|
| 811 |
+
baseline_val = 0.537 # default
|
| 812 |
+
|
| 813 |
+
if os.path.exists(k_file):
|
| 814 |
+
try:
|
| 815 |
+
with open(k_file, 'r') as f:
|
| 816 |
+
lines = f.readlines()
|
| 817 |
+
if len(lines) >= 1:
|
| 818 |
+
camera_matrix_str = lines[0].strip()
|
| 819 |
+
if len(lines) >= 2:
|
| 820 |
+
baseline_val = float(lines[1].strip())
|
| 821 |
+
except:
|
| 822 |
+
camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0"
|
| 823 |
+
|
| 824 |
+
examples_advanced_list.append([
|
| 825 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 826 |
+
os.path.join(current_dir, "assets", "example2", "right.png"),
|
| 827 |
+
camera_matrix_str,
|
| 828 |
+
baseline_val
|
| 829 |
+
])
|
| 830 |
+
|
| 831 |
+
if examples_advanced_list:
|
| 832 |
+
gr.Examples(
|
| 833 |
+
examples=examples_advanced_list,
|
| 834 |
+
inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 835 |
+
label="📋 Example Images with Camera Parameters"
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# Event handlers
|
| 839 |
+
if available_models:
|
| 840 |
+
process_btn.click(
|
| 841 |
+
fn=process_stereo_pair,
|
| 842 |
+
inputs=[model_selector, left_input, right_input],
|
| 843 |
+
outputs=[output_image, status_text],
|
| 844 |
+
show_progress=True
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
if OPEN3D_AVAILABLE:
|
| 848 |
+
process_depth_btn.click(
|
| 849 |
+
fn=process_with_depth,
|
| 850 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 851 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
|
| 852 |
+
show_progress=True
|
| 853 |
+
)
|
| 854 |
+
else:
|
| 855 |
+
process_depth_btn.click(
|
| 856 |
+
fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
|
| 857 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 858 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
|
| 859 |
+
)
|
| 860 |
+
else:
|
| 861 |
+
# No models available
|
| 862 |
+
process_btn.click(
|
| 863 |
+
fn=lambda *args: (None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
|
| 864 |
+
inputs=[model_selector, left_input, right_input],
|
| 865 |
+
outputs=[output_image, status_text]
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
process_depth_btn.click(
|
| 869 |
+
fn=lambda *args: (None, None, None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
|
| 870 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 871 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# Citation section at the bottom
|
| 875 |
+
with gr.Accordion("📖 Citation", open=False):
|
| 876 |
+
gr.Markdown("""
|
| 877 |
+
### 📄 Please Cite the Original Paper
|
| 878 |
+
|
| 879 |
+
If you use this work in your research, please cite:
|
| 880 |
+
|
| 881 |
+
```bibtex
|
| 882 |
+
@article{li2022practical,
|
| 883 |
+
title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation},
|
| 884 |
+
author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng},
|
| 885 |
+
journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 886 |
+
pages={16263--16272},
|
| 887 |
+
year={2022}
|
| 888 |
+
}
|
| 889 |
+
```
|
| 890 |
+
""")
|
| 891 |
+
|
| 892 |
+
# Footer
|
| 893 |
+
gr.Markdown("""
|
| 894 |
+
---
|
| 895 |
+
### 📝 Notes:
|
| 896 |
+
- **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
|
| 897 |
+
- **⚡ GPU Acceleration**: Requires CUDA-compatible GPU
|
| 898 |
+
- **📦 Model Caching**: Models are cached for efficient repeated usage
|
| 899 |
+
- For best results, use high-quality rectified stereo pairs
|
| 900 |
+
- Model works on RGB images and supports various resolutions
|
| 901 |
+
|
| 902 |
+
### 🔗 References:
|
| 903 |
+
- [CREStereo Paper](https://arxiv.org/abs/2203.11483)
|
| 904 |
+
- [Original GitHub Repository](https://github.com/megvii-research/CREStereo)
|
| 905 |
+
- [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch)
|
| 906 |
+
""")
|
| 907 |
+
|
| 908 |
+
return app
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def main():
|
| 912 |
+
"""Main function to launch the app"""
|
| 913 |
+
|
| 914 |
+
# Ensure no CUDA operations during startup
|
| 915 |
+
if torch.cuda.is_available():
|
| 916 |
+
logging.warning("CUDA detected during startup - this should not happen in ZeroGPU")
|
| 917 |
+
|
| 918 |
+
logging.info("🚀 Starting CREStereo Gradio App...")
|
| 919 |
+
|
| 920 |
+
# Parse command line arguments
|
| 921 |
+
import argparse
|
| 922 |
+
parser = argparse.ArgumentParser(description="CREStereo Gradio App")
|
| 923 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 924 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
| 925 |
+
parser.add_argument("--share", action="store_true", help="Create shareable link")
|
| 926 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 927 |
+
|
| 928 |
+
args = parser.parse_args()
|
| 929 |
+
|
| 930 |
+
if args.debug:
|
| 931 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 932 |
+
|
| 933 |
+
try:
|
| 934 |
+
# Create and launch app
|
| 935 |
+
logging.info("Creating Gradio app...")
|
| 936 |
+
app = create_app()
|
| 937 |
+
logging.info("✅ Gradio app created successfully")
|
| 938 |
+
|
| 939 |
+
logging.info(f"Launching app on {args.host}:{args.port}")
|
| 940 |
+
if args.share:
|
| 941 |
+
logging.info("Share link will be created")
|
| 942 |
+
|
| 943 |
+
# For ZeroGPU compatibility, launch with appropriate settings
|
| 944 |
+
app.launch(
|
| 945 |
+
server_name=args.host,
|
| 946 |
+
server_port=args.port,
|
| 947 |
+
share=args.share,
|
| 948 |
+
show_error=True,
|
| 949 |
+
favicon_path=None,
|
| 950 |
+
ssr_mode=False, # Disable SSR for ZeroGPU compatibility
|
| 951 |
+
allowed_paths=["./"] # Allow access to local files
|
| 952 |
+
)
|
| 953 |
+
except Exception as e:
|
| 954 |
+
logging.error(f"Failed to launch app: {e}")
|
| 955 |
+
raise
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
if __name__ == "__main__":
|
| 959 |
+
# Additional safety check for ZeroGPU environment
|
| 960 |
+
if 'SPACE_ID' in os.environ:
|
| 961 |
+
logging.info("Running in Hugging Face Spaces environment")
|
| 962 |
+
|
| 963 |
+
# Do not check CUDA status during startup - this can trigger CUDA initialization
|
| 964 |
+
# The CUDA status will be checked inside the @spaces.GPU decorated functions
|
| 965 |
+
logging.info("✅ CUDA status will be checked within GPU-decorated functions")
|
| 966 |
+
|
| 967 |
+
main()
|
CREStereo_demo/app_local.py
ADDED
|
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import tempfile
|
| 5 |
+
import gc
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional, Tuple, Union
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import imageio
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
# Set default tensor type if needed
|
| 17 |
+
# torch.set_default_tensor_type('torch.FloatTensor')
|
| 18 |
+
|
| 19 |
+
# CUDA backend settings
|
| 20 |
+
# torch.backends.cudnn.enabled = False
|
| 21 |
+
# torch.backends.cudnn.benchmark = False
|
| 22 |
+
|
| 23 |
+
# Use current directory as base
|
| 24 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
base_dir = current_dir
|
| 26 |
+
|
| 27 |
+
# Add current directory to path for local imports
|
| 28 |
+
sys.path.insert(0, current_dir)
|
| 29 |
+
|
| 30 |
+
# Import local modules
|
| 31 |
+
from nets import Model
|
| 32 |
+
|
| 33 |
+
# Import Open3D with error handling
|
| 34 |
+
OPEN3D_AVAILABLE = False
|
| 35 |
+
try:
|
| 36 |
+
# Set Open3D to CPU mode to avoid CUDA initialization
|
| 37 |
+
os.environ['OPEN3D_CPU_RENDERING'] = '1'
|
| 38 |
+
# Don't import open3d here - do it inside functions
|
| 39 |
+
# import open3d as o3d
|
| 40 |
+
OPEN3D_AVAILABLE = True # Assume available, will check later
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logging.warning(f"Open3D setup failed: {e}")
|
| 43 |
+
OPEN3D_AVAILABLE = False
|
| 44 |
+
|
| 45 |
+
# Configure logging
|
| 46 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 47 |
+
|
| 48 |
+
# Model configuration
|
| 49 |
+
MODEL_VARIANTS = {
|
| 50 |
+
"crestereo_eth3d": {
|
| 51 |
+
"display_name": "CREStereo ETH3D (Pre-trained model)",
|
| 52 |
+
"model_file": "models/crestereo_eth3d.pth",
|
| 53 |
+
"max_disp": 256
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Global variables for model caching
|
| 58 |
+
_cached_model = None
|
| 59 |
+
_cached_device = None
|
| 60 |
+
_cached_model_selection = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class InputPadder:
|
| 64 |
+
""" Pads images such that dimensions are divisible by divis_by """
|
| 65 |
+
def __init__(self, dims, divis_by=8, force_square=False):
|
| 66 |
+
self.ht, self.wd = dims[-2:]
|
| 67 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
| 68 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
| 69 |
+
|
| 70 |
+
if force_square:
|
| 71 |
+
# Make the padded dimensions square
|
| 72 |
+
max_dim = max(self.ht + pad_ht, self.wd + pad_wd)
|
| 73 |
+
pad_ht = max_dim - self.ht
|
| 74 |
+
pad_wd = max_dim - self.wd
|
| 75 |
+
|
| 76 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 77 |
+
|
| 78 |
+
def pad(self, *inputs):
|
| 79 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
| 80 |
+
|
| 81 |
+
def unpad(self, x):
|
| 82 |
+
ht, wd = x.shape[-2:]
|
| 83 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 84 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def aggressive_cleanup():
|
| 88 |
+
"""Perform basic cleanup"""
|
| 89 |
+
import gc
|
| 90 |
+
gc.collect()
|
| 91 |
+
logging.info("Performed basic memory cleanup")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def check_gpu_memory():
|
| 95 |
+
"""Check and log current GPU memory usage"""
|
| 96 |
+
try:
|
| 97 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 98 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 99 |
+
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 100 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 101 |
+
|
| 102 |
+
logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
|
| 103 |
+
return allocated, reserved, max_allocated, total
|
| 104 |
+
except RuntimeError as e:
|
| 105 |
+
logging.warning(f"Failed to get GPU memory info: {e}")
|
| 106 |
+
return None, None, None, None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_available_models() -> dict:
|
| 110 |
+
"""Get all available models with their display names"""
|
| 111 |
+
models = {}
|
| 112 |
+
|
| 113 |
+
# Check for local models
|
| 114 |
+
for variant, info in MODEL_VARIANTS.items():
|
| 115 |
+
model_path = os.path.join(current_dir, info["model_file"])
|
| 116 |
+
|
| 117 |
+
if os.path.exists(model_path):
|
| 118 |
+
display_name = info["display_name"]
|
| 119 |
+
models[display_name] = {
|
| 120 |
+
"model_path": model_path,
|
| 121 |
+
"variant": variant,
|
| 122 |
+
"max_disp": info["max_disp"],
|
| 123 |
+
"source": "local"
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
return models
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]:
|
| 130 |
+
"""Get model path and config from the selected model"""
|
| 131 |
+
models = get_available_models()
|
| 132 |
+
|
| 133 |
+
# Check if it's in our models dict
|
| 134 |
+
if model_selection in models:
|
| 135 |
+
model_info = models[model_selection]
|
| 136 |
+
logging.info(f"📁 Using local model: {model_selection}")
|
| 137 |
+
return model_info["model_path"], model_info
|
| 138 |
+
|
| 139 |
+
return None, None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def load_model_for_inference(model_path: str, model_info: dict):
|
| 143 |
+
"""Load CREStereo model for inference"""
|
| 144 |
+
# Check if CUDA is available
|
| 145 |
+
if not torch.cuda.is_available():
|
| 146 |
+
raise RuntimeError("CUDA is not available.")
|
| 147 |
+
|
| 148 |
+
# Use the first available CUDA device
|
| 149 |
+
device = torch.device("cuda")
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Create model
|
| 153 |
+
max_disp = model_info.get("max_disp", 256)
|
| 154 |
+
model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True)
|
| 155 |
+
|
| 156 |
+
# Load checkpoint
|
| 157 |
+
ckpt = torch.load(model_path, map_location=device)
|
| 158 |
+
model.load_state_dict(ckpt, strict=True)
|
| 159 |
+
model.to(device)
|
| 160 |
+
model.eval()
|
| 161 |
+
|
| 162 |
+
logging.info("Loaded CREStereo model weights")
|
| 163 |
+
|
| 164 |
+
# Memory optimizations
|
| 165 |
+
torch.set_grad_enabled(False)
|
| 166 |
+
logging.info("Applied memory optimizations")
|
| 167 |
+
|
| 168 |
+
return model, device
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logging.error(f"Model loading failed: {e}")
|
| 172 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_cached_model(model_selection: str):
|
| 176 |
+
"""Get cached model or load new one if selection changed"""
|
| 177 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 178 |
+
|
| 179 |
+
# Get model paths from selection
|
| 180 |
+
model_path, model_info = get_model_paths_from_selection(model_selection)
|
| 181 |
+
|
| 182 |
+
if model_path is None or model_info is None:
|
| 183 |
+
raise ValueError(f"Selected model not found: {model_selection}")
|
| 184 |
+
|
| 185 |
+
# Check if we need to reload the model
|
| 186 |
+
if (_cached_model is None or
|
| 187 |
+
_cached_model_selection != model_selection):
|
| 188 |
+
|
| 189 |
+
# Clear previous model if exists
|
| 190 |
+
if _cached_model is not None:
|
| 191 |
+
del _cached_model
|
| 192 |
+
torch.cuda.empty_cache()
|
| 193 |
+
gc.collect()
|
| 194 |
+
|
| 195 |
+
logging.info(f"🚀 Loading model: {model_selection}")
|
| 196 |
+
_cached_model, _cached_device = load_model_for_inference(model_path, model_info)
|
| 197 |
+
_cached_model_selection = model_selection
|
| 198 |
+
|
| 199 |
+
logging.info(f"✅ Model loaded successfully: {model_selection}")
|
| 200 |
+
else:
|
| 201 |
+
logging.info(f"✅ Using cached model: {model_selection}")
|
| 202 |
+
|
| 203 |
+
return _cached_model, _cached_device
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def clear_model_cache():
|
| 207 |
+
"""Clear the cached model to free memory"""
|
| 208 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 209 |
+
|
| 210 |
+
if _cached_model is not None:
|
| 211 |
+
logging.info("Clearing model cache...")
|
| 212 |
+
del _cached_model
|
| 213 |
+
_cached_model = None
|
| 214 |
+
_cached_device = None
|
| 215 |
+
_cached_model_selection = None
|
| 216 |
+
|
| 217 |
+
# Simple cleanup
|
| 218 |
+
import gc
|
| 219 |
+
gc.collect()
|
| 220 |
+
torch.cuda.empty_cache()
|
| 221 |
+
logging.info("Model cache cleared")
|
| 222 |
+
else:
|
| 223 |
+
logging.info("No model in cache to clear")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def inference(left, right, model, device, n_iter=20):
|
| 227 |
+
"""Run CREStereo inference on stereo pair"""
|
| 228 |
+
print("Model Forwarding...")
|
| 229 |
+
imgL = left.transpose(2, 0, 1)
|
| 230 |
+
imgR = right.transpose(2, 0, 1)
|
| 231 |
+
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
| 232 |
+
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
| 233 |
+
|
| 234 |
+
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
| 235 |
+
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
| 236 |
+
|
| 237 |
+
# Use InputPadder to handle any image size
|
| 238 |
+
padder = InputPadder(imgL.shape, divis_by=8)
|
| 239 |
+
imgL_padded, imgR_padded = padder.pad(imgL, imgR)
|
| 240 |
+
|
| 241 |
+
# Downsample for coarse prediction
|
| 242 |
+
imgL_dw2 = F.interpolate(
|
| 243 |
+
imgL_padded,
|
| 244 |
+
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
|
| 245 |
+
mode="bilinear",
|
| 246 |
+
align_corners=True,
|
| 247 |
+
)
|
| 248 |
+
imgR_dw2 = F.interpolate(
|
| 249 |
+
imgR_padded,
|
| 250 |
+
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2),
|
| 251 |
+
mode="bilinear",
|
| 252 |
+
align_corners=True,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
with torch.inference_mode():
|
| 256 |
+
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
|
| 257 |
+
pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2)
|
| 258 |
+
|
| 259 |
+
# Unpad the result to original dimensions
|
| 260 |
+
pred_flow = padder.unpad(pred_flow)
|
| 261 |
+
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
| 262 |
+
|
| 263 |
+
return pred_disp
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def vis_disparity(disparity_map, max_val=None):
|
| 267 |
+
"""Visualize disparity map"""
|
| 268 |
+
if max_val is None:
|
| 269 |
+
disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0
|
| 270 |
+
else:
|
| 271 |
+
disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255)
|
| 272 |
+
|
| 273 |
+
disp_vis = disp_vis.astype("uint8")
|
| 274 |
+
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
| 275 |
+
disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB)
|
| 276 |
+
return disp_vis
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def process_stereo_pair(model_selection: str, left_image: str, right_image: str,
|
| 280 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
|
| 281 |
+
"""
|
| 282 |
+
Main processing function for stereo pair (with model caching)
|
| 283 |
+
"""
|
| 284 |
+
logging.info("Starting stereo pair processing...")
|
| 285 |
+
|
| 286 |
+
if left_image is None or right_image is None:
|
| 287 |
+
return None, "❌ Please upload both left and right images."
|
| 288 |
+
|
| 289 |
+
# Convert image paths to numpy arrays
|
| 290 |
+
logging.info(f"Loading images: left={left_image}, right={right_image}")
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
# Load left image
|
| 294 |
+
if not os.path.exists(left_image):
|
| 295 |
+
logging.error(f"Left image file does not exist: {left_image}")
|
| 296 |
+
return None, f"❌ Left image file not found: {left_image}"
|
| 297 |
+
|
| 298 |
+
logging.info(f"Loading left image from: {left_image}")
|
| 299 |
+
left_img = cv2.imread(left_image)
|
| 300 |
+
if left_img is not None:
|
| 301 |
+
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
|
| 302 |
+
else:
|
| 303 |
+
# Try with imageio as fallback
|
| 304 |
+
left_img = imageio.imread(left_image)
|
| 305 |
+
if len(left_img.shape) == 3 and left_img.shape[2] == 4:
|
| 306 |
+
left_img = left_img[:, :, :3]
|
| 307 |
+
|
| 308 |
+
# Load right image
|
| 309 |
+
if not os.path.exists(right_image):
|
| 310 |
+
logging.error(f"Right image file does not exist: {right_image}")
|
| 311 |
+
return None, f"❌ Right image file not found: {right_image}"
|
| 312 |
+
|
| 313 |
+
logging.info(f"Loading right image from: {right_image}")
|
| 314 |
+
right_img = cv2.imread(right_image)
|
| 315 |
+
if right_img is not None:
|
| 316 |
+
right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
|
| 317 |
+
else:
|
| 318 |
+
# Try with imageio as fallback
|
| 319 |
+
right_img = imageio.imread(right_image)
|
| 320 |
+
if len(right_img.shape) == 3 and right_img.shape[2] == 4:
|
| 321 |
+
right_img = right_img[:, :, :3]
|
| 322 |
+
|
| 323 |
+
logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}")
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logging.error(f"Failed to load images: {e}")
|
| 327 |
+
return None, f"❌ Failed to load images: {str(e)}"
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
# Get cached model
|
| 331 |
+
variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
|
| 332 |
+
progress(0.1, desc=f"Loading cached model ({variant_name})...")
|
| 333 |
+
logging.info("🚀 Getting cached model...")
|
| 334 |
+
model, device = get_cached_model(model_selection)
|
| 335 |
+
logging.info("✅ Cached model loaded successfully")
|
| 336 |
+
|
| 337 |
+
progress(0.2, desc="Preprocessing images...")
|
| 338 |
+
|
| 339 |
+
# Validate input images
|
| 340 |
+
if left_img.shape != right_img.shape:
|
| 341 |
+
return None, "❌ Left and right images must have the same dimensions."
|
| 342 |
+
|
| 343 |
+
H, W = left_img.shape[:2]
|
| 344 |
+
|
| 345 |
+
progress(0.5, desc="Running inference...")
|
| 346 |
+
|
| 347 |
+
# Process stereo pair
|
| 348 |
+
torch.cuda.empty_cache() # Clear any cached memory before inference
|
| 349 |
+
|
| 350 |
+
disp_cpu = inference(left_img, right_img, model, device, n_iter=20)
|
| 351 |
+
|
| 352 |
+
progress(0.8, desc="Creating visualization...")
|
| 353 |
+
|
| 354 |
+
# Create visualization
|
| 355 |
+
disparity_vis = vis_disparity(disp_cpu)
|
| 356 |
+
result_image = disparity_vis
|
| 357 |
+
|
| 358 |
+
progress(1.0, desc="Complete!")
|
| 359 |
+
|
| 360 |
+
# Create status message
|
| 361 |
+
valid_mask = ~np.isinf(disp_cpu)
|
| 362 |
+
min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
|
| 363 |
+
max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
|
| 364 |
+
mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
|
| 365 |
+
|
| 366 |
+
# Get model variant for status
|
| 367 |
+
variant = variant_name
|
| 368 |
+
|
| 369 |
+
# Check current memory usage
|
| 370 |
+
try:
|
| 371 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 372 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 373 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 374 |
+
except:
|
| 375 |
+
memory_info = ""
|
| 376 |
+
|
| 377 |
+
status = f"""✅ Processing successful!
|
| 378 |
+
🔧 Model: {variant}{memory_info}
|
| 379 |
+
📊 Disparity Statistics:
|
| 380 |
+
• Range: {min_disp:.2f} - {max_disp:.2f}
|
| 381 |
+
• Mean: {mean_disp:.2f}
|
| 382 |
+
• Input size: {W}×{H}
|
| 383 |
+
• Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
|
| 384 |
+
|
| 385 |
+
return result_image, status
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logging.error(f"Processing failed: {e}")
|
| 389 |
+
# Clean up GPU memory
|
| 390 |
+
torch.cuda.empty_cache()
|
| 391 |
+
gc.collect()
|
| 392 |
+
return None, f"❌ Error: {str(e)}"
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def process_with_depth(model_selection: str, left_image: str, right_image: str,
|
| 396 |
+
camera_matrix: str, baseline: float,
|
| 397 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
|
| 398 |
+
"""
|
| 399 |
+
Process stereo pair and generate depth map and point cloud (with model caching)
|
| 400 |
+
"""
|
| 401 |
+
# Import Open3D
|
| 402 |
+
global OPEN3D_AVAILABLE
|
| 403 |
+
try:
|
| 404 |
+
import open3d as o3d
|
| 405 |
+
OPEN3D_AVAILABLE = True
|
| 406 |
+
except ImportError as e:
|
| 407 |
+
logging.warning(f"Open3D not available: {e}")
|
| 408 |
+
OPEN3D_AVAILABLE = False
|
| 409 |
+
return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
|
| 410 |
+
|
| 411 |
+
if left_image is None or right_image is None:
|
| 412 |
+
return None, None, None, "❌ Please upload both left and right images."
|
| 413 |
+
|
| 414 |
+
try:
|
| 415 |
+
progress(0.1, desc="Parsing camera parameters...")
|
| 416 |
+
|
| 417 |
+
# Parse camera matrix
|
| 418 |
+
try:
|
| 419 |
+
K_values = list(map(float, camera_matrix.strip().split()))
|
| 420 |
+
if len(K_values) != 9:
|
| 421 |
+
return None, None, None, "❌ Camera matrix must contain exactly 9 values."
|
| 422 |
+
K = np.array(K_values).reshape(3, 3)
|
| 423 |
+
except ValueError:
|
| 424 |
+
return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
|
| 425 |
+
|
| 426 |
+
if baseline <= 0:
|
| 427 |
+
return None, None, None, "❌ Baseline must be positive."
|
| 428 |
+
|
| 429 |
+
# First get disparity using the same process as basic function
|
| 430 |
+
disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress)
|
| 431 |
+
|
| 432 |
+
if disparity_result is None:
|
| 433 |
+
return None, None, None, status
|
| 434 |
+
|
| 435 |
+
# Load images again for depth processing
|
| 436 |
+
left_img = cv2.imread(left_image)
|
| 437 |
+
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
|
| 438 |
+
|
| 439 |
+
# Get disparity from model again (we need the raw values, not the visualization)
|
| 440 |
+
model, device = get_cached_model(model_selection)
|
| 441 |
+
disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20)
|
| 442 |
+
|
| 443 |
+
progress(0.6, desc="Converting to depth...")
|
| 444 |
+
|
| 445 |
+
# Remove invisible points
|
| 446 |
+
H, W = disp_cpu.shape
|
| 447 |
+
yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
|
| 448 |
+
us_right = xx - disp_cpu
|
| 449 |
+
invalid = us_right < 0
|
| 450 |
+
disp_cpu[invalid] = np.inf
|
| 451 |
+
|
| 452 |
+
# Convert to depth using the formula: depth = focal_length * baseline / disparity
|
| 453 |
+
depth = K[0, 0] * baseline / disp_cpu
|
| 454 |
+
|
| 455 |
+
# Visualize depth
|
| 456 |
+
depth_vis = vis_disparity(depth, max_val=10.0)
|
| 457 |
+
|
| 458 |
+
progress(0.8, desc="Generating point cloud...")
|
| 459 |
+
|
| 460 |
+
# Generate point cloud
|
| 461 |
+
fx, fy = K[0, 0], K[1, 1]
|
| 462 |
+
cx, cy = K[0, 2], K[1, 2]
|
| 463 |
+
|
| 464 |
+
# Create coordinate meshgrids
|
| 465 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 466 |
+
|
| 467 |
+
# Convert to 3D coordinates
|
| 468 |
+
valid_depth = ~np.isinf(depth)
|
| 469 |
+
z = depth[valid_depth] # Z coordinate (depth)
|
| 470 |
+
x = (u[valid_depth] - cx) * z / fx # X coordinate
|
| 471 |
+
y = (v[valid_depth] - cy) * z / fy # Y coordinate
|
| 472 |
+
|
| 473 |
+
# Stack coordinates (X, Y, Z)
|
| 474 |
+
points = np.stack([x, y, z], axis=-1)
|
| 475 |
+
|
| 476 |
+
# Get corresponding colors
|
| 477 |
+
colors = left_img[valid_depth]
|
| 478 |
+
|
| 479 |
+
# Filter points by depth range
|
| 480 |
+
depth_mask = (z > 0) & (z <= 10.0)
|
| 481 |
+
valid_points = points[depth_mask]
|
| 482 |
+
valid_colors = colors[depth_mask]
|
| 483 |
+
|
| 484 |
+
if len(valid_points) == 0:
|
| 485 |
+
return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
|
| 486 |
+
|
| 487 |
+
# Subsample points for better performance
|
| 488 |
+
if len(valid_points) > 100000:
|
| 489 |
+
indices = np.random.choice(len(valid_points), 100000, replace=False)
|
| 490 |
+
valid_points = valid_points[indices]
|
| 491 |
+
valid_colors = valid_colors[indices]
|
| 492 |
+
|
| 493 |
+
# Transform coordinates for proper visualization
|
| 494 |
+
transformed_points = valid_points.copy()
|
| 495 |
+
transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
|
| 496 |
+
transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
|
| 497 |
+
|
| 498 |
+
# Generate point cloud
|
| 499 |
+
pcd = o3d.geometry.PointCloud()
|
| 500 |
+
pcd.points = o3d.utility.Vector3dVector(transformed_points)
|
| 501 |
+
pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
|
| 502 |
+
|
| 503 |
+
progress(1.0, desc="Complete!")
|
| 504 |
+
|
| 505 |
+
# Check current memory usage
|
| 506 |
+
try:
|
| 507 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 508 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 509 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 510 |
+
except:
|
| 511 |
+
memory_info = ""
|
| 512 |
+
|
| 513 |
+
variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection
|
| 514 |
+
|
| 515 |
+
status = f"""✅ Depth processing successful!
|
| 516 |
+
🔧 Model: {variant}{memory_info}
|
| 517 |
+
📊 Statistics:
|
| 518 |
+
• Valid points: {len(valid_points):,}
|
| 519 |
+
• Depth range: {z.min():.2f} - {z.max():.2f} m
|
| 520 |
+
• Baseline: {baseline} m
|
| 521 |
+
• Point cloud generated with {len(valid_points)} points
|
| 522 |
+
• 3D visualization available"""
|
| 523 |
+
|
| 524 |
+
return depth_vis, None, None, status
|
| 525 |
+
|
| 526 |
+
except Exception as e:
|
| 527 |
+
logging.error(f"Depth processing failed: {e}")
|
| 528 |
+
torch.cuda.empty_cache()
|
| 529 |
+
gc.collect()
|
| 530 |
+
return None, None, None, f"❌ Error: {str(e)}"
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def create_app() -> gr.Blocks:
|
| 534 |
+
"""Create the Gradio application"""
|
| 535 |
+
|
| 536 |
+
# Get available models
|
| 537 |
+
try:
|
| 538 |
+
available_models = get_available_models()
|
| 539 |
+
logging.info(f"Successfully got available models: {len(available_models)} found")
|
| 540 |
+
except Exception as e:
|
| 541 |
+
logging.error(f"Failed to get available models: {e}")
|
| 542 |
+
available_models = {}
|
| 543 |
+
|
| 544 |
+
with gr.Blocks(
|
| 545 |
+
title="CREStereo - Stereo Depth Estimation",
|
| 546 |
+
theme=gr.themes.Soft(),
|
| 547 |
+
css="footer {visibility: hidden}",
|
| 548 |
+
delete_cache=(60, 60)
|
| 549 |
+
) as app:
|
| 550 |
+
|
| 551 |
+
gr.Markdown("""
|
| 552 |
+
# 🔍 CREStereo: Practical Stereo Matching
|
| 553 |
+
|
| 554 |
+
Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo.
|
| 555 |
+
|
| 556 |
+
⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
|
| 557 |
+
⚡ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference.
|
| 558 |
+
""")
|
| 559 |
+
|
| 560 |
+
# Instructions section
|
| 561 |
+
with gr.Accordion("📋 Instructions", open=False):
|
| 562 |
+
gr.Markdown("""
|
| 563 |
+
## 🚀 How to Use This Demo
|
| 564 |
+
|
| 565 |
+
### 🖼️ Input Requirements
|
| 566 |
+
1. **Image Format**: Upload images in JPEG or PNG format.
|
| 567 |
+
2. **Image Size**: Images should be of the same size and resolution.
|
| 568 |
+
3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
|
| 569 |
+
4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance.
|
| 570 |
+
|
| 571 |
+
### 📊 Using the Demo
|
| 572 |
+
1. **Select Model**: Choose the CREStereo model variant
|
| 573 |
+
2. **Upload Images**: Provide rectified stereo image pairs
|
| 574 |
+
3. **Basic Processing**: Get disparity visualization
|
| 575 |
+
4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
|
| 576 |
+
|
| 577 |
+
### 📖 Original Work
|
| 578 |
+
This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network.
|
| 579 |
+
- **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483)
|
| 580 |
+
- **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo)
|
| 581 |
+
""")
|
| 582 |
+
|
| 583 |
+
# Model selection
|
| 584 |
+
with gr.Row():
|
| 585 |
+
all_choices = list(available_models.keys())
|
| 586 |
+
|
| 587 |
+
if not all_choices:
|
| 588 |
+
all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"]
|
| 589 |
+
|
| 590 |
+
default_model = all_choices[0] if all_choices else None
|
| 591 |
+
|
| 592 |
+
model_selector = gr.Dropdown(
|
| 593 |
+
choices=all_choices,
|
| 594 |
+
value=default_model,
|
| 595 |
+
label="🎯 Select Model",
|
| 596 |
+
info="Choose the CREStereo model variant.",
|
| 597 |
+
interactive=True
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
with gr.Tabs():
|
| 601 |
+
# Basic stereo processing tab
|
| 602 |
+
with gr.TabItem("🖼️ Basic Stereo Processing"):
|
| 603 |
+
with gr.Row():
|
| 604 |
+
with gr.Column():
|
| 605 |
+
left_input = gr.Image(
|
| 606 |
+
label="📷 Left Image",
|
| 607 |
+
type="filepath",
|
| 608 |
+
height=300
|
| 609 |
+
)
|
| 610 |
+
right_input = gr.Image(
|
| 611 |
+
label="📷 Right Image",
|
| 612 |
+
type="filepath",
|
| 613 |
+
height=300
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
process_btn = gr.Button(
|
| 617 |
+
"🚀 Process Stereo Pair",
|
| 618 |
+
variant="primary",
|
| 619 |
+
size="lg"
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
with gr.Column():
|
| 623 |
+
output_image = gr.Image(
|
| 624 |
+
label="📊 Disparity Visualization",
|
| 625 |
+
height=400
|
| 626 |
+
)
|
| 627 |
+
status_text = gr.Textbox(
|
| 628 |
+
label="Status",
|
| 629 |
+
interactive=False,
|
| 630 |
+
lines=8
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# Example images
|
| 634 |
+
examples_list = []
|
| 635 |
+
|
| 636 |
+
# Example 1
|
| 637 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 638 |
+
examples_list.append([
|
| 639 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 640 |
+
os.path.join(current_dir, "assets", "example1", "right.png")
|
| 641 |
+
])
|
| 642 |
+
|
| 643 |
+
# Example 2
|
| 644 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 645 |
+
examples_list.append([
|
| 646 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 647 |
+
os.path.join(current_dir, "assets", "example2", "right.png")
|
| 648 |
+
])
|
| 649 |
+
|
| 650 |
+
if examples_list:
|
| 651 |
+
gr.Examples(
|
| 652 |
+
examples=examples_list,
|
| 653 |
+
inputs=[left_input, right_input],
|
| 654 |
+
label="📋 Example Images"
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# Advanced processing with depth
|
| 658 |
+
with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
|
| 659 |
+
with gr.Row():
|
| 660 |
+
with gr.Column():
|
| 661 |
+
left_input_adv = gr.Image(
|
| 662 |
+
label="📷 Left Image",
|
| 663 |
+
type="filepath",
|
| 664 |
+
height=250
|
| 665 |
+
)
|
| 666 |
+
right_input_adv = gr.Image(
|
| 667 |
+
label="📷 Right Image",
|
| 668 |
+
type="filepath",
|
| 669 |
+
height=250
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Camera parameters
|
| 673 |
+
with gr.Group():
|
| 674 |
+
gr.Markdown("### 📹 Camera Parameters")
|
| 675 |
+
camera_matrix_input = gr.Textbox(
|
| 676 |
+
label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
|
| 677 |
+
value="",
|
| 678 |
+
)
|
| 679 |
+
baseline_input = gr.Number(
|
| 680 |
+
label="Baseline (meters)",
|
| 681 |
+
value=None,
|
| 682 |
+
minimum=0.001,
|
| 683 |
+
maximum=10.0,
|
| 684 |
+
step=0.001
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
process_depth_btn = gr.Button(
|
| 688 |
+
"🔬 Process with Depth",
|
| 689 |
+
variant="primary",
|
| 690 |
+
size="lg"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
with gr.Column():
|
| 694 |
+
depth_output = gr.Image(
|
| 695 |
+
label="📏 Depth Visualization",
|
| 696 |
+
height=300
|
| 697 |
+
)
|
| 698 |
+
pointcloud_output = gr.File(
|
| 699 |
+
label="☁️ Point Cloud Download (.ply)",
|
| 700 |
+
file_types=[".ply"]
|
| 701 |
+
)
|
| 702 |
+
status_depth = gr.Textbox(
|
| 703 |
+
label="Status",
|
| 704 |
+
interactive=False,
|
| 705 |
+
lines=6
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
# 3D Point Cloud Visualization
|
| 709 |
+
with gr.Row():
|
| 710 |
+
pointcloud_3d = gr.Model3D(
|
| 711 |
+
label="🌐 3D Point Cloud Viewer",
|
| 712 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 713 |
+
height=400
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# Example images for advanced processing
|
| 717 |
+
examples_advanced_list = []
|
| 718 |
+
|
| 719 |
+
# Try to read camera parameters from K.txt files
|
| 720 |
+
# Example 1
|
| 721 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 722 |
+
k_file = os.path.join(current_dir, "assets", "example1", "K.txt")
|
| 723 |
+
camera_matrix_str = ""
|
| 724 |
+
baseline_val = 0.063 # default
|
| 725 |
+
|
| 726 |
+
if os.path.exists(k_file):
|
| 727 |
+
try:
|
| 728 |
+
with open(k_file, 'r') as f:
|
| 729 |
+
lines = f.readlines()
|
| 730 |
+
if len(lines) >= 1:
|
| 731 |
+
camera_matrix_str = lines[0].strip()
|
| 732 |
+
if len(lines) >= 2:
|
| 733 |
+
baseline_val = float(lines[1].strip())
|
| 734 |
+
except:
|
| 735 |
+
camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0"
|
| 736 |
+
|
| 737 |
+
examples_advanced_list.append([
|
| 738 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 739 |
+
os.path.join(current_dir, "assets", "example1", "right.png"),
|
| 740 |
+
camera_matrix_str,
|
| 741 |
+
baseline_val
|
| 742 |
+
])
|
| 743 |
+
|
| 744 |
+
# Example 2
|
| 745 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 746 |
+
k_file = os.path.join(current_dir, "assets", "example2", "K.txt")
|
| 747 |
+
camera_matrix_str = ""
|
| 748 |
+
baseline_val = 0.537 # default
|
| 749 |
+
|
| 750 |
+
if os.path.exists(k_file):
|
| 751 |
+
try:
|
| 752 |
+
with open(k_file, 'r') as f:
|
| 753 |
+
lines = f.readlines()
|
| 754 |
+
if len(lines) >= 1:
|
| 755 |
+
camera_matrix_str = lines[0].strip()
|
| 756 |
+
if len(lines) >= 2:
|
| 757 |
+
baseline_val = float(lines[1].strip())
|
| 758 |
+
except:
|
| 759 |
+
camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0"
|
| 760 |
+
|
| 761 |
+
examples_advanced_list.append([
|
| 762 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 763 |
+
os.path.join(current_dir, "assets", "example2", "right.png"),
|
| 764 |
+
camera_matrix_str,
|
| 765 |
+
baseline_val
|
| 766 |
+
])
|
| 767 |
+
|
| 768 |
+
if examples_advanced_list:
|
| 769 |
+
gr.Examples(
|
| 770 |
+
examples=examples_advanced_list,
|
| 771 |
+
inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 772 |
+
label="📋 Example Images with Camera Parameters"
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# Event handlers
|
| 776 |
+
if available_models:
|
| 777 |
+
process_btn.click(
|
| 778 |
+
fn=process_stereo_pair,
|
| 779 |
+
inputs=[model_selector, left_input, right_input],
|
| 780 |
+
outputs=[output_image, status_text],
|
| 781 |
+
show_progress=True
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
if OPEN3D_AVAILABLE:
|
| 785 |
+
process_depth_btn.click(
|
| 786 |
+
fn=process_with_depth,
|
| 787 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 788 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
|
| 789 |
+
show_progress=True
|
| 790 |
+
)
|
| 791 |
+
else:
|
| 792 |
+
process_depth_btn.click(
|
| 793 |
+
fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
|
| 794 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 795 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
|
| 796 |
+
)
|
| 797 |
+
else:
|
| 798 |
+
# No models available
|
| 799 |
+
process_btn.click(
|
| 800 |
+
fn=lambda *args: (None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
|
| 801 |
+
inputs=[model_selector, left_input, right_input],
|
| 802 |
+
outputs=[output_image, status_text]
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
process_depth_btn.click(
|
| 806 |
+
fn=lambda *args: (None, None, None, "❌ No models available. Please ensure crestereo_eth3d.pth is in models/ directory."),
|
| 807 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 808 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Citation section at the bottom
|
| 812 |
+
with gr.Accordion("📖 Citation", open=False):
|
| 813 |
+
gr.Markdown("""
|
| 814 |
+
### 📄 Please Cite the Original Paper
|
| 815 |
+
|
| 816 |
+
If you use this work in your research, please cite:
|
| 817 |
+
|
| 818 |
+
```bibtex
|
| 819 |
+
@article{li2022practical,
|
| 820 |
+
title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation},
|
| 821 |
+
author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng},
|
| 822 |
+
journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 823 |
+
pages={16263--16272},
|
| 824 |
+
year={2022}
|
| 825 |
+
}
|
| 826 |
+
```
|
| 827 |
+
""")
|
| 828 |
+
|
| 829 |
+
# Footer
|
| 830 |
+
gr.Markdown("""
|
| 831 |
+
---
|
| 832 |
+
### 📝 Notes:
|
| 833 |
+
- **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
|
| 834 |
+
- **⚡ GPU Acceleration**: Requires CUDA-compatible GPU
|
| 835 |
+
- **📦 Model Caching**: Models are cached for efficient repeated usage
|
| 836 |
+
- For best results, use high-quality rectified stereo pairs
|
| 837 |
+
- Model works on RGB images and supports various resolutions
|
| 838 |
+
|
| 839 |
+
### 🔗 References:
|
| 840 |
+
- [CREStereo Paper](https://arxiv.org/abs/2203.11483)
|
| 841 |
+
- [Original GitHub Repository](https://github.com/megvii-research/CREStereo)
|
| 842 |
+
- [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch)
|
| 843 |
+
""")
|
| 844 |
+
|
| 845 |
+
return app
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def main():
|
| 849 |
+
"""Main function to launch the app"""
|
| 850 |
+
|
| 851 |
+
logging.info("🚀 Starting CREStereo Gradio App...")
|
| 852 |
+
|
| 853 |
+
# Parse command line arguments
|
| 854 |
+
import argparse
|
| 855 |
+
parser = argparse.ArgumentParser(description="CREStereo Gradio App")
|
| 856 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 857 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
| 858 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 859 |
+
|
| 860 |
+
args = parser.parse_args()
|
| 861 |
+
|
| 862 |
+
if args.debug:
|
| 863 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 864 |
+
|
| 865 |
+
try:
|
| 866 |
+
# Create and launch app
|
| 867 |
+
logging.info("Creating Gradio app...")
|
| 868 |
+
app = create_app()
|
| 869 |
+
logging.info("✅ Gradio app created successfully")
|
| 870 |
+
|
| 871 |
+
logging.info(f"Launching app on {args.host}:{args.port}")
|
| 872 |
+
|
| 873 |
+
# Launch with appropriate settings
|
| 874 |
+
app.launch(
|
| 875 |
+
server_name=args.host,
|
| 876 |
+
server_port=args.port,
|
| 877 |
+
share=False,
|
| 878 |
+
show_error=True,
|
| 879 |
+
favicon_path=None,
|
| 880 |
+
ssr_mode=False,
|
| 881 |
+
allowed_paths=["./"]
|
| 882 |
+
)
|
| 883 |
+
except Exception as e:
|
| 884 |
+
logging.error(f"Failed to launch app: {e}")
|
| 885 |
+
raise
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
if __name__ == "__main__":
|
| 889 |
+
main()
|
CREStereo_demo/models/.gitkeep
ADDED
|
File without changes
|
CREStereo_demo/models/crestereo_eth3d.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2271ab615015a73edd4759b0f7b25a4d82ffb654270b92d3811237da3d63aa6d
|
| 3 |
+
size 21763979
|
CREStereo_demo/nets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .crestereo import CREStereo as Model
|
CREStereo_demo/nets/attention/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transformer import LocalFeatureTransformer
|
| 2 |
+
from .position_encoding import PositionEncodingSine
|
CREStereo_demo/nets/attention/linear_attention.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
| 3 |
+
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import Module, Dropout
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def elu_feature_map(x):
|
| 11 |
+
return torch.nn.functional.elu(x) + 1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LinearAttention(Module):
|
| 15 |
+
def __init__(self, eps=1e-6):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.feature_map = elu_feature_map
|
| 18 |
+
self.eps = eps
|
| 19 |
+
|
| 20 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
| 21 |
+
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
| 22 |
+
Args:
|
| 23 |
+
queries: [N, L, H, D]
|
| 24 |
+
keys: [N, S, H, D]
|
| 25 |
+
values: [N, S, H, D]
|
| 26 |
+
q_mask: [N, L]
|
| 27 |
+
kv_mask: [N, S]
|
| 28 |
+
Returns:
|
| 29 |
+
queried_values: (N, L, H, D)
|
| 30 |
+
"""
|
| 31 |
+
Q = self.feature_map(queries)
|
| 32 |
+
K = self.feature_map(keys)
|
| 33 |
+
|
| 34 |
+
# set padded position to zero
|
| 35 |
+
if q_mask is not None:
|
| 36 |
+
Q = Q * q_mask[:, :, None, None]
|
| 37 |
+
if kv_mask is not None:
|
| 38 |
+
K = K * kv_mask[:, :, None, None]
|
| 39 |
+
values = values * kv_mask[:, :, None, None]
|
| 40 |
+
|
| 41 |
+
v_length = values.size(1)
|
| 42 |
+
values = values / v_length # prevent fp16 overflow
|
| 43 |
+
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
| 44 |
+
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
| 45 |
+
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
| 46 |
+
|
| 47 |
+
return queried_values.contiguous()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class FullAttention(Module):
|
| 51 |
+
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.use_dropout = use_dropout
|
| 54 |
+
self.dropout = Dropout(attention_dropout)
|
| 55 |
+
|
| 56 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
| 57 |
+
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
| 58 |
+
Args:
|
| 59 |
+
queries: [N, L, H, D]
|
| 60 |
+
keys: [N, S, H, D]
|
| 61 |
+
values: [N, S, H, D]
|
| 62 |
+
q_mask: [N, L]
|
| 63 |
+
kv_mask: [N, S]
|
| 64 |
+
Returns:
|
| 65 |
+
queried_values: (N, L, H, D)
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
# Compute the unnormalized attention and apply the masks
|
| 69 |
+
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
| 70 |
+
if kv_mask is not None:
|
| 71 |
+
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
|
| 72 |
+
|
| 73 |
+
# Compute the attention and the weighted average
|
| 74 |
+
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
| 75 |
+
A = torch.softmax(softmax_temp * QK, dim=2)
|
| 76 |
+
if self.use_dropout:
|
| 77 |
+
A = self.dropout(A)
|
| 78 |
+
|
| 79 |
+
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
| 80 |
+
|
| 81 |
+
return queried_values.contiguous()
|
CREStereo_demo/nets/attention/position_encoding.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PositionEncodingSine(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=False):
|
| 12 |
+
"""
|
| 13 |
+
Args:
|
| 14 |
+
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
| 15 |
+
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
| 16 |
+
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
| 17 |
+
on the final performance. For now, we keep both impls for backward compatability.
|
| 18 |
+
We will remove the buggy impl after re-training all variants of our released models.
|
| 19 |
+
"""
|
| 20 |
+
super().__init__()
|
| 21 |
+
pe = torch.zeros((d_model, *max_shape))
|
| 22 |
+
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
| 23 |
+
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
| 24 |
+
if temp_bug_fix:
|
| 25 |
+
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
|
| 26 |
+
else: # a buggy implementation (for backward compatability only)
|
| 27 |
+
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
|
| 28 |
+
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
| 29 |
+
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
| 30 |
+
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
| 31 |
+
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
| 32 |
+
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
| 33 |
+
|
| 34 |
+
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
x: [N, C, H, W]
|
| 40 |
+
"""
|
| 41 |
+
return x + self.pe[:, :, :x.size(2), :x.size(3)].to(x.device)
|
CREStereo_demo/nets/attention/transformer.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from .linear_attention import LinearAttention, FullAttention
|
| 5 |
+
|
| 6 |
+
#Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
|
| 7 |
+
class LoFTREncoderLayer(nn.Module):
|
| 8 |
+
def __init__(self,
|
| 9 |
+
d_model,
|
| 10 |
+
nhead,
|
| 11 |
+
attention='linear'):
|
| 12 |
+
super(LoFTREncoderLayer, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.dim = d_model // nhead
|
| 15 |
+
self.nhead = nhead
|
| 16 |
+
|
| 17 |
+
# multi-head attention
|
| 18 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 19 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 20 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 21 |
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
| 22 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
| 23 |
+
|
| 24 |
+
# feed-forward network
|
| 25 |
+
self.mlp = nn.Sequential(
|
| 26 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# norm and dropout
|
| 32 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 33 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 34 |
+
|
| 35 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
x (torch.Tensor): [N, L, C]
|
| 39 |
+
source (torch.Tensor): [N, S, C]
|
| 40 |
+
x_mask (torch.Tensor): [N, L] (optional)
|
| 41 |
+
source_mask (torch.Tensor): [N, S] (optional)
|
| 42 |
+
"""
|
| 43 |
+
bs = x.size(0)
|
| 44 |
+
query, key, value = x, source, source
|
| 45 |
+
|
| 46 |
+
# multi-head attention
|
| 47 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
| 48 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
| 49 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
| 50 |
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
| 51 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
| 52 |
+
message = self.norm1(message)
|
| 53 |
+
|
| 54 |
+
# feed-forward network
|
| 55 |
+
message = self.mlp(torch.cat([x, message], dim=2))
|
| 56 |
+
message = self.norm2(message)
|
| 57 |
+
|
| 58 |
+
return x + message
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LocalFeatureTransformer(nn.Module):
|
| 62 |
+
"""A Local Feature Transformer (LoFTR) module."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, d_model, nhead, layer_names, attention):
|
| 65 |
+
super(LocalFeatureTransformer, self).__init__()
|
| 66 |
+
|
| 67 |
+
self.d_model = d_model
|
| 68 |
+
self.nhead = nhead
|
| 69 |
+
self.layer_names = layer_names
|
| 70 |
+
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
|
| 71 |
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
| 72 |
+
self._reset_parameters()
|
| 73 |
+
|
| 74 |
+
def _reset_parameters(self):
|
| 75 |
+
for p in self.parameters():
|
| 76 |
+
if p.dim() > 1:
|
| 77 |
+
nn.init.xavier_uniform_(p)
|
| 78 |
+
|
| 79 |
+
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
| 80 |
+
"""
|
| 81 |
+
Args:
|
| 82 |
+
feat0 (torch.Tensor): [N, L, C]
|
| 83 |
+
feat1 (torch.Tensor): [N, S, C]
|
| 84 |
+
mask0 (torch.Tensor): [N, L] (optional)
|
| 85 |
+
mask1 (torch.Tensor): [N, S] (optional)
|
| 86 |
+
"""
|
| 87 |
+
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
| 88 |
+
|
| 89 |
+
for layer, name in zip(self.layers, self.layer_names):
|
| 90 |
+
|
| 91 |
+
if name == 'self':
|
| 92 |
+
feat0 = layer(feat0, feat0, mask0, mask0)
|
| 93 |
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
| 94 |
+
elif name == 'cross':
|
| 95 |
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
| 96 |
+
feat1 = layer(feat1, feat0, mask1, mask0)
|
| 97 |
+
else:
|
| 98 |
+
raise KeyError
|
| 99 |
+
|
| 100 |
+
return feat0, feat1
|
CREStereo_demo/nets/corr.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .utils import bilinear_sampler, coords_grid, manual_pad
|
| 7 |
+
|
| 8 |
+
class AGCL:
|
| 9 |
+
"""
|
| 10 |
+
Implementation of Adaptive Group Correlation Layer (AGCL).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, fmap1, fmap2, att=None):
|
| 14 |
+
self.fmap1 = fmap1
|
| 15 |
+
self.fmap2 = fmap2
|
| 16 |
+
|
| 17 |
+
self.att = att
|
| 18 |
+
|
| 19 |
+
self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)
|
| 20 |
+
|
| 21 |
+
def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
|
| 22 |
+
if iter_mode:
|
| 23 |
+
corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
|
| 24 |
+
else:
|
| 25 |
+
corr = self.corr_att_offset(
|
| 26 |
+
self.fmap1, self.fmap2, flow, extra_offset, small_patch
|
| 27 |
+
)
|
| 28 |
+
return corr
|
| 29 |
+
|
| 30 |
+
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
|
| 31 |
+
|
| 32 |
+
N, C, H, W = left_feature.shape
|
| 33 |
+
|
| 34 |
+
di_y, di_x = dilate[0], dilate[1]
|
| 35 |
+
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
|
| 36 |
+
|
| 37 |
+
right_pad = manual_pad(right_feature, pady, padx)
|
| 38 |
+
|
| 39 |
+
corr_list = []
|
| 40 |
+
for h in range(0, pady * 2 + 1, di_y):
|
| 41 |
+
for w in range(0, padx * 2 + 1, di_x):
|
| 42 |
+
right_crop = right_pad[:, :, h : h + H, w : w + W]
|
| 43 |
+
assert right_crop.shape == left_feature.shape
|
| 44 |
+
corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
|
| 45 |
+
corr_list.append(corr)
|
| 46 |
+
|
| 47 |
+
corr_final = torch.cat(corr_list, dim=1)
|
| 48 |
+
|
| 49 |
+
return corr_final
|
| 50 |
+
|
| 51 |
+
def corr_iter(self, left_feature, right_feature, flow, small_patch):
|
| 52 |
+
|
| 53 |
+
coords = self.coords + flow
|
| 54 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 55 |
+
right_feature = bilinear_sampler(right_feature, coords)
|
| 56 |
+
|
| 57 |
+
if small_patch:
|
| 58 |
+
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
|
| 59 |
+
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
| 60 |
+
else:
|
| 61 |
+
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
|
| 62 |
+
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
| 63 |
+
|
| 64 |
+
N, C, H, W = left_feature.shape
|
| 65 |
+
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
|
| 66 |
+
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
|
| 67 |
+
|
| 68 |
+
corrs = []
|
| 69 |
+
for i in range(len(psize_list)):
|
| 70 |
+
corr = self.get_correlation(
|
| 71 |
+
lefts[i], rights[i], psize_list[i], dilate_list[i]
|
| 72 |
+
)
|
| 73 |
+
corrs.append(corr)
|
| 74 |
+
|
| 75 |
+
final_corr = torch.cat(corrs, dim=1)
|
| 76 |
+
|
| 77 |
+
return final_corr
|
| 78 |
+
|
| 79 |
+
def corr_att_offset(
|
| 80 |
+
self, left_feature, right_feature, flow, extra_offset, small_patch
|
| 81 |
+
):
|
| 82 |
+
|
| 83 |
+
N, C, H, W = left_feature.shape
|
| 84 |
+
|
| 85 |
+
if self.att is not None:
|
| 86 |
+
left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
|
| 87 |
+
right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
|
| 88 |
+
# 'n (h w) c -> n c h w'
|
| 89 |
+
left_feature, right_feature = self.att(left_feature, right_feature)
|
| 90 |
+
# 'n (h w) c -> n c h w'
|
| 91 |
+
left_feature, right_feature = [
|
| 92 |
+
x.reshape(N, H, W, C).permute(0, 3, 1, 2)
|
| 93 |
+
for x in [left_feature, right_feature]
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
|
| 97 |
+
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
|
| 98 |
+
|
| 99 |
+
C = C // 4
|
| 100 |
+
|
| 101 |
+
if small_patch:
|
| 102 |
+
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
|
| 103 |
+
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
| 104 |
+
else:
|
| 105 |
+
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
|
| 106 |
+
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
| 107 |
+
|
| 108 |
+
search_num = 9
|
| 109 |
+
extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2]
|
| 110 |
+
|
| 111 |
+
corrs = []
|
| 112 |
+
for i in range(len(psize_list)):
|
| 113 |
+
left_feature, right_feature = lefts[i], rights[i]
|
| 114 |
+
psize, dilate = psize_list[i], dilate_list[i]
|
| 115 |
+
|
| 116 |
+
psizey, psizex = psize[0], psize[1]
|
| 117 |
+
dilatey, dilatex = dilate[0], dilate[1]
|
| 118 |
+
|
| 119 |
+
ry = psizey // 2 * dilatey
|
| 120 |
+
rx = psizex // 2 * dilatex
|
| 121 |
+
x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device),
|
| 122 |
+
torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy')
|
| 123 |
+
|
| 124 |
+
offsets = torch.stack((x_grid, y_grid))
|
| 125 |
+
offsets = offsets.reshape(2, -1).permute(1, 0)
|
| 126 |
+
for d in sorted((0, 2, 3)):
|
| 127 |
+
offsets = offsets.unsqueeze(d)
|
| 128 |
+
offsets = offsets.repeat_interleave(N, dim=0)
|
| 129 |
+
offsets = offsets + extra_offset
|
| 130 |
+
|
| 131 |
+
coords = self.coords + flow # [N, 2, H, W]
|
| 132 |
+
coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2]
|
| 133 |
+
coords = torch.unsqueeze(coords, 1) + offsets
|
| 134 |
+
coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2]
|
| 135 |
+
|
| 136 |
+
right_feature = bilinear_sampler(
|
| 137 |
+
right_feature, coords
|
| 138 |
+
) # [N, C, search_num*H, W]
|
| 139 |
+
right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W]
|
| 140 |
+
left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2)
|
| 141 |
+
|
| 142 |
+
corr = torch.mean(left_feature * right_feature, dim=1)
|
| 143 |
+
|
| 144 |
+
corrs.append(corr)
|
| 145 |
+
|
| 146 |
+
final_corr = torch.cat(corrs, dim=1)
|
| 147 |
+
|
| 148 |
+
return final_corr
|
CREStereo_demo/nets/crestereo.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .update import BasicUpdateBlock
|
| 6 |
+
from .extractor import BasicEncoder
|
| 7 |
+
from .corr import AGCL
|
| 8 |
+
|
| 9 |
+
from .attention import PositionEncodingSine, LocalFeatureTransformer
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
autocast = torch.cuda.amp.autocast
|
| 13 |
+
except:
|
| 14 |
+
# dummy autocast for PyTorch < 1.6
|
| 15 |
+
class autocast:
|
| 16 |
+
def __init__(self, enabled):
|
| 17 |
+
pass
|
| 18 |
+
def __enter__(self):
|
| 19 |
+
pass
|
| 20 |
+
def __exit__(self, *args):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
|
| 24 |
+
class CREStereo(nn.Module):
|
| 25 |
+
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
|
| 26 |
+
super(CREStereo, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.max_flow = max_disp
|
| 29 |
+
self.mixed_precision = mixed_precision
|
| 30 |
+
self.test_mode = test_mode
|
| 31 |
+
|
| 32 |
+
self.hidden_dim = 128
|
| 33 |
+
self.context_dim = 128
|
| 34 |
+
self.dropout = 0
|
| 35 |
+
|
| 36 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
|
| 37 |
+
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
| 38 |
+
|
| 39 |
+
# loftr
|
| 40 |
+
self.self_att_fn = LocalFeatureTransformer(
|
| 41 |
+
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
|
| 42 |
+
)
|
| 43 |
+
self.cross_att_fn = LocalFeatureTransformer(
|
| 44 |
+
d_model=256, nhead=8, layer_names=["cross"] * 1, attention="linear"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# adaptive search
|
| 48 |
+
self.search_num = 9
|
| 49 |
+
self.conv_offset_16 = nn.Conv2d(
|
| 50 |
+
256, self.search_num * 2, kernel_size=3, stride=1, padding=1
|
| 51 |
+
)
|
| 52 |
+
self.conv_offset_8 = nn.Conv2d(
|
| 53 |
+
256, self.search_num * 2, kernel_size=3, stride=1, padding=1
|
| 54 |
+
)
|
| 55 |
+
self.range_16 = 1
|
| 56 |
+
self.range_8 = 1
|
| 57 |
+
|
| 58 |
+
def freeze_bn(self):
|
| 59 |
+
for m in self.modules():
|
| 60 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 61 |
+
m.eval()
|
| 62 |
+
|
| 63 |
+
def convex_upsample(self, flow, mask, rate=4):
|
| 64 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
| 65 |
+
N, _, H, W = flow.shape
|
| 66 |
+
# print(flow.shape, mask.shape, rate)
|
| 67 |
+
mask = mask.view(N, 1, 9, rate, rate, H, W)
|
| 68 |
+
mask = torch.softmax(mask, dim=2)
|
| 69 |
+
|
| 70 |
+
up_flow = F.unfold(rate * flow, [3,3], padding=1)
|
| 71 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
| 72 |
+
|
| 73 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
| 74 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
| 75 |
+
return up_flow.reshape(N, 2, rate*H, rate*W)
|
| 76 |
+
|
| 77 |
+
def zero_init(self, fmap):
|
| 78 |
+
N, C, H, W = fmap.shape
|
| 79 |
+
_x = torch.zeros([N, 1, H, W], dtype=torch.float32)
|
| 80 |
+
_y = torch.zeros([N, 1, H, W], dtype=torch.float32)
|
| 81 |
+
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
| 82 |
+
return zero_flow
|
| 83 |
+
|
| 84 |
+
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
|
| 85 |
+
""" Estimate optical flow between pair of frames """
|
| 86 |
+
|
| 87 |
+
image1 = 2 * (image1 / 255.0) - 1.0
|
| 88 |
+
image2 = 2 * (image2 / 255.0) - 1.0
|
| 89 |
+
|
| 90 |
+
image1 = image1.contiguous()
|
| 91 |
+
image2 = image2.contiguous()
|
| 92 |
+
|
| 93 |
+
hdim = self.hidden_dim
|
| 94 |
+
cdim = self.context_dim
|
| 95 |
+
|
| 96 |
+
# run the feature network
|
| 97 |
+
with autocast(enabled=self.mixed_precision):
|
| 98 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
| 99 |
+
|
| 100 |
+
fmap1 = fmap1.float()
|
| 101 |
+
fmap2 = fmap2.float()
|
| 102 |
+
|
| 103 |
+
with autocast(enabled=self.mixed_precision):
|
| 104 |
+
|
| 105 |
+
# 1/4 -> 1/8
|
| 106 |
+
# feature
|
| 107 |
+
fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
|
| 108 |
+
fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
|
| 109 |
+
|
| 110 |
+
# offset
|
| 111 |
+
offset_dw8 = self.conv_offset_8(fmap1_dw8)
|
| 112 |
+
offset_dw8 = self.range_8 * (torch.sigmoid(offset_dw8) - 0.5) * 2.0
|
| 113 |
+
|
| 114 |
+
# context
|
| 115 |
+
net, inp = torch.split(fmap1, [hdim,hdim], dim=1)
|
| 116 |
+
net = torch.tanh(net)
|
| 117 |
+
inp = F.relu(inp)
|
| 118 |
+
net_dw8 = F.avg_pool2d(net, 2, stride=2)
|
| 119 |
+
inp_dw8 = F.avg_pool2d(inp, 2, stride=2)
|
| 120 |
+
|
| 121 |
+
# 1/4 -> 1/16
|
| 122 |
+
# feature
|
| 123 |
+
fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
|
| 124 |
+
fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
|
| 125 |
+
offset_dw16 = self.conv_offset_16(fmap1_dw16)
|
| 126 |
+
offset_dw16 = self.range_16 * (torch.sigmoid(offset_dw16) - 0.5) * 2.0
|
| 127 |
+
|
| 128 |
+
# context
|
| 129 |
+
net_dw16 = F.avg_pool2d(net, 4, stride=4)
|
| 130 |
+
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
| 131 |
+
|
| 132 |
+
# positional encoding and self-attention
|
| 133 |
+
pos_encoding_fn_small = PositionEncodingSine(
|
| 134 |
+
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
| 135 |
+
)
|
| 136 |
+
# 'n c h w -> n (h w) c'
|
| 137 |
+
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
| 138 |
+
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
| 139 |
+
# 'n c h w -> n (h w) c'
|
| 140 |
+
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
| 141 |
+
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
| 142 |
+
|
| 143 |
+
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
| 144 |
+
fmap1_dw16, fmap2_dw16 = [
|
| 145 |
+
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
|
| 146 |
+
for x in [fmap1_dw16, fmap2_dw16]
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
corr_fn = AGCL(fmap1, fmap2)
|
| 150 |
+
corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8)
|
| 151 |
+
corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn)
|
| 152 |
+
|
| 153 |
+
# Cascaded refinement (1/16 + 1/8 + 1/4)
|
| 154 |
+
predictions = []
|
| 155 |
+
flow = None
|
| 156 |
+
flow_up = None
|
| 157 |
+
if flow_init is not None:
|
| 158 |
+
scale = fmap1.shape[2] / flow_init.shape[2]
|
| 159 |
+
flow = -scale * F.interpolate(
|
| 160 |
+
flow_init,
|
| 161 |
+
size=(fmap1.shape[2], fmap1.shape[3]),
|
| 162 |
+
mode="bilinear",
|
| 163 |
+
align_corners=True,
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
# zero initialization
|
| 167 |
+
flow_dw16 = self.zero_init(fmap1_dw16)
|
| 168 |
+
|
| 169 |
+
# Recurrent Update Module
|
| 170 |
+
# RUM: 1/16
|
| 171 |
+
for itr in range(iters // 2):
|
| 172 |
+
if itr % 2 == 0:
|
| 173 |
+
small_patch = False
|
| 174 |
+
else:
|
| 175 |
+
small_patch = True
|
| 176 |
+
|
| 177 |
+
flow_dw16 = flow_dw16.detach()
|
| 178 |
+
out_corrs = corr_fn_att_dw16(
|
| 179 |
+
flow_dw16, offset_dw16, small_patch=small_patch
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
with autocast(enabled=self.mixed_precision):
|
| 183 |
+
net_dw16, up_mask, delta_flow = self.update_block(
|
| 184 |
+
net_dw16, inp_dw16, out_corrs, flow_dw16
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
flow_dw16 = flow_dw16 + delta_flow
|
| 188 |
+
flow = self.convex_upsample(flow_dw16, up_mask, rate=4)
|
| 189 |
+
flow_up = -4 * F.interpolate(
|
| 190 |
+
flow,
|
| 191 |
+
size=(4 * flow.shape[2], 4 * flow.shape[3]),
|
| 192 |
+
mode="bilinear",
|
| 193 |
+
align_corners=True,
|
| 194 |
+
)
|
| 195 |
+
predictions.append(flow_up)
|
| 196 |
+
|
| 197 |
+
scale = fmap1_dw8.shape[2] / flow.shape[2]
|
| 198 |
+
flow_dw8 = -scale * F.interpolate(
|
| 199 |
+
flow,
|
| 200 |
+
size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
|
| 201 |
+
mode="bilinear",
|
| 202 |
+
align_corners=True,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# RUM: 1/8
|
| 206 |
+
for itr in range(iters // 2):
|
| 207 |
+
if itr % 2 == 0:
|
| 208 |
+
small_patch = False
|
| 209 |
+
else:
|
| 210 |
+
small_patch = True
|
| 211 |
+
|
| 212 |
+
flow_dw8 = flow_dw8.detach()
|
| 213 |
+
out_corrs = corr_fn_dw8(flow_dw8, offset_dw8, small_patch=small_patch)
|
| 214 |
+
|
| 215 |
+
with autocast(enabled=self.mixed_precision):
|
| 216 |
+
net_dw8, up_mask, delta_flow = self.update_block(
|
| 217 |
+
net_dw8, inp_dw8, out_corrs, flow_dw8
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
flow_dw8 = flow_dw8 + delta_flow
|
| 221 |
+
flow = self.convex_upsample(flow_dw8, up_mask, rate=4)
|
| 222 |
+
flow_up = -2 * F.interpolate(
|
| 223 |
+
flow,
|
| 224 |
+
size=(2 * flow.shape[2], 2 * flow.shape[3]),
|
| 225 |
+
mode="bilinear",
|
| 226 |
+
align_corners=True,
|
| 227 |
+
)
|
| 228 |
+
predictions.append(flow_up)
|
| 229 |
+
|
| 230 |
+
scale = fmap1.shape[2] / flow.shape[2]
|
| 231 |
+
flow = -scale * F.interpolate(
|
| 232 |
+
flow,
|
| 233 |
+
size=(fmap1.shape[2], fmap1.shape[3]),
|
| 234 |
+
mode="bilinear",
|
| 235 |
+
align_corners=True,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# RUM: 1/4
|
| 239 |
+
for itr in range(iters):
|
| 240 |
+
if itr % 2 == 0:
|
| 241 |
+
small_patch = False
|
| 242 |
+
else:
|
| 243 |
+
small_patch = True
|
| 244 |
+
|
| 245 |
+
flow = flow.detach()
|
| 246 |
+
out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True)
|
| 247 |
+
|
| 248 |
+
with autocast(enabled=self.mixed_precision):
|
| 249 |
+
net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow)
|
| 250 |
+
|
| 251 |
+
flow = flow + delta_flow
|
| 252 |
+
flow_up = -self.convex_upsample(flow, up_mask, rate=4)
|
| 253 |
+
predictions.append(flow_up)
|
| 254 |
+
|
| 255 |
+
if self.test_mode:
|
| 256 |
+
return flow_up
|
| 257 |
+
|
| 258 |
+
return predictions
|
CREStereo_demo/nets/extractor.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
|
| 6 |
+
class ResidualBlock(nn.Module):
|
| 7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
| 8 |
+
super(ResidualBlock, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
| 11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
| 12 |
+
self.relu = nn.ReLU(inplace=True)
|
| 13 |
+
|
| 14 |
+
num_groups = planes // 8
|
| 15 |
+
|
| 16 |
+
if norm_fn == 'group':
|
| 17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 19 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 20 |
+
|
| 21 |
+
elif norm_fn == 'batch':
|
| 22 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 24 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 25 |
+
|
| 26 |
+
elif norm_fn == 'instance':
|
| 27 |
+
self.norm1 = nn.InstanceNorm2d(planes, affine=False)
|
| 28 |
+
self.norm2 = nn.InstanceNorm2d(planes, affine=False)
|
| 29 |
+
self.norm3 = nn.InstanceNorm2d(planes, affine=False)
|
| 30 |
+
|
| 31 |
+
elif norm_fn == 'none':
|
| 32 |
+
self.norm1 = nn.Sequential()
|
| 33 |
+
self.norm2 = nn.Sequential()
|
| 34 |
+
self.norm3 = nn.Sequential()
|
| 35 |
+
|
| 36 |
+
self.downsample = nn.Sequential(
|
| 37 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
y = x
|
| 42 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 43 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 44 |
+
|
| 45 |
+
x = self.downsample(x)
|
| 46 |
+
|
| 47 |
+
return self.relu(x+y)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BasicEncoder(nn.Module):
|
| 51 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
| 52 |
+
super(BasicEncoder, self).__init__()
|
| 53 |
+
self.norm_fn = norm_fn
|
| 54 |
+
|
| 55 |
+
if self.norm_fn == 'group':
|
| 56 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 57 |
+
|
| 58 |
+
elif self.norm_fn == 'batch':
|
| 59 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 60 |
+
|
| 61 |
+
elif self.norm_fn == 'instance':
|
| 62 |
+
self.norm1 = nn.InstanceNorm2d(64, affine=False)
|
| 63 |
+
|
| 64 |
+
elif self.norm_fn == 'none':
|
| 65 |
+
self.norm1 = nn.Sequential()
|
| 66 |
+
|
| 67 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 68 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 69 |
+
|
| 70 |
+
self.in_planes = 64
|
| 71 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 72 |
+
self.layer2 = self._make_layer(96, stride=2)
|
| 73 |
+
self.layer3 = self._make_layer(128, stride=1)
|
| 74 |
+
|
| 75 |
+
# output convolution
|
| 76 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
| 77 |
+
|
| 78 |
+
self.dropout = None
|
| 79 |
+
if dropout > 0:
|
| 80 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 81 |
+
|
| 82 |
+
for m in self.modules():
|
| 83 |
+
if isinstance(m, nn.Conv2d):
|
| 84 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 85 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 86 |
+
if m.weight is not None:
|
| 87 |
+
nn.init.constant_(m.weight, 1)
|
| 88 |
+
if m.bias is not None:
|
| 89 |
+
nn.init.constant_(m.bias, 0)
|
| 90 |
+
|
| 91 |
+
def _make_layer(self, dim, stride=1):
|
| 92 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 93 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 94 |
+
layers = (layer1, layer2)
|
| 95 |
+
|
| 96 |
+
self.in_planes = dim
|
| 97 |
+
return nn.Sequential(*layers)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
|
| 101 |
+
# if input is list, combine batch dimension
|
| 102 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 103 |
+
if is_list:
|
| 104 |
+
batch_dim = x[0].shape[0]
|
| 105 |
+
x = torch.cat(x, dim=0)
|
| 106 |
+
|
| 107 |
+
x = self.conv1(x)
|
| 108 |
+
x = self.norm1(x)
|
| 109 |
+
x = self.relu1(x)
|
| 110 |
+
|
| 111 |
+
x = self.layer1(x)
|
| 112 |
+
x = self.layer2(x)
|
| 113 |
+
x = self.layer3(x)
|
| 114 |
+
|
| 115 |
+
x = self.conv2(x)
|
| 116 |
+
|
| 117 |
+
if self.dropout is not None:
|
| 118 |
+
x = self.dropout(x)
|
| 119 |
+
|
| 120 |
+
if is_list:
|
| 121 |
+
x = torch.split(x, x.shape[0]//2, dim=0)
|
| 122 |
+
|
| 123 |
+
return x
|
CREStereo_demo/nets/update.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py
|
| 6 |
+
class FlowHead(nn.Module):
|
| 7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
| 8 |
+
super(FlowHead, self).__init__()
|
| 9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
| 10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
| 11 |
+
self.relu = nn.ReLU(inplace=True)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.conv2(self.relu(self.conv1(x)))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SepConvGRU(nn.Module):
|
| 18 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
| 19 |
+
super(SepConvGRU, self).__init__()
|
| 20 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
| 21 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
| 22 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
| 23 |
+
|
| 24 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
| 25 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
| 26 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
| 27 |
+
|
| 28 |
+
def forward(self, h, x):
|
| 29 |
+
# horizontal
|
| 30 |
+
hx = torch.cat([h, x], dim=1)
|
| 31 |
+
z = torch.sigmoid(self.convz1(hx))
|
| 32 |
+
r = torch.sigmoid(self.convr1(hx))
|
| 33 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
| 34 |
+
h = (1-z) * h + z * q
|
| 35 |
+
|
| 36 |
+
# vertical
|
| 37 |
+
hx = torch.cat([h, x], dim=1)
|
| 38 |
+
z = torch.sigmoid(self.convz2(hx))
|
| 39 |
+
r = torch.sigmoid(self.convr2(hx))
|
| 40 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
| 41 |
+
h = (1-z) * h + z * q
|
| 42 |
+
|
| 43 |
+
return h
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BasicMotionEncoder(nn.Module):
|
| 47 |
+
def __init__(self, cor_planes):
|
| 48 |
+
super(BasicMotionEncoder, self).__init__()
|
| 49 |
+
|
| 50 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
| 51 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
| 52 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
| 53 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
| 54 |
+
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
| 55 |
+
|
| 56 |
+
def forward(self, flow, corr):
|
| 57 |
+
cor = F.relu(self.convc1(corr))
|
| 58 |
+
cor = F.relu(self.convc2(cor))
|
| 59 |
+
flo = F.relu(self.convf1(flow))
|
| 60 |
+
flo = F.relu(self.convf2(flo))
|
| 61 |
+
|
| 62 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 63 |
+
out = F.relu(self.conv(cor_flo))
|
| 64 |
+
return torch.cat([out, flow], dim=1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class BasicUpdateBlock(nn.Module):
|
| 68 |
+
def __init__(self, hidden_dim, cor_planes, mask_size=8):
|
| 69 |
+
super(BasicUpdateBlock, self).__init__()
|
| 70 |
+
|
| 71 |
+
self.encoder = BasicMotionEncoder(cor_planes)
|
| 72 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
| 73 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
| 74 |
+
|
| 75 |
+
self.mask = nn.Sequential(
|
| 76 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
| 77 |
+
nn.ReLU(inplace=True),
|
| 78 |
+
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
|
| 79 |
+
|
| 80 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
| 81 |
+
# print(inp.shape, corr.shape, flow.shape)
|
| 82 |
+
motion_features = self.encoder(flow, corr)
|
| 83 |
+
# print(motion_features.shape, inp.shape)
|
| 84 |
+
inp = torch.cat((inp, motion_features), dim=1)
|
| 85 |
+
|
| 86 |
+
net = self.gru(net, inp)
|
| 87 |
+
delta_flow = self.flow_head(net)
|
| 88 |
+
|
| 89 |
+
# scale mask to balence gradients
|
| 90 |
+
mask = .25 * self.mask(net)
|
| 91 |
+
return net, mask, delta_flow
|
CREStereo_demo/nets/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .utils import bilinear_sampler, coords_grid, manual_pad
|
CREStereo_demo/nets/utils/utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
|
| 6 |
+
|
| 7 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
| 8 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
| 9 |
+
H, W = img.shape[-2:]
|
| 10 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
| 11 |
+
xgrid = 2*xgrid/(W-1) - 1
|
| 12 |
+
ygrid = 2*ygrid/(H-1) - 1
|
| 13 |
+
|
| 14 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 15 |
+
# img = F.grid_sample(img, grid, align_corners=True)
|
| 16 |
+
img = bilinear_grid_sample(img, grid, align_corners=True)
|
| 17 |
+
|
| 18 |
+
if mask:
|
| 19 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 20 |
+
return img, mask.float()
|
| 21 |
+
|
| 22 |
+
return img
|
| 23 |
+
|
| 24 |
+
def coords_grid(batch, ht, wd, device):
|
| 25 |
+
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
|
| 26 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 27 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 28 |
+
|
| 29 |
+
def manual_pad(x, pady, padx):
|
| 30 |
+
|
| 31 |
+
pad = (padx, padx, pady, pady)
|
| 32 |
+
return F.pad(x.clone().detach(), pad, "replicate")
|
| 33 |
+
|
| 34 |
+
# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160
|
| 35 |
+
def bilinear_grid_sample(im, grid, align_corners=False):
|
| 36 |
+
"""Given an input and a flow-field grid, computes the output using input
|
| 37 |
+
values and pixel locations from grid. Supported only bilinear interpolation
|
| 38 |
+
method to sample the input pixels.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
im (torch.Tensor): Input feature map, shape (N, C, H, W)
|
| 42 |
+
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
|
| 43 |
+
align_corners {bool}: If set to True, the extrema (-1 and 1) are
|
| 44 |
+
considered as referring to the center points of the input’s
|
| 45 |
+
corner pixels. If set to False, they are instead considered as
|
| 46 |
+
referring to the corner points of the input’s corner pixels,
|
| 47 |
+
making the sampling more resolution agnostic.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
|
| 51 |
+
"""
|
| 52 |
+
n, c, h, w = im.shape
|
| 53 |
+
gn, gh, gw, _ = grid.shape
|
| 54 |
+
assert n == gn
|
| 55 |
+
|
| 56 |
+
x = grid[:, :, :, 0]
|
| 57 |
+
y = grid[:, :, :, 1]
|
| 58 |
+
|
| 59 |
+
if align_corners:
|
| 60 |
+
x = ((x + 1) / 2) * (w - 1)
|
| 61 |
+
y = ((y + 1) / 2) * (h - 1)
|
| 62 |
+
else:
|
| 63 |
+
x = ((x + 1) * w - 1) / 2
|
| 64 |
+
y = ((y + 1) * h - 1) / 2
|
| 65 |
+
|
| 66 |
+
x = x.view(n, -1)
|
| 67 |
+
y = y.view(n, -1)
|
| 68 |
+
|
| 69 |
+
x0 = torch.floor(x).long()
|
| 70 |
+
y0 = torch.floor(y).long()
|
| 71 |
+
x1 = x0 + 1
|
| 72 |
+
y1 = y0 + 1
|
| 73 |
+
|
| 74 |
+
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
|
| 75 |
+
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
|
| 76 |
+
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
|
| 77 |
+
wd = ((x - x0) * (y - y0)).unsqueeze(1)
|
| 78 |
+
|
| 79 |
+
# Apply default for grid_sample function zero padding
|
| 80 |
+
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
|
| 81 |
+
padded_h = h + 2
|
| 82 |
+
padded_w = w + 2
|
| 83 |
+
# save points positions after padding
|
| 84 |
+
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
|
| 85 |
+
|
| 86 |
+
# Clip coordinates to padded image size
|
| 87 |
+
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
|
| 88 |
+
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
|
| 89 |
+
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
|
| 90 |
+
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
|
| 91 |
+
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
|
| 92 |
+
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
|
| 93 |
+
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
|
| 94 |
+
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
|
| 95 |
+
|
| 96 |
+
im_padded = im_padded.view(n, c, -1)
|
| 97 |
+
|
| 98 |
+
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
| 99 |
+
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
| 100 |
+
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
| 101 |
+
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
| 102 |
+
|
| 103 |
+
Ia = torch.gather(im_padded, 2, x0_y0)
|
| 104 |
+
Ib = torch.gather(im_padded, 2, x0_y1)
|
| 105 |
+
Ic = torch.gather(im_padded, 2, x1_y0)
|
| 106 |
+
Id = torch.gather(im_padded, 2, x1_y1)
|
| 107 |
+
|
| 108 |
+
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
|
FoundationStereo_demo/Utils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import os, sys, time, pickle, itertools, datetime, imageio, logging, joblib, importlib, argparse
|
| 11 |
+
# Import torch and related modules only when needed inside functions to avoid CUDA init
|
| 12 |
+
# import torch, torchvision # Moved to function-level imports
|
| 13 |
+
# import torch.nn.functional as F # Moved to function-level imports
|
| 14 |
+
# import torch.nn as nn # Moved to function-level imports
|
| 15 |
+
from functools import partial
|
| 16 |
+
import pandas as pd
|
| 17 |
+
# Import open3d only when needed to avoid CUDA conflicts
|
| 18 |
+
# import open3d as o3d # Moved to function-level imports
|
| 19 |
+
import cv2
|
| 20 |
+
import numpy as np
|
| 21 |
+
# Removed transformations import to avoid ModuleNotFoundError
|
| 22 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 23 |
+
sys.path.append(code_dir)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def set_logging_format(level=logging.INFO):
|
| 28 |
+
importlib.reload(logging)
|
| 29 |
+
FORMAT = '%(message)s'
|
| 30 |
+
logging.basicConfig(level=level, format=FORMAT, datefmt='%m-%d|%H:%M:%S')
|
| 31 |
+
|
| 32 |
+
# Only call set_logging_format when explicitly needed, not during import
|
| 33 |
+
# set_logging_format() # Commented out to avoid automatic execution
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def set_seed(random_seed=0):
|
| 38 |
+
import torch # Import torch only when function is called
|
| 39 |
+
import random
|
| 40 |
+
import numpy as np
|
| 41 |
+
|
| 42 |
+
np.random.seed(random_seed)
|
| 43 |
+
random.seed(random_seed)
|
| 44 |
+
torch.manual_seed(random_seed)
|
| 45 |
+
# Skip CUDA seeding to avoid initialization issues in ZeroGPU
|
| 46 |
+
# CUDA seeding should be done within @spaces.GPU context
|
| 47 |
+
try:
|
| 48 |
+
# Only try CUDA operations if we're already in a CUDA context
|
| 49 |
+
if hasattr(torch.cuda, '_initialized') and torch.cuda._initialized:
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
torch.cuda.manual_seed_all(random_seed)
|
| 52 |
+
except (RuntimeError, AttributeError):
|
| 53 |
+
pass # CUDA not initialized yet or not available
|
| 54 |
+
torch.backends.cudnn.deterministic = True
|
| 55 |
+
torch.backends.cudnn.benchmark = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def toOpen3dCloud(points,colors=None,normals=None):
|
| 59 |
+
import open3d as o3d # Import only when function is called
|
| 60 |
+
|
| 61 |
+
cloud = o3d.geometry.PointCloud()
|
| 62 |
+
cloud.points = o3d.utility.Vector3dVector(points.astype(np.float64))
|
| 63 |
+
if colors is not None:
|
| 64 |
+
if colors.max()>1:
|
| 65 |
+
colors = colors/255.0
|
| 66 |
+
cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
|
| 67 |
+
if normals is not None:
|
| 68 |
+
cloud.normals = o3d.utility.Vector3dVector(normals.astype(np.float64))
|
| 69 |
+
return cloud
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def depth2xyzmap(depth:np.ndarray, K, uvs:np.ndarray=None, zmin=0.1):
|
| 74 |
+
invalid_mask = (depth<zmin)
|
| 75 |
+
H,W = depth.shape[:2]
|
| 76 |
+
if uvs is None:
|
| 77 |
+
vs,us = np.meshgrid(np.arange(0,H),np.arange(0,W), sparse=False, indexing='ij')
|
| 78 |
+
vs = vs.reshape(-1)
|
| 79 |
+
us = us.reshape(-1)
|
| 80 |
+
else:
|
| 81 |
+
uvs = uvs.round().astype(int)
|
| 82 |
+
us = uvs[:,0]
|
| 83 |
+
vs = uvs[:,1]
|
| 84 |
+
zs = depth[vs,us]
|
| 85 |
+
xs = (us-K[0,2])*zs/K[0,0]
|
| 86 |
+
ys = (vs-K[1,2])*zs/K[1,1]
|
| 87 |
+
pts = np.stack((xs.reshape(-1),ys.reshape(-1),zs.reshape(-1)), 1) #(N,3)
|
| 88 |
+
xyz_map = np.zeros((H,W,3), dtype=np.float32)
|
| 89 |
+
xyz_map[vs,us] = pts
|
| 90 |
+
if invalid_mask.any():
|
| 91 |
+
xyz_map[invalid_mask] = 0
|
| 92 |
+
return xyz_map
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def freeze_model(model):
|
| 97 |
+
# This function now works with any model passed to it
|
| 98 |
+
# No need to import torch at module level
|
| 99 |
+
model = model.eval()
|
| 100 |
+
for p in model.parameters():
|
| 101 |
+
p.requires_grad = False
|
| 102 |
+
for p in model.buffers():
|
| 103 |
+
p.requires_grad = False
|
| 104 |
+
return model
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_resize_keep_aspect_ratio(H, W, divider=16, max_H=1232, max_W=1232):
|
| 109 |
+
assert max_H%divider==0
|
| 110 |
+
assert max_W%divider==0
|
| 111 |
+
|
| 112 |
+
def round_by_divider(x):
|
| 113 |
+
return int(np.ceil(x/divider)*divider)
|
| 114 |
+
|
| 115 |
+
H_resize = round_by_divider(H) #!NOTE KITTI width=1242
|
| 116 |
+
W_resize = round_by_divider(W)
|
| 117 |
+
if H_resize>max_H or W_resize>max_W:
|
| 118 |
+
if H_resize>W_resize:
|
| 119 |
+
W_resize = round_by_divider(W_resize*max_H/H_resize)
|
| 120 |
+
H_resize = max_H
|
| 121 |
+
else:
|
| 122 |
+
H_resize = round_by_divider(H_resize*max_W/W_resize)
|
| 123 |
+
W_resize = max_W
|
| 124 |
+
return int(H_resize), int(W_resize)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def vis_disparity(disp, min_val=None, max_val=None, invalid_thres=np.inf, color_map=cv2.COLORMAP_TURBO, cmap=None, other_output={}):
|
| 128 |
+
"""
|
| 129 |
+
@disp: np array (H,W)
|
| 130 |
+
@invalid_thres: > thres is invalid
|
| 131 |
+
"""
|
| 132 |
+
disp = disp.copy()
|
| 133 |
+
H,W = disp.shape[:2]
|
| 134 |
+
invalid_mask = disp>=invalid_thres
|
| 135 |
+
if (invalid_mask==0).sum()==0:
|
| 136 |
+
other_output['min_val'] = None
|
| 137 |
+
other_output['max_val'] = None
|
| 138 |
+
return np.zeros((H,W,3))
|
| 139 |
+
if min_val is None:
|
| 140 |
+
min_val = disp[invalid_mask==0].min()
|
| 141 |
+
if max_val is None:
|
| 142 |
+
max_val = disp[invalid_mask==0].max()
|
| 143 |
+
other_output['min_val'] = min_val
|
| 144 |
+
other_output['max_val'] = max_val
|
| 145 |
+
vis = ((disp-min_val)/(max_val-min_val)).clip(0,1) * 255
|
| 146 |
+
if cmap is None:
|
| 147 |
+
vis = cv2.applyColorMap(vis.clip(0, 255).astype(np.uint8), color_map)[...,::-1]
|
| 148 |
+
else:
|
| 149 |
+
vis = cmap(vis.astype(np.uint8))[...,:3]*255
|
| 150 |
+
if invalid_mask.any():
|
| 151 |
+
vis[invalid_mask] = 0
|
| 152 |
+
return vis.astype(np.uint8)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def depth_uint8_decoding(depth_uint8, scale=1000):
|
| 157 |
+
depth_uint8 = depth_uint8.astype(float)
|
| 158 |
+
out = depth_uint8[...,0]*255*255 + depth_uint8[...,1]*255 + depth_uint8[...,2]
|
| 159 |
+
return out/float(scale)
|
| 160 |
+
|
FoundationStereo_demo/app.py
ADDED
|
@@ -0,0 +1,1138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import tempfile
|
| 5 |
+
import zipfile
|
| 6 |
+
import gc
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import imageio
|
| 13 |
+
|
| 14 |
+
# Import spaces BEFORE torch to ensure proper ZeroGPU initialization
|
| 15 |
+
import spaces
|
| 16 |
+
|
| 17 |
+
# Import torch after spaces - avoid any CUDA calls during import
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
# Completely avoid CUDA operations during import phase
|
| 21 |
+
# Do not set default tensor type or modify CUDA settings outside GPU context
|
| 22 |
+
# torch.set_default_tensor_type('torch.FloatTensor') # Commented out - causes CUDA init
|
| 23 |
+
|
| 24 |
+
# Import other safe modules
|
| 25 |
+
from omegaconf import OmegaConf
|
| 26 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 27 |
+
|
| 28 |
+
# Do not modify CUDA settings during import - this can trigger CUDA initialization
|
| 29 |
+
# torch.backends.cudnn.enabled = False # Commented out
|
| 30 |
+
# torch.backends.cudnn.benchmark = False # Commented out
|
| 31 |
+
|
| 32 |
+
# Use current directory as base (gradio_app folder)
|
| 33 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 34 |
+
base_dir = current_dir # gradio_app folder
|
| 35 |
+
|
| 36 |
+
# Add current directory to path for local imports
|
| 37 |
+
sys.path.insert(0, current_dir)
|
| 38 |
+
|
| 39 |
+
# DO NOT import any local modules here that might use CUDA
|
| 40 |
+
# All local module imports will be done inside GPU-decorated functions
|
| 41 |
+
|
| 42 |
+
# Import Open3D with error handling - avoid any CUDA operations
|
| 43 |
+
OPEN3D_AVAILABLE = False # Will be set properly in GPU context
|
| 44 |
+
try:
|
| 45 |
+
# Set Open3D to CPU mode to avoid CUDA initialization
|
| 46 |
+
os.environ['OPEN3D_CPU_RENDERING'] = '1'
|
| 47 |
+
# Don't import open3d here - do it inside GPU functions
|
| 48 |
+
# import open3d as o3d
|
| 49 |
+
OPEN3D_AVAILABLE = True # Assume available, will check inside GPU context
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logging.warning(f"Open3D setup failed: {e}")
|
| 52 |
+
OPEN3D_AVAILABLE = False
|
| 53 |
+
|
| 54 |
+
# Configure logging
|
| 55 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 56 |
+
|
| 57 |
+
# Hugging Face model repository configuration
|
| 58 |
+
HF_REPO_ID = "shriarul5273/FoundationStereo_models"
|
| 59 |
+
MODEL_VARIANTS = {
|
| 60 |
+
"11-33-40": {
|
| 61 |
+
"display_name": "FoundationStereo (Low-cost variant - 11-33-40)",
|
| 62 |
+
"model_file": "pretrained_models/11-33-40/model_best_bp2.pth",
|
| 63 |
+
"config_file": "pretrained_models/11-33-40/cfg.yaml"
|
| 64 |
+
},
|
| 65 |
+
"23-51-11": {
|
| 66 |
+
"display_name": "FoundationStereo (High-quality variant - 23-51-11)",
|
| 67 |
+
"model_file": "pretrained_models/23-51-11/model_best_bp2.pth",
|
| 68 |
+
"config_file": "pretrained_models/23-51-11/cfg.yaml"
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Global variables for model caching
|
| 73 |
+
MODEL_PATH: str = None
|
| 74 |
+
CONFIG_PATH: str = None
|
| 75 |
+
|
| 76 |
+
# Model cache to avoid reloading when selection doesn't change
|
| 77 |
+
_cached_model = None
|
| 78 |
+
_cached_device = None
|
| 79 |
+
_cached_model_selection = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def aggressive_cleanup():
|
| 83 |
+
"""Perform basic cleanup - no CUDA operations outside GPU context"""
|
| 84 |
+
import gc
|
| 85 |
+
gc.collect()
|
| 86 |
+
logging.info("Performed basic memory cleanup")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@spaces.GPU
|
| 90 |
+
def check_gpu_memory():
|
| 91 |
+
"""Check and log current GPU memory usage - only call within GPU context"""
|
| 92 |
+
try:
|
| 93 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 94 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 95 |
+
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 96 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 97 |
+
|
| 98 |
+
logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
|
| 99 |
+
return allocated, reserved, max_allocated, total
|
| 100 |
+
except RuntimeError as e:
|
| 101 |
+
logging.warning(f"Failed to get GPU memory info: {e}")
|
| 102 |
+
return None, None, None, None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def download_model_from_hf(variant: str, force_download: bool = False) -> Tuple[str, str]:
|
| 106 |
+
"""
|
| 107 |
+
Download model and config files from Hugging Face Hub
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
variant: Model variant ("11-33-40" or "23-51-11")
|
| 111 |
+
force_download: Force re-download even if files exist locally
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tuple of (model_path, config_path)
|
| 115 |
+
"""
|
| 116 |
+
if variant not in MODEL_VARIANTS:
|
| 117 |
+
raise ValueError(f"Unknown model variant: {variant}. Available: {list(MODEL_VARIANTS.keys())}")
|
| 118 |
+
|
| 119 |
+
variant_info = MODEL_VARIANTS[variant]
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
if not force_download:
|
| 123 |
+
logging.info(f"📦 Checking cache for model variant: {variant}")
|
| 124 |
+
else:
|
| 125 |
+
logging.info(f"🔄 Force downloading model variant: {variant}")
|
| 126 |
+
|
| 127 |
+
# Download model file
|
| 128 |
+
model_path = hf_hub_download(
|
| 129 |
+
repo_id=HF_REPO_ID,
|
| 130 |
+
filename=variant_info["model_file"],
|
| 131 |
+
force_download=force_download,
|
| 132 |
+
local_dir_use_symlinks=False
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Download config file
|
| 136 |
+
config_path = hf_hub_download(
|
| 137 |
+
repo_id=HF_REPO_ID,
|
| 138 |
+
filename=variant_info["config_file"],
|
| 139 |
+
force_download=force_download,
|
| 140 |
+
local_dir_use_symlinks=False
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if force_download:
|
| 144 |
+
logging.info(f"✅ Successfully downloaded {variant} model files")
|
| 145 |
+
else:
|
| 146 |
+
logging.info(f"✅ Successfully loaded {variant} model files from cache")
|
| 147 |
+
|
| 148 |
+
logging.debug(f"Model: {model_path}")
|
| 149 |
+
logging.debug(f"Config: {config_path}")
|
| 150 |
+
|
| 151 |
+
return model_path, config_path
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logging.error(f"Failed to download model {variant}: {e}")
|
| 155 |
+
raise RuntimeError(f"Failed to download model {variant} from Hugging Face: {e}")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_available_models() -> dict:
|
| 159 |
+
"""Get all available models with their display names and download info"""
|
| 160 |
+
models = {}
|
| 161 |
+
|
| 162 |
+
# First check local models (legacy support)
|
| 163 |
+
search_dirs = [
|
| 164 |
+
os.path.join(current_dir, "pretrained_models"),
|
| 165 |
+
os.path.join(os.path.dirname(current_dir), "pretrained_models")
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
for search_dir in search_dirs:
|
| 169 |
+
if os.path.exists(search_dir):
|
| 170 |
+
for model_dir in os.listdir(search_dir):
|
| 171 |
+
model_path = os.path.join(search_dir, model_dir, "model_best_bp2.pth")
|
| 172 |
+
cfg_path = os.path.join(search_dir, model_dir, "cfg.yaml")
|
| 173 |
+
|
| 174 |
+
if os.path.exists(model_path) and os.path.exists(cfg_path):
|
| 175 |
+
# Create a descriptive name for the model
|
| 176 |
+
if model_dir == "11-33-40":
|
| 177 |
+
display_name = "FoundationStereo (Low-cost variant - 11-33-40) [Local]"
|
| 178 |
+
elif model_dir == "23-51-11":
|
| 179 |
+
display_name = "FoundationStereo (High-quality variant - 23-51-11) [Local]"
|
| 180 |
+
else:
|
| 181 |
+
display_name = f"FoundationStereo ({model_dir}) [Local]"
|
| 182 |
+
|
| 183 |
+
models[display_name] = {
|
| 184 |
+
"model_path": model_path,
|
| 185 |
+
"config_path": cfg_path,
|
| 186 |
+
"variant": model_dir,
|
| 187 |
+
"source": "local"
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
# Add Hugging Face models
|
| 191 |
+
for variant, info in MODEL_VARIANTS.items():
|
| 192 |
+
display_name = f"{info['display_name']} [Hugging Face]"
|
| 193 |
+
models[display_name] = {
|
| 194 |
+
"model_path": None, # Will be downloaded when needed
|
| 195 |
+
"config_path": None, # Will be downloaded when needed
|
| 196 |
+
"variant": variant,
|
| 197 |
+
"source": "huggingface"
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
return models
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def find_model_path() -> Tuple[Optional[str], Optional[str]]:
|
| 204 |
+
"""Find available model and config paths (legacy function for backward compatibility)"""
|
| 205 |
+
models = get_available_models()
|
| 206 |
+
if models:
|
| 207 |
+
# Prefer Hugging Face models over local ones
|
| 208 |
+
# First try to find HF low-cost variant
|
| 209 |
+
for display_name in models:
|
| 210 |
+
if "11-33-40" in display_name and "[Hugging Face]" in display_name:
|
| 211 |
+
return get_model_paths_from_selection(display_name)
|
| 212 |
+
|
| 213 |
+
# Then try local low-cost variant
|
| 214 |
+
for display_name in models:
|
| 215 |
+
if "11-33-40" in display_name:
|
| 216 |
+
return get_model_paths_from_selection(display_name)
|
| 217 |
+
|
| 218 |
+
# If no low-cost variant, return the first available
|
| 219 |
+
first_model_name = next(iter(models.keys()))
|
| 220 |
+
return get_model_paths_from_selection(first_model_name)
|
| 221 |
+
return None, None
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[str]]:
|
| 225 |
+
"""Get model and config paths from the selected model"""
|
| 226 |
+
models = get_available_models()
|
| 227 |
+
|
| 228 |
+
# Check if it's in our models dict
|
| 229 |
+
if model_selection in models:
|
| 230 |
+
model_info = models[model_selection]
|
| 231 |
+
|
| 232 |
+
# If it's a Hugging Face model, download it first (or get from cache)
|
| 233 |
+
if model_info["source"] == "huggingface":
|
| 234 |
+
variant = model_info["variant"]
|
| 235 |
+
try:
|
| 236 |
+
logging.info(f"📦 Retrieving {variant} model from cache...")
|
| 237 |
+
model_path, config_path = download_model_from_hf(variant, force_download=False)
|
| 238 |
+
return model_path, config_path
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logging.error(f"Failed to get model {variant} from cache: {e}")
|
| 241 |
+
return None, None
|
| 242 |
+
else:
|
| 243 |
+
# Local model
|
| 244 |
+
logging.info(f"📁 Using local model: {model_selection}")
|
| 245 |
+
return model_info["model_path"], model_info["config_path"]
|
| 246 |
+
|
| 247 |
+
# Handle direct HF model selection (fallback)
|
| 248 |
+
elif "[Hugging Face]" in model_selection:
|
| 249 |
+
if "11-33-40" in model_selection:
|
| 250 |
+
variant = "11-33-40"
|
| 251 |
+
elif "23-51-11" in model_selection:
|
| 252 |
+
variant = "23-51-11"
|
| 253 |
+
else:
|
| 254 |
+
logging.error(f"Unknown HF model variant in: {model_selection}")
|
| 255 |
+
return None, None
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
logging.info(f"📦 Retrieving {variant} model from cache...")
|
| 259 |
+
model_path, config_path = download_model_from_hf(variant, force_download=False)
|
| 260 |
+
return model_path, config_path
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logging.error(f"Failed to get model {variant} from cache: {e}")
|
| 263 |
+
return None, None
|
| 264 |
+
|
| 265 |
+
return None, None
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_cached_model(model_selection: str):
|
| 269 |
+
"""Get cached model or load new one if selection changed"""
|
| 270 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 271 |
+
|
| 272 |
+
# Get model paths from selection
|
| 273 |
+
model_path, config_path = get_model_paths_from_selection(model_selection)
|
| 274 |
+
|
| 275 |
+
if model_path is None or config_path is None:
|
| 276 |
+
raise ValueError(f"Selected model not found: {model_selection}")
|
| 277 |
+
|
| 278 |
+
# Load model fresh for each inference (ZeroGPU optimized)
|
| 279 |
+
# Since models are pre-downloaded, this should be fast
|
| 280 |
+
logging.info(f"🚀 Loading cached model: {model_selection}")
|
| 281 |
+
model, device = load_model_for_inference(model_path, config_path)
|
| 282 |
+
|
| 283 |
+
logging.info(f"✅ Model loaded successfully from cache: {model_selection}")
|
| 284 |
+
return model, device
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def clear_model_cache():
|
| 288 |
+
"""Clear the cached model to free memory"""
|
| 289 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 290 |
+
|
| 291 |
+
if _cached_model is not None:
|
| 292 |
+
logging.info("Clearing model cache...")
|
| 293 |
+
del _cached_model
|
| 294 |
+
_cached_model = None
|
| 295 |
+
_cached_device = None
|
| 296 |
+
_cached_model_selection = None
|
| 297 |
+
|
| 298 |
+
# Simple cleanup
|
| 299 |
+
import gc
|
| 300 |
+
gc.collect()
|
| 301 |
+
logging.info("Model cache cleared")
|
| 302 |
+
else:
|
| 303 |
+
logging.info("No model in cache to clear")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@spaces.GPU
|
| 307 |
+
def load_model_for_inference(model_path: str, cfg_path: str):
|
| 308 |
+
"""Load model temporarily for inference (demo-style)"""
|
| 309 |
+
# Set CUDA settings safely within GPU context
|
| 310 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor') # Now safe to use CUDA tensors
|
| 311 |
+
torch.backends.cudnn.enabled = True
|
| 312 |
+
torch.backends.cudnn.benchmark = True
|
| 313 |
+
|
| 314 |
+
# Import these inside the function to avoid early CUDA initialization
|
| 315 |
+
try:
|
| 316 |
+
# Import selectively to avoid CUDA calls in Utils
|
| 317 |
+
from core.foundation_stereo import FoundationStereo
|
| 318 |
+
from omegaconf import OmegaConf
|
| 319 |
+
logging.info("Successfully imported required modules")
|
| 320 |
+
|
| 321 |
+
# Import set_logging_format safely
|
| 322 |
+
from Utils import set_logging_format
|
| 323 |
+
set_logging_format()
|
| 324 |
+
|
| 325 |
+
# Manual seed setting to avoid CUDA calls in Utils.set_seed
|
| 326 |
+
import random
|
| 327 |
+
random_seed = 0
|
| 328 |
+
np.random.seed(random_seed)
|
| 329 |
+
random.seed(random_seed)
|
| 330 |
+
torch.manual_seed(random_seed)
|
| 331 |
+
# CUDA seeding will be done after device is available
|
| 332 |
+
|
| 333 |
+
logging.info("Set logging format and seed")
|
| 334 |
+
except Exception as e:
|
| 335 |
+
logging.error(f"Failed to import modules: {e}")
|
| 336 |
+
raise RuntimeError(f"Import failed: {e}")
|
| 337 |
+
|
| 338 |
+
# Check if CUDA is available after ZeroGPU initialization
|
| 339 |
+
if not torch.cuda.is_available():
|
| 340 |
+
raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.")
|
| 341 |
+
|
| 342 |
+
# Use the first available CUDA device
|
| 343 |
+
device = torch.device("cuda")
|
| 344 |
+
|
| 345 |
+
# Now set CUDA seed safely within GPU context
|
| 346 |
+
try:
|
| 347 |
+
torch.cuda.manual_seed_all(random_seed)
|
| 348 |
+
torch.backends.cudnn.deterministic = True
|
| 349 |
+
torch.backends.cudnn.benchmark = False
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logging.warning(f"Could not set CUDA seed: {e}")
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
# Load config
|
| 355 |
+
cfg = OmegaConf.load(cfg_path)
|
| 356 |
+
cfg.setdefault("vit_size", "vitl")
|
| 357 |
+
logging.info("Loaded config file")
|
| 358 |
+
|
| 359 |
+
# Create model
|
| 360 |
+
model = FoundationStereo(cfg).to(device)
|
| 361 |
+
model.eval()
|
| 362 |
+
logging.info("Created model")
|
| 363 |
+
|
| 364 |
+
# Load checkpoint
|
| 365 |
+
ckpt = torch.load(model_path, map_location=device)
|
| 366 |
+
model.load_state_dict(ckpt["model"], strict=True)
|
| 367 |
+
logging.info("Loaded model weights")
|
| 368 |
+
|
| 369 |
+
# Memory optimizations
|
| 370 |
+
torch.set_grad_enabled(False)
|
| 371 |
+
model.half() # Use half precision
|
| 372 |
+
logging.info("Applied memory optimizations")
|
| 373 |
+
|
| 374 |
+
return model, device
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logging.error(f"Model loading failed: {e}")
|
| 378 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
# Fixed with static duration
|
| 382 |
+
@spaces.GPU(duration=60) # Static 60 seconds for basic processing
|
| 383 |
+
def process_stereo_pair(model_selection: str, left_image: np.ndarray, right_image: np.ndarray,
|
| 384 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
|
| 385 |
+
"""
|
| 386 |
+
Main processing function for stereo pair (with model caching)
|
| 387 |
+
"""
|
| 388 |
+
logging.info("Starting stereo pair processing...")
|
| 389 |
+
|
| 390 |
+
if left_image is None or right_image is None:
|
| 391 |
+
return None, "❌ Please upload both left and right images."
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
# Import these inside to avoid early CUDA calls
|
| 395 |
+
logging.info("Importing required modules...")
|
| 396 |
+
from core.utils.utils import InputPadder
|
| 397 |
+
# Import vis_disparity safely - it shouldn't have CUDA calls but be careful
|
| 398 |
+
from Utils import vis_disparity
|
| 399 |
+
logging.info("✅ Successfully imported processing modules")
|
| 400 |
+
|
| 401 |
+
# Get cached model (will load if not cached or selection changed)
|
| 402 |
+
variant_name = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else model_selection
|
| 403 |
+
progress(0.1, desc=f"Loading cached model ({variant_name})...")
|
| 404 |
+
logging.info("🚀 Getting cached model...")
|
| 405 |
+
model, device = get_cached_model(model_selection)
|
| 406 |
+
logging.info("✅ Cached model loaded successfully")
|
| 407 |
+
|
| 408 |
+
progress(0.2, desc="Preprocessing images...")
|
| 409 |
+
|
| 410 |
+
# Validate input images
|
| 411 |
+
if left_image.shape != right_image.shape:
|
| 412 |
+
return None, "❌ Left and right images must have the same dimensions."
|
| 413 |
+
|
| 414 |
+
H, W = left_image.shape[:2]
|
| 415 |
+
|
| 416 |
+
# Convert to torch tensors and ensure they are contiguous
|
| 417 |
+
img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 418 |
+
img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 419 |
+
|
| 420 |
+
# Pad images and ensure contiguity
|
| 421 |
+
padder = InputPadder(img0.shape, divis_by=32, force_square=False)
|
| 422 |
+
img0, img1 = padder.pad(img0, img1)
|
| 423 |
+
|
| 424 |
+
# Ensure padded tensors are contiguous
|
| 425 |
+
img0 = img0.contiguous()
|
| 426 |
+
img1 = img1.contiguous()
|
| 427 |
+
|
| 428 |
+
progress(0.5, desc="Running inference...")
|
| 429 |
+
|
| 430 |
+
# Process stereo pair with autocast and ensure clean memory state
|
| 431 |
+
torch.cuda.empty_cache() # Clear any cached memory before inference
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 435 |
+
# Ensure tensors are in the right format for cuDNN
|
| 436 |
+
if not img0.is_contiguous():
|
| 437 |
+
img0 = img0.contiguous()
|
| 438 |
+
if not img1.is_contiguous():
|
| 439 |
+
img1 = img1.contiguous()
|
| 440 |
+
|
| 441 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 442 |
+
except RuntimeError as e:
|
| 443 |
+
if "cuDNN" in str(e):
|
| 444 |
+
# Fallback: disable cuDNN optimizations and retry
|
| 445 |
+
logging.warning(f"cuDNN error encountered, retrying with fallback: {e}")
|
| 446 |
+
torch.backends.cudnn.enabled = False
|
| 447 |
+
try:
|
| 448 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 449 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 450 |
+
finally:
|
| 451 |
+
torch.backends.cudnn.enabled = True # Re-enable for future use
|
| 452 |
+
else:
|
| 453 |
+
raise e
|
| 454 |
+
|
| 455 |
+
# Unpad and convert to numpy
|
| 456 |
+
disp = padder.unpad(disp.float())
|
| 457 |
+
disp_cpu = disp.data.cpu().numpy().reshape(H, W)
|
| 458 |
+
|
| 459 |
+
progress(0.8, desc="Creating visualization...")
|
| 460 |
+
|
| 461 |
+
# Create visualization - ONLY disparity
|
| 462 |
+
disparity_vis = vis_disparity(disp_cpu)
|
| 463 |
+
result_image = disparity_vis
|
| 464 |
+
|
| 465 |
+
progress(1.0, desc="Complete!")
|
| 466 |
+
|
| 467 |
+
# Clean up intermediate tensors
|
| 468 |
+
del img0, img1, disp
|
| 469 |
+
|
| 470 |
+
# For ZeroGPU: Clean up model after inference
|
| 471 |
+
del model
|
| 472 |
+
torch.cuda.empty_cache()
|
| 473 |
+
gc.collect()
|
| 474 |
+
|
| 475 |
+
# Create status message
|
| 476 |
+
valid_mask = disp_cpu != np.inf
|
| 477 |
+
min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
|
| 478 |
+
max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
|
| 479 |
+
mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
|
| 480 |
+
|
| 481 |
+
# Get model variant for status
|
| 482 |
+
variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
|
| 483 |
+
|
| 484 |
+
# Check current memory usage (safely within GPU context)
|
| 485 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 486 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 487 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 488 |
+
|
| 489 |
+
status = f"""✅ Processing successful!
|
| 490 |
+
🔧 Model: {variant} (ZeroGPU){memory_info}
|
| 491 |
+
📊 Disparity Statistics:
|
| 492 |
+
• Range: {min_disp:.2f} - {max_disp:.2f}
|
| 493 |
+
• Mean: {mean_disp:.2f}
|
| 494 |
+
• Input size: {W}×{H}
|
| 495 |
+
• Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
|
| 496 |
+
|
| 497 |
+
return result_image, status
|
| 498 |
+
|
| 499 |
+
except Exception as e:
|
| 500 |
+
logging.error(f"Processing failed: {e}")
|
| 501 |
+
# Cleanup on error
|
| 502 |
+
if 'img0' in locals():
|
| 503 |
+
del img0
|
| 504 |
+
if 'img1' in locals():
|
| 505 |
+
del img1
|
| 506 |
+
if 'disp' in locals():
|
| 507 |
+
del disp
|
| 508 |
+
if 'model' in locals():
|
| 509 |
+
del model
|
| 510 |
+
# Clean up GPU memory
|
| 511 |
+
torch.cuda.empty_cache()
|
| 512 |
+
gc.collect()
|
| 513 |
+
return None, f"❌ Error: {str(e)}"
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# Fixed with static duration
|
| 517 |
+
@spaces.GPU(duration=120) # Static 120 seconds for depth processing
|
| 518 |
+
def process_with_depth(model_selection: str, left_image: np.ndarray, right_image: np.ndarray,
|
| 519 |
+
camera_matrix: str, baseline: float,
|
| 520 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
|
| 521 |
+
"""
|
| 522 |
+
Process stereo pair and generate depth map and point cloud (with model caching)
|
| 523 |
+
"""
|
| 524 |
+
# Import these inside to avoid early CUDA calls
|
| 525 |
+
from core.utils.utils import InputPadder
|
| 526 |
+
# Import vis_disparity safely within GPU context
|
| 527 |
+
from Utils import vis_disparity
|
| 528 |
+
|
| 529 |
+
# Import Open3D inside GPU context
|
| 530 |
+
global OPEN3D_AVAILABLE
|
| 531 |
+
try:
|
| 532 |
+
import open3d as o3d
|
| 533 |
+
OPEN3D_AVAILABLE = True
|
| 534 |
+
except ImportError as e:
|
| 535 |
+
logging.warning(f"Open3D not available: {e}")
|
| 536 |
+
OPEN3D_AVAILABLE = False
|
| 537 |
+
return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
|
| 538 |
+
|
| 539 |
+
if left_image is None or right_image is None:
|
| 540 |
+
return None, None, None, "❌ Please upload both left and right images."
|
| 541 |
+
|
| 542 |
+
try:
|
| 543 |
+
progress(0.1, desc="Parsing camera parameters...")
|
| 544 |
+
|
| 545 |
+
# Parse camera matrix
|
| 546 |
+
try:
|
| 547 |
+
K_values = list(map(float, camera_matrix.strip().split()))
|
| 548 |
+
if len(K_values) != 9:
|
| 549 |
+
return None, None, None, "❌ Camera matrix must contain exactly 9 values."
|
| 550 |
+
K = np.array(K_values).reshape(3, 3)
|
| 551 |
+
except ValueError:
|
| 552 |
+
return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
|
| 553 |
+
|
| 554 |
+
if baseline <= 0:
|
| 555 |
+
return None, None, None, "❌ Baseline must be positive."
|
| 556 |
+
|
| 557 |
+
variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
|
| 558 |
+
progress(0.2, desc=f"Loading cached model ({variant})...")
|
| 559 |
+
|
| 560 |
+
# Get cached model (will load if not cached or selection changed)
|
| 561 |
+
model, device = get_cached_model(model_selection)
|
| 562 |
+
|
| 563 |
+
progress(0.4, desc="Running stereo inference...")
|
| 564 |
+
|
| 565 |
+
# Get disparity using the same process as the basic function
|
| 566 |
+
H, W = left_image.shape[:2]
|
| 567 |
+
img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 568 |
+
img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 569 |
+
|
| 570 |
+
padder = InputPadder(img0.shape, divis_by=32, force_square=False)
|
| 571 |
+
img0, img1 = padder.pad(img0, img1)
|
| 572 |
+
|
| 573 |
+
# Ensure padded tensors are contiguous
|
| 574 |
+
img0 = img0.contiguous()
|
| 575 |
+
img1 = img1.contiguous()
|
| 576 |
+
|
| 577 |
+
# Clear cache and ensure clean memory state before inference
|
| 578 |
+
torch.cuda.empty_cache()
|
| 579 |
+
|
| 580 |
+
try:
|
| 581 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 582 |
+
# Double-check tensor contiguity before cuDNN operations
|
| 583 |
+
if not img0.is_contiguous():
|
| 584 |
+
img0 = img0.contiguous()
|
| 585 |
+
if not img1.is_contiguous():
|
| 586 |
+
img1 = img1.contiguous()
|
| 587 |
+
|
| 588 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 589 |
+
except RuntimeError as e:
|
| 590 |
+
if "cuDNN" in str(e):
|
| 591 |
+
# Fallback: disable cuDNN optimizations and retry
|
| 592 |
+
logging.warning(f"cuDNN error encountered in depth processing, retrying with fallback: {e}")
|
| 593 |
+
torch.backends.cudnn.enabled = False
|
| 594 |
+
try:
|
| 595 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 596 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 597 |
+
finally:
|
| 598 |
+
torch.backends.cudnn.enabled = True # Re-enable for future use
|
| 599 |
+
else:
|
| 600 |
+
raise e
|
| 601 |
+
|
| 602 |
+
disp = padder.unpad(disp.float())
|
| 603 |
+
disp_cpu = disp.data.cpu().numpy().reshape(H, W)
|
| 604 |
+
|
| 605 |
+
# Clean up intermediate tensors early
|
| 606 |
+
del img0, img1, disp
|
| 607 |
+
|
| 608 |
+
# For ZeroGPU: Keep model reference for rest of processing
|
| 609 |
+
torch.cuda.empty_cache()
|
| 610 |
+
|
| 611 |
+
progress(0.6, desc="Converting to depth...")
|
| 612 |
+
|
| 613 |
+
# Remove invisible points (same as in original demo)
|
| 614 |
+
yy, xx = np.meshgrid(np.arange(disp_cpu.shape[0]), np.arange(disp_cpu.shape[1]), indexing='ij')
|
| 615 |
+
us_right = xx - disp_cpu
|
| 616 |
+
invalid = us_right < 0
|
| 617 |
+
disp_cpu[invalid] = np.inf
|
| 618 |
+
|
| 619 |
+
# Convert to depth using the formula from the original demo
|
| 620 |
+
depth = K[0, 0] * baseline / disp_cpu
|
| 621 |
+
|
| 622 |
+
# Visualize depth (no rotation)
|
| 623 |
+
depth_vis = vis_disparity(depth, max_val=10.0)
|
| 624 |
+
|
| 625 |
+
progress(0.8, desc="Generating point cloud...")
|
| 626 |
+
|
| 627 |
+
# Generate point cloud with proper coordinate transformation
|
| 628 |
+
fx, fy = K[0, 0], K[1, 1]
|
| 629 |
+
cx, cy = K[0, 2], K[1, 2]
|
| 630 |
+
|
| 631 |
+
# Create coordinate meshgrids
|
| 632 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 633 |
+
|
| 634 |
+
# Convert to 3D coordinates (proper camera coordinate system)
|
| 635 |
+
valid_depth = depth != np.inf
|
| 636 |
+
z = depth[valid_depth] # Z coordinate (depth)
|
| 637 |
+
x = (u[valid_depth] - cx) * z / fx # X coordinate
|
| 638 |
+
y = (v[valid_depth] - cy) * z / fy # Y coordinate
|
| 639 |
+
|
| 640 |
+
# Stack coordinates (X, Y, Z)
|
| 641 |
+
points = np.stack([x, y, z], axis=-1)
|
| 642 |
+
|
| 643 |
+
# Get corresponding colors
|
| 644 |
+
colors = left_image[valid_depth]
|
| 645 |
+
|
| 646 |
+
# Filter points by depth range
|
| 647 |
+
depth_mask = (z > 0) & (z <= 10.0)
|
| 648 |
+
valid_points = points[depth_mask]
|
| 649 |
+
valid_colors = colors[depth_mask]
|
| 650 |
+
|
| 651 |
+
if len(valid_points) == 0:
|
| 652 |
+
return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
|
| 653 |
+
|
| 654 |
+
# Subsample points for better 3D visualization performance
|
| 655 |
+
if len(valid_points) > 100000:
|
| 656 |
+
indices = np.random.choice(len(valid_points), 100000, replace=False)
|
| 657 |
+
valid_points = valid_points[indices]
|
| 658 |
+
valid_colors = valid_colors[indices]
|
| 659 |
+
|
| 660 |
+
# Transform coordinates for proper visualization orientation
|
| 661 |
+
# Standard computer vision: X right, Y down, Z forward
|
| 662 |
+
# For better 3D viewing: X right, Y up, Z backward
|
| 663 |
+
transformed_points = valid_points.copy()
|
| 664 |
+
transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
|
| 665 |
+
transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
|
| 666 |
+
|
| 667 |
+
# Create point cloud using transformed coordinates
|
| 668 |
+
pcd = o3d.geometry.PointCloud()
|
| 669 |
+
pcd.points = o3d.utility.Vector3dVector(transformed_points)
|
| 670 |
+
pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
|
| 671 |
+
|
| 672 |
+
# Save point cloud for download (.ply)
|
| 673 |
+
temp_ply_file = tempfile.NamedTemporaryFile(delete=False, suffix='.ply')
|
| 674 |
+
o3d.io.write_point_cloud(temp_ply_file.name, pcd)
|
| 675 |
+
|
| 676 |
+
# Create OBJ file for 3D visualization (better Gradio compatibility)
|
| 677 |
+
temp_obj_file = tempfile.NamedTemporaryFile(delete=False, suffix='.obj')
|
| 678 |
+
|
| 679 |
+
# Write OBJ file with proper vertex colors
|
| 680 |
+
with open(temp_obj_file.name, 'w') as f:
|
| 681 |
+
f.write("# Point cloud generated from stereo depth\n")
|
| 682 |
+
f.write(f"# Total points: {len(valid_points)}\n")
|
| 683 |
+
|
| 684 |
+
# Write vertices with RGB colors (0-1 range)
|
| 685 |
+
for i, (point, color) in enumerate(zip(transformed_points, valid_colors)):
|
| 686 |
+
# Ensure colors are in 0-1 range
|
| 687 |
+
r, g, b = np.clip(color / 255.0, 0, 1)
|
| 688 |
+
f.write(f"v {point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r:.6f} {g:.6f} {b:.6f}\n")
|
| 689 |
+
|
| 690 |
+
progress(1.0, desc="Complete!")
|
| 691 |
+
|
| 692 |
+
# For ZeroGPU: Clean up model after inference
|
| 693 |
+
del model
|
| 694 |
+
torch.cuda.empty_cache()
|
| 695 |
+
gc.collect()
|
| 696 |
+
|
| 697 |
+
# Check current memory usage (safely within GPU context)
|
| 698 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 699 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 700 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 701 |
+
|
| 702 |
+
status = f"""✅ Depth processing successful!
|
| 703 |
+
🔧 Model: {variant} (ZeroGPU){memory_info}
|
| 704 |
+
📊 Statistics:
|
| 705 |
+
• Valid points: {len(valid_points):,}
|
| 706 |
+
• Depth range: {z.min():.2f} - {z.max():.2f} m
|
| 707 |
+
• Baseline: {baseline} m
|
| 708 |
+
• Point cloud saved with {len(valid_points)} points
|
| 709 |
+
• 3D visualization ready (corrected orientation)"""
|
| 710 |
+
|
| 711 |
+
return depth_vis, temp_ply_file.name, temp_obj_file.name, status
|
| 712 |
+
|
| 713 |
+
except Exception as e:
|
| 714 |
+
logging.error(f"Depth processing failed: {e}")
|
| 715 |
+
# Cleanup on error
|
| 716 |
+
if 'img0' in locals():
|
| 717 |
+
del img0
|
| 718 |
+
if 'img1' in locals():
|
| 719 |
+
del img1
|
| 720 |
+
if 'disp' in locals():
|
| 721 |
+
del disp
|
| 722 |
+
if 'model' in locals():
|
| 723 |
+
del model
|
| 724 |
+
# Clean up GPU memory
|
| 725 |
+
torch.cuda.empty_cache()
|
| 726 |
+
gc.collect()
|
| 727 |
+
return None, None, None, f"❌ Error: {str(e)}"
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def preload_all_models():
|
| 731 |
+
"""Pre-download all Hugging Face models to cache during startup"""
|
| 732 |
+
logging.info("🔄 Pre-downloading all models to cache...")
|
| 733 |
+
|
| 734 |
+
downloaded_models = {}
|
| 735 |
+
|
| 736 |
+
for variant, info in MODEL_VARIANTS.items():
|
| 737 |
+
try:
|
| 738 |
+
logging.info(f"📥 Downloading {variant} model to cache...")
|
| 739 |
+
model_path, config_path = download_model_from_hf(variant, force_download=False)
|
| 740 |
+
downloaded_models[variant] = {
|
| 741 |
+
"model_path": model_path,
|
| 742 |
+
"config_path": config_path,
|
| 743 |
+
"display_name": info["display_name"]
|
| 744 |
+
}
|
| 745 |
+
logging.info(f"✅ {variant} model cached successfully")
|
| 746 |
+
except Exception as e:
|
| 747 |
+
logging.warning(f"⚠️ Failed to download {variant} model: {e}")
|
| 748 |
+
# Continue with other models even if one fails
|
| 749 |
+
|
| 750 |
+
logging.info(f"✅ Model pre-loading complete. {len(downloaded_models)}/{len(MODEL_VARIANTS)} models cached.")
|
| 751 |
+
return downloaded_models
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def create_app() -> gr.Blocks:
|
| 755 |
+
"""Create the Gradio application"""
|
| 756 |
+
|
| 757 |
+
global MODEL_PATH, CONFIG_PATH
|
| 758 |
+
|
| 759 |
+
# Debug: Print current directory and check for files
|
| 760 |
+
print(f"Current directory: {current_dir}")
|
| 761 |
+
print(f"Python working directory: {os.getcwd()}")
|
| 762 |
+
|
| 763 |
+
# Pre-download all models to cache
|
| 764 |
+
try:
|
| 765 |
+
cached_models = preload_all_models()
|
| 766 |
+
logging.info(f"Pre-loaded {len(cached_models)} models to cache")
|
| 767 |
+
except Exception as e:
|
| 768 |
+
logging.error(f"Failed to pre-load models: {e}")
|
| 769 |
+
cached_models = {}
|
| 770 |
+
|
| 771 |
+
# Get available models (this should be safe as it only does file system operations)
|
| 772 |
+
try:
|
| 773 |
+
available_models = get_available_models()
|
| 774 |
+
logging.info(f"Successfully got available models: {len(available_models)} found")
|
| 775 |
+
except Exception as e:
|
| 776 |
+
logging.error(f"Failed to get available models: {e}")
|
| 777 |
+
available_models = {}
|
| 778 |
+
|
| 779 |
+
# Find model and config paths (legacy) - should be safe as well
|
| 780 |
+
try:
|
| 781 |
+
MODEL_PATH, CONFIG_PATH = find_model_path()
|
| 782 |
+
logging.info("Successfully found model paths")
|
| 783 |
+
except Exception as e:
|
| 784 |
+
logging.error(f"Failed to find model paths: {e}")
|
| 785 |
+
MODEL_PATH, CONFIG_PATH = None, None
|
| 786 |
+
|
| 787 |
+
with gr.Blocks(
|
| 788 |
+
title="FoundationStereo - Stereo Depth Estimation",
|
| 789 |
+
theme=gr.themes.Soft(),
|
| 790 |
+
css="footer {visibility: hidden}",
|
| 791 |
+
delete_cache=(60, 60) # Delete cache after 60 seconds for ZeroGPU
|
| 792 |
+
) as app:
|
| 793 |
+
|
| 794 |
+
gr.Markdown("""
|
| 795 |
+
# 🔍 FoundationStereo: Zero-Shot Stereo Matching
|
| 796 |
+
|
| 797 |
+
Upload a pair of **rectified** stereo images to get disparity estimation.
|
| 798 |
+
|
| 799 |
+
⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
|
| 800 |
+
⚡ **ZeroGPU Powered**: Runs on high-performance A100 GPUs for fast inference.
|
| 801 |
+
📦 **Smart Caching**: All models are pre-downloaded for instant model switching.
|
| 802 |
+
""")
|
| 803 |
+
|
| 804 |
+
# Instructions section
|
| 805 |
+
with gr.Accordion("📋 Instructions to Run This Repository", open=False):
|
| 806 |
+
gr.Markdown("""
|
| 807 |
+
## 🚀 How to Run This Demo
|
| 808 |
+
This is a **demo application** showcasing the FoundationStereo model for stereo matching estimation.
|
| 809 |
+
|
| 810 |
+
### 🖼️ Input Requirements
|
| 811 |
+
|
| 812 |
+
1. **Image Format**: Upload images in JPEG or PNG format.
|
| 813 |
+
2. **Image Size**: Images should be of the same size and resolution.
|
| 814 |
+
3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
|
| 815 |
+
4. **Camera Parameters**: For advanced processing, provide camera parameters (camera matrix and baseline).
|
| 816 |
+
|
| 817 |
+
### 📊 Using the Demo
|
| 818 |
+
|
| 819 |
+
1. **Select Model**: Choose between low-cost (11-33-40) or high-quality (23-51-11) variants
|
| 820 |
+
2. **Upload Images**: Provide rectified stereo image pairs
|
| 821 |
+
3. **Basic Processing**: Get disparity visualization
|
| 822 |
+
4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
|
| 823 |
+
|
| 824 |
+
### Original Work
|
| 825 |
+
|
| 826 |
+
This demo is based on the original FoundationStereo research. Please visit the official resources:
|
| 827 |
+
- **Paper**: [FoundationStereo: Zero-Shot Stereo Matching via Foundation Model](https://arxiv.org/abs/2501.09898)
|
| 828 |
+
- **Project Page**: [https://nvlabs.github.io/FoundationStereo/](https://nvlabs.github.io/FoundationStereo/)
|
| 829 |
+
- **Official Repository**: [https://github.com/NVlabs/FoundationStereo](https://github.com/NVlabs/FoundationStereo)
|
| 830 |
+
|
| 831 |
+
**⚠️ Demo Notice**: This is a demonstration interface. For research and production use, please refer to the original repository and follow the official implementation guidelines.
|
| 832 |
+
""")
|
| 833 |
+
|
| 834 |
+
# Model selection
|
| 835 |
+
with gr.Row():
|
| 836 |
+
# Always include Hugging Face models in the choices
|
| 837 |
+
all_choices = list(available_models.keys())
|
| 838 |
+
|
| 839 |
+
# If no models found, add the HF models manually
|
| 840 |
+
if not all_choices:
|
| 841 |
+
all_choices = [
|
| 842 |
+
"FoundationStereo (Low-cost variant - 11-33-40) [Hugging Face]",
|
| 843 |
+
"FoundationStereo (High-quality variant - 23-51-11) [Hugging Face]"
|
| 844 |
+
]
|
| 845 |
+
|
| 846 |
+
# Get default model (prefer Hugging Face low-cost variant)
|
| 847 |
+
default_model = None
|
| 848 |
+
|
| 849 |
+
# First try Hugging Face low-cost variant
|
| 850 |
+
for name in all_choices:
|
| 851 |
+
if "11-33-40" in name and "[Hugging Face]" in name:
|
| 852 |
+
default_model = name
|
| 853 |
+
break
|
| 854 |
+
|
| 855 |
+
# If no HF low-cost variant, try any low-cost variant
|
| 856 |
+
if default_model is None:
|
| 857 |
+
for name in all_choices:
|
| 858 |
+
if "11-33-40" in name:
|
| 859 |
+
default_model = name
|
| 860 |
+
break
|
| 861 |
+
|
| 862 |
+
# If no low-cost variant, use first available
|
| 863 |
+
if default_model is None:
|
| 864 |
+
default_model = all_choices[0] if all_choices else None
|
| 865 |
+
|
| 866 |
+
model_selector = gr.Dropdown(
|
| 867 |
+
choices=all_choices,
|
| 868 |
+
value=default_model,
|
| 869 |
+
label="🎯 Select Model",
|
| 870 |
+
info="Choose the FoundationStereo model variant. Hugging Face models download automatically.",
|
| 871 |
+
interactive=True
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
with gr.Tabs():
|
| 875 |
+
# Basic stereo processing tab
|
| 876 |
+
with gr.TabItem("🖼️ Basic Stereo Processing"):
|
| 877 |
+
with gr.Row():
|
| 878 |
+
with gr.Column():
|
| 879 |
+
left_input = gr.Image(
|
| 880 |
+
label="📷 Left Image",
|
| 881 |
+
type="numpy",
|
| 882 |
+
height=300
|
| 883 |
+
)
|
| 884 |
+
right_input = gr.Image(
|
| 885 |
+
label="📷 Right Image",
|
| 886 |
+
type="numpy",
|
| 887 |
+
height=300
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
process_btn = gr.Button(
|
| 891 |
+
"🚀 Process Stereo Pair",
|
| 892 |
+
variant="primary",
|
| 893 |
+
size="lg"
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
with gr.Column():
|
| 897 |
+
output_image = gr.Image(
|
| 898 |
+
label="📊 Disparity Visualization",
|
| 899 |
+
height=400
|
| 900 |
+
)
|
| 901 |
+
status_text = gr.Textbox(
|
| 902 |
+
label="Status",
|
| 903 |
+
interactive=False,
|
| 904 |
+
lines=8
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
# Example images
|
| 908 |
+
examples_list = []
|
| 909 |
+
|
| 910 |
+
# Example 1
|
| 911 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 912 |
+
examples_list.append([
|
| 913 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 914 |
+
os.path.join(current_dir, "assets", "example1", "right.png")
|
| 915 |
+
])
|
| 916 |
+
|
| 917 |
+
# Example 2
|
| 918 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 919 |
+
examples_list.append([
|
| 920 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 921 |
+
os.path.join(current_dir, "assets", "example2", "right.png")
|
| 922 |
+
])
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
gr.Examples(
|
| 927 |
+
examples=examples_list,
|
| 928 |
+
inputs=[left_input, right_input],
|
| 929 |
+
label="📋 Example Images"
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# Advanced processing with depth
|
| 933 |
+
with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
|
| 934 |
+
with gr.Row():
|
| 935 |
+
with gr.Column():
|
| 936 |
+
left_input_adv = gr.Image(
|
| 937 |
+
label="📷 Left Image",
|
| 938 |
+
type="numpy",
|
| 939 |
+
height=250
|
| 940 |
+
)
|
| 941 |
+
right_input_adv = gr.Image(
|
| 942 |
+
label="📷 Right Image",
|
| 943 |
+
type="numpy",
|
| 944 |
+
height=250
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
# Camera parameters
|
| 948 |
+
with gr.Group():
|
| 949 |
+
gr.Markdown("### 📹 Camera Parameters")
|
| 950 |
+
camera_matrix_input = gr.Textbox(
|
| 951 |
+
label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
|
| 952 |
+
value="",
|
| 953 |
+
|
| 954 |
+
)
|
| 955 |
+
baseline_input = gr.Number(
|
| 956 |
+
label="Baseline (meters)",
|
| 957 |
+
value=None,
|
| 958 |
+
minimum=0.001,
|
| 959 |
+
maximum=10.0,
|
| 960 |
+
step=0.001
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
process_depth_btn = gr.Button(
|
| 964 |
+
"🔬 Process with Depth",
|
| 965 |
+
variant="primary",
|
| 966 |
+
size="lg"
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
with gr.Column():
|
| 970 |
+
depth_output = gr.Image(
|
| 971 |
+
label="📏 Depth Visualization",
|
| 972 |
+
height=300
|
| 973 |
+
)
|
| 974 |
+
pointcloud_output = gr.File(
|
| 975 |
+
label="☁️ Point Cloud Download (.ply)",
|
| 976 |
+
file_types=[".ply"]
|
| 977 |
+
)
|
| 978 |
+
status_depth = gr.Textbox(
|
| 979 |
+
label="Status",
|
| 980 |
+
interactive=False,
|
| 981 |
+
lines=6
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# 3D Point Cloud Visualization
|
| 985 |
+
with gr.Row():
|
| 986 |
+
pointcloud_3d = gr.Model3D(
|
| 987 |
+
label="🌐 3D Point Cloud Viewer",
|
| 988 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 989 |
+
height=400
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
# Example images for advanced processing
|
| 993 |
+
examples_advanced_list = []
|
| 994 |
+
|
| 995 |
+
# Example 1 - Camera parameters from K.txt
|
| 996 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 997 |
+
examples_advanced_list.append([
|
| 998 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 999 |
+
os.path.join(current_dir, "assets", "example1", "right.png"),
|
| 1000 |
+
"754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0", # Camera matrix
|
| 1001 |
+
0.063 # Baseline in meters
|
| 1002 |
+
])
|
| 1003 |
+
|
| 1004 |
+
# Example 2 - Camera parameters from K.txt
|
| 1005 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 1006 |
+
examples_advanced_list.append([
|
| 1007 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 1008 |
+
os.path.join(current_dir, "assets", "example2", "right.png"),
|
| 1009 |
+
"1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0", # Camera matrix
|
| 1010 |
+
0.537 # Baseline in meters (converted from 536.62mm)
|
| 1011 |
+
])
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
gr.Examples(
|
| 1016 |
+
examples=examples_advanced_list,
|
| 1017 |
+
inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 1018 |
+
label="📋 Example Images with Camera Parameters"
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
# Event handlers - Always enable since we have HF models
|
| 1022 |
+
process_btn.click(
|
| 1023 |
+
fn=process_stereo_pair,
|
| 1024 |
+
inputs=[model_selector, left_input, right_input],
|
| 1025 |
+
outputs=[output_image, status_text],
|
| 1026 |
+
show_progress=True
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
if OPEN3D_AVAILABLE:
|
| 1030 |
+
process_depth_btn.click(
|
| 1031 |
+
fn=process_with_depth,
|
| 1032 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 1033 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
|
| 1034 |
+
show_progress=True
|
| 1035 |
+
)
|
| 1036 |
+
else:
|
| 1037 |
+
process_depth_btn.click(
|
| 1038 |
+
fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
|
| 1039 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 1040 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
# Citation section at the bottom
|
| 1044 |
+
with gr.Accordion("📖 Citation", open=False):
|
| 1045 |
+
gr.Markdown("""
|
| 1046 |
+
### 📄 Please Cite the Original Paper
|
| 1047 |
+
|
| 1048 |
+
If you use this work in your research, please cite:
|
| 1049 |
+
|
| 1050 |
+
```bibtex
|
| 1051 |
+
@article{wen2025stereo,
|
| 1052 |
+
title={FoundationStereo: Zero-Shot Stereo Matching},
|
| 1053 |
+
author={Bowen Wen and Matthew Trepte and Joseph Aribido and Jan Kautz and Orazio Gallo and Stan Birchfield},
|
| 1054 |
+
journal={CVPR},
|
| 1055 |
+
year={2025}
|
| 1056 |
+
}
|
| 1057 |
+
```
|
| 1058 |
+
""")
|
| 1059 |
+
|
| 1060 |
+
# Footer
|
| 1061 |
+
gr.Markdown(f"""
|
| 1062 |
+
---
|
| 1063 |
+
### 📝 Notes:
|
| 1064 |
+
- **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
|
| 1065 |
+
- **🤗 Hugging Face Integration**: Models are automatically downloaded from `{HF_REPO_ID}`
|
| 1066 |
+
- **📦 Smart Caching**: All models are pre-downloaded and cached for instant switching
|
| 1067 |
+
- **⚡ ZeroGPU Acceleration**: Powered by high-performance A100 GPUs
|
| 1068 |
+
- For best results, use PNG images without lossy compression
|
| 1069 |
+
- Model works on RGB images but also supports monochrome/IR stereo pairs
|
| 1070 |
+
- **Optimized for Spaces**: Memory-efficient inference on shared infrastructure
|
| 1071 |
+
|
| 1072 |
+
### 🔗 References:
|
| 1073 |
+
- [FoundationStereo Paper](https://arxiv.org/abs/2501.09898)
|
| 1074 |
+
- [Project Website](https://nvlabs.github.io/FoundationStereo/)
|
| 1075 |
+
- [GitHub Repository](https://github.com/NVlabs/FoundationStereo)
|
| 1076 |
+
- [Hugging Face Models]({f"https://huggingface.co/{HF_REPO_ID}"})
|
| 1077 |
+
""")
|
| 1078 |
+
|
| 1079 |
+
return app
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
def main():
|
| 1083 |
+
"""Main function to launch the app"""
|
| 1084 |
+
|
| 1085 |
+
# Ensure no CUDA operations during startup
|
| 1086 |
+
if torch.cuda.is_available():
|
| 1087 |
+
logging.warning("CUDA detected during startup - this should not happen in ZeroGPU")
|
| 1088 |
+
|
| 1089 |
+
logging.info("🚀 Starting FoundationStereo Gradio App...")
|
| 1090 |
+
|
| 1091 |
+
# Parse command line arguments
|
| 1092 |
+
import argparse
|
| 1093 |
+
parser = argparse.ArgumentParser(description="FoundationStereo Gradio App")
|
| 1094 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 1095 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
| 1096 |
+
parser.add_argument("--share", action="store_true", help="Create shareable link")
|
| 1097 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 1098 |
+
|
| 1099 |
+
args = parser.parse_args()
|
| 1100 |
+
|
| 1101 |
+
if args.debug:
|
| 1102 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 1103 |
+
|
| 1104 |
+
try:
|
| 1105 |
+
# Create and launch app
|
| 1106 |
+
logging.info("Creating Gradio app...")
|
| 1107 |
+
app = create_app()
|
| 1108 |
+
logging.info("✅ Gradio app created successfully")
|
| 1109 |
+
|
| 1110 |
+
logging.info(f"Launching app on {args.host}:{args.port}")
|
| 1111 |
+
if args.share:
|
| 1112 |
+
logging.info("Share link will be created")
|
| 1113 |
+
|
| 1114 |
+
# For ZeroGPU compatibility, launch with appropriate settings
|
| 1115 |
+
app.launch(
|
| 1116 |
+
server_name=args.host,
|
| 1117 |
+
server_port=args.port,
|
| 1118 |
+
share=args.share,
|
| 1119 |
+
show_error=True,
|
| 1120 |
+
favicon_path=None,
|
| 1121 |
+
ssr_mode=False, # Disable SSR for ZeroGPU compatibility
|
| 1122 |
+
allowed_paths=["./"] # Allow access to local files
|
| 1123 |
+
)
|
| 1124 |
+
except Exception as e:
|
| 1125 |
+
logging.error(f"Failed to launch app: {e}")
|
| 1126 |
+
raise
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
if __name__ == "__main__":
|
| 1130 |
+
# Additional safety check for ZeroGPU environment
|
| 1131 |
+
if 'SPACE_ID' in os.environ:
|
| 1132 |
+
logging.info("Running in Hugging Face Spaces environment")
|
| 1133 |
+
|
| 1134 |
+
# Do not check CUDA status during startup - this can trigger CUDA initialization
|
| 1135 |
+
# The CUDA status will be checked inside the @spaces.GPU decorated functions
|
| 1136 |
+
logging.info("✅ CUDA status will be checked within GPU-decorated functions")
|
| 1137 |
+
|
| 1138 |
+
main()
|
FoundationStereo_demo/app_local.py
ADDED
|
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import tempfile
|
| 5 |
+
import zipfile
|
| 6 |
+
import gc
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import imageio
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
# Set default tensor type if needed
|
| 17 |
+
# torch.set_default_tensor_type('torch.FloatTensor')
|
| 18 |
+
|
| 19 |
+
# Import other safe modules
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 22 |
+
|
| 23 |
+
# CUDA backend settings
|
| 24 |
+
# torch.backends.cudnn.enabled = False
|
| 25 |
+
# torch.backends.cudnn.benchmark = False
|
| 26 |
+
|
| 27 |
+
# Use current directory as base (gradio_app folder)
|
| 28 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
+
base_dir = current_dir # gradio_app folder
|
| 30 |
+
|
| 31 |
+
# Add current directory to path for local imports
|
| 32 |
+
sys.path.insert(0, current_dir)
|
| 33 |
+
|
| 34 |
+
# DO NOT import any local modules here that might use CUDA
|
| 35 |
+
# All local module imports will be done inside functions
|
| 36 |
+
|
| 37 |
+
# Import Open3D with error handling
|
| 38 |
+
OPEN3D_AVAILABLE = False
|
| 39 |
+
try:
|
| 40 |
+
# Set Open3D to CPU mode to avoid CUDA initialization
|
| 41 |
+
os.environ['OPEN3D_CPU_RENDERING'] = '1'
|
| 42 |
+
# Don't import open3d here - do it inside functions
|
| 43 |
+
# import open3d as o3d
|
| 44 |
+
OPEN3D_AVAILABLE = True # Assume available, will check later
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logging.warning(f"Open3D setup failed: {e}")
|
| 47 |
+
OPEN3D_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
# Configure logging
|
| 50 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 51 |
+
|
| 52 |
+
# Hugging Face model repository configuration
|
| 53 |
+
HF_REPO_ID = "shriarul5273/FoundationStereo_models"
|
| 54 |
+
MODEL_VARIANTS = {
|
| 55 |
+
"11-33-40": {
|
| 56 |
+
"display_name": "FoundationStereo (Low-cost variant - 11-33-40)",
|
| 57 |
+
"model_file": "pretrained_models/11-33-40/model_best_bp2.pth",
|
| 58 |
+
"config_file": "pretrained_models/11-33-40/cfg.yaml"
|
| 59 |
+
},
|
| 60 |
+
"23-51-11": {
|
| 61 |
+
"display_name": "FoundationStereo (High-quality variant - 23-51-11)",
|
| 62 |
+
"model_file": "pretrained_models/23-51-11/model_best_bp2.pth",
|
| 63 |
+
"config_file": "pretrained_models/23-51-11/cfg.yaml"
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
# Global variables for model caching
|
| 68 |
+
MODEL_PATH: str = None
|
| 69 |
+
CONFIG_PATH: str = None
|
| 70 |
+
|
| 71 |
+
# Model cache to avoid reloading when selection doesn't change
|
| 72 |
+
_cached_model = None
|
| 73 |
+
_cached_device = None
|
| 74 |
+
_cached_model_selection = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def aggressive_cleanup():
|
| 78 |
+
"""Perform basic cleanup"""
|
| 79 |
+
import gc
|
| 80 |
+
gc.collect()
|
| 81 |
+
logging.info("Performed basic memory cleanup")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def check_gpu_memory():
|
| 85 |
+
"""Check and log current GPU memory usage"""
|
| 86 |
+
try:
|
| 87 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 88 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 89 |
+
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 90 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 91 |
+
|
| 92 |
+
logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB")
|
| 93 |
+
return allocated, reserved, max_allocated, total
|
| 94 |
+
except RuntimeError as e:
|
| 95 |
+
logging.warning(f"Failed to get GPU memory info: {e}")
|
| 96 |
+
return None, None, None, None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def download_model_from_hf(variant: str, force_download: bool = False) -> Tuple[str, str]:
|
| 100 |
+
"""
|
| 101 |
+
Download model and config files from Hugging Face Hub
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
variant: Model variant ("11-33-40" or "23-51-11")
|
| 105 |
+
force_download: Force re-download even if files exist locally
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Tuple of (model_path, config_path)
|
| 109 |
+
"""
|
| 110 |
+
if variant not in MODEL_VARIANTS:
|
| 111 |
+
raise ValueError(f"Unknown model variant: {variant}. Available: {list(MODEL_VARIANTS.keys())}")
|
| 112 |
+
|
| 113 |
+
variant_info = MODEL_VARIANTS[variant]
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
if not force_download:
|
| 117 |
+
logging.info(f"📦 Checking cache for model variant: {variant}")
|
| 118 |
+
else:
|
| 119 |
+
logging.info(f"🔄 Force downloading model variant: {variant}")
|
| 120 |
+
|
| 121 |
+
# Download model file
|
| 122 |
+
model_path = hf_hub_download(
|
| 123 |
+
repo_id=HF_REPO_ID,
|
| 124 |
+
filename=variant_info["model_file"],
|
| 125 |
+
force_download=force_download,
|
| 126 |
+
local_dir_use_symlinks=False
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Download config file
|
| 130 |
+
config_path = hf_hub_download(
|
| 131 |
+
repo_id=HF_REPO_ID,
|
| 132 |
+
filename=variant_info["config_file"],
|
| 133 |
+
force_download=force_download,
|
| 134 |
+
local_dir_use_symlinks=False
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if force_download:
|
| 138 |
+
logging.info(f"✅ Successfully downloaded {variant} model files")
|
| 139 |
+
else:
|
| 140 |
+
logging.info(f"✅ Successfully loaded {variant} model files from cache")
|
| 141 |
+
|
| 142 |
+
logging.debug(f"Model: {model_path}")
|
| 143 |
+
logging.debug(f"Config: {config_path}")
|
| 144 |
+
|
| 145 |
+
return model_path, config_path
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logging.error(f"Failed to download model {variant}: {e}")
|
| 149 |
+
raise RuntimeError(f"Failed to download model {variant} from Hugging Face: {e}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_available_models() -> dict:
|
| 153 |
+
"""Get all available models with their display names and download info"""
|
| 154 |
+
models = {}
|
| 155 |
+
|
| 156 |
+
# First check local models (legacy support)
|
| 157 |
+
search_dirs = [
|
| 158 |
+
os.path.join(current_dir, "pretrained_models"),
|
| 159 |
+
os.path.join(os.path.dirname(current_dir), "pretrained_models")
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
for search_dir in search_dirs:
|
| 163 |
+
if os.path.exists(search_dir):
|
| 164 |
+
for model_dir in os.listdir(search_dir):
|
| 165 |
+
model_path = os.path.join(search_dir, model_dir, "model_best_bp2.pth")
|
| 166 |
+
cfg_path = os.path.join(search_dir, model_dir, "cfg.yaml")
|
| 167 |
+
|
| 168 |
+
if os.path.exists(model_path) and os.path.exists(cfg_path):
|
| 169 |
+
# Create a descriptive name for the model
|
| 170 |
+
if model_dir == "11-33-40":
|
| 171 |
+
display_name = "FoundationStereo (Low-cost variant - 11-33-40) [Local]"
|
| 172 |
+
elif model_dir == "23-51-11":
|
| 173 |
+
display_name = "FoundationStereo (High-quality variant - 23-51-11) [Local]"
|
| 174 |
+
else:
|
| 175 |
+
display_name = f"FoundationStereo ({model_dir}) [Local]"
|
| 176 |
+
|
| 177 |
+
models[display_name] = {
|
| 178 |
+
"model_path": model_path,
|
| 179 |
+
"config_path": cfg_path,
|
| 180 |
+
"variant": model_dir,
|
| 181 |
+
"source": "local"
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Add Hugging Face models
|
| 185 |
+
for variant, info in MODEL_VARIANTS.items():
|
| 186 |
+
display_name = f"{info['display_name']} [Hugging Face]"
|
| 187 |
+
models[display_name] = {
|
| 188 |
+
"model_path": None, # Will be downloaded when needed
|
| 189 |
+
"config_path": None, # Will be downloaded when needed
|
| 190 |
+
"variant": variant,
|
| 191 |
+
"source": "huggingface"
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
return models
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def find_model_path() -> Tuple[Optional[str], Optional[str]]:
|
| 198 |
+
"""Find available model and config paths (legacy function for backward compatibility)"""
|
| 199 |
+
models = get_available_models()
|
| 200 |
+
if models:
|
| 201 |
+
# Prefer Hugging Face models over local ones
|
| 202 |
+
# First try to find HF low-cost variant
|
| 203 |
+
for display_name in models:
|
| 204 |
+
if "11-33-40" in display_name and "[Hugging Face]" in display_name:
|
| 205 |
+
return get_model_paths_from_selection(display_name)
|
| 206 |
+
|
| 207 |
+
# Then try local low-cost variant
|
| 208 |
+
for display_name in models:
|
| 209 |
+
if "11-33-40" in display_name:
|
| 210 |
+
return get_model_paths_from_selection(display_name)
|
| 211 |
+
|
| 212 |
+
# If no low-cost variant, return the first available
|
| 213 |
+
first_model_name = next(iter(models.keys()))
|
| 214 |
+
return get_model_paths_from_selection(first_model_name)
|
| 215 |
+
return None, None
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[str]]:
|
| 219 |
+
"""Get model and config paths from the selected model"""
|
| 220 |
+
models = get_available_models()
|
| 221 |
+
|
| 222 |
+
# Check if it's in our models dict
|
| 223 |
+
if model_selection in models:
|
| 224 |
+
model_info = models[model_selection]
|
| 225 |
+
|
| 226 |
+
# If it's a Hugging Face model, download it first (or get from cache)
|
| 227 |
+
if model_info["source"] == "huggingface":
|
| 228 |
+
variant = model_info["variant"]
|
| 229 |
+
try:
|
| 230 |
+
logging.info(f"📦 Retrieving {variant} model from cache...")
|
| 231 |
+
model_path, config_path = download_model_from_hf(variant, force_download=False)
|
| 232 |
+
return model_path, config_path
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logging.error(f"Failed to get model {variant} from cache: {e}")
|
| 235 |
+
return None, None
|
| 236 |
+
else:
|
| 237 |
+
# Local model
|
| 238 |
+
logging.info(f"📁 Using local model: {model_selection}")
|
| 239 |
+
return model_info["model_path"], model_info["config_path"]
|
| 240 |
+
|
| 241 |
+
# Handle direct HF model selection (fallback)
|
| 242 |
+
elif "[Hugging Face]" in model_selection:
|
| 243 |
+
if "11-33-40" in model_selection:
|
| 244 |
+
variant = "11-33-40"
|
| 245 |
+
elif "23-51-11" in model_selection:
|
| 246 |
+
variant = "23-51-11"
|
| 247 |
+
else:
|
| 248 |
+
logging.error(f"Unknown HF model variant in: {model_selection}")
|
| 249 |
+
return None, None
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
logging.info(f"📦 Retrieving {variant} model from cache...")
|
| 253 |
+
model_path, config_path = download_model_from_hf(variant, force_download=False)
|
| 254 |
+
return model_path, config_path
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logging.error(f"Failed to get model {variant} from cache: {e}")
|
| 257 |
+
return None, None
|
| 258 |
+
|
| 259 |
+
return None, None
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_cached_model(model_selection: str):
|
| 263 |
+
"""Get cached model or load new one if selection changed"""
|
| 264 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 265 |
+
|
| 266 |
+
# Get model paths from selection
|
| 267 |
+
model_path, config_path = get_model_paths_from_selection(model_selection)
|
| 268 |
+
|
| 269 |
+
if model_path is None or config_path is None:
|
| 270 |
+
raise ValueError(f"Selected model not found: {model_selection}")
|
| 271 |
+
|
| 272 |
+
# Load model fresh for each inference
|
| 273 |
+
# Since models are pre-downloaded, this should be fast
|
| 274 |
+
logging.info(f"🚀 Loading cached model: {model_selection}")
|
| 275 |
+
model, device = load_model_for_inference(model_path, config_path)
|
| 276 |
+
|
| 277 |
+
logging.info(f"✅ Model loaded successfully from cache: {model_selection}")
|
| 278 |
+
return model, device
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def clear_model_cache():
|
| 282 |
+
"""Clear the cached model to free memory"""
|
| 283 |
+
global _cached_model, _cached_device, _cached_model_selection
|
| 284 |
+
|
| 285 |
+
if _cached_model is not None:
|
| 286 |
+
logging.info("Clearing model cache...")
|
| 287 |
+
del _cached_model
|
| 288 |
+
_cached_model = None
|
| 289 |
+
_cached_device = None
|
| 290 |
+
_cached_model_selection = None
|
| 291 |
+
|
| 292 |
+
# Simple cleanup
|
| 293 |
+
import gc
|
| 294 |
+
gc.collect()
|
| 295 |
+
logging.info("Model cache cleared")
|
| 296 |
+
else:
|
| 297 |
+
logging.info("No model in cache to clear")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def load_model_for_inference(model_path: str, cfg_path: str):
|
| 301 |
+
"""Load model temporarily for inference"""
|
| 302 |
+
# Set CUDA settings
|
| 303 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
| 304 |
+
torch.backends.cudnn.enabled = True
|
| 305 |
+
torch.backends.cudnn.benchmark = True
|
| 306 |
+
|
| 307 |
+
# Import required modules
|
| 308 |
+
try:
|
| 309 |
+
# Import selectively to avoid CUDA calls in Utils
|
| 310 |
+
from core.foundation_stereo import FoundationStereo
|
| 311 |
+
from omegaconf import OmegaConf
|
| 312 |
+
logging.info("Successfully imported required modules")
|
| 313 |
+
|
| 314 |
+
# Import set_logging_format safely
|
| 315 |
+
from Utils import set_logging_format
|
| 316 |
+
set_logging_format()
|
| 317 |
+
|
| 318 |
+
# Manual seed setting to avoid CUDA calls in Utils.set_seed
|
| 319 |
+
import random
|
| 320 |
+
random_seed = 0
|
| 321 |
+
np.random.seed(random_seed)
|
| 322 |
+
random.seed(random_seed)
|
| 323 |
+
torch.manual_seed(random_seed)
|
| 324 |
+
# CUDA seeding will be done after device is available
|
| 325 |
+
|
| 326 |
+
logging.info("Set logging format and seed")
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logging.error(f"Failed to import modules: {e}")
|
| 329 |
+
raise RuntimeError(f"Import failed: {e}")
|
| 330 |
+
|
| 331 |
+
# Check if CUDA is available
|
| 332 |
+
if not torch.cuda.is_available():
|
| 333 |
+
raise RuntimeError("CUDA is not available.")
|
| 334 |
+
|
| 335 |
+
# Use the first available CUDA device
|
| 336 |
+
device = torch.device("cuda")
|
| 337 |
+
|
| 338 |
+
# Set CUDA seed
|
| 339 |
+
try:
|
| 340 |
+
torch.cuda.manual_seed_all(random_seed)
|
| 341 |
+
torch.backends.cudnn.deterministic = True
|
| 342 |
+
torch.backends.cudnn.benchmark = False
|
| 343 |
+
except Exception as e:
|
| 344 |
+
logging.warning(f"Could not set CUDA seed: {e}")
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
# Load config
|
| 348 |
+
cfg = OmegaConf.load(cfg_path)
|
| 349 |
+
cfg.setdefault("vit_size", "vitl")
|
| 350 |
+
logging.info("Loaded config file")
|
| 351 |
+
|
| 352 |
+
# Create model
|
| 353 |
+
model = FoundationStereo(cfg).to(device)
|
| 354 |
+
model.eval()
|
| 355 |
+
logging.info("Created model")
|
| 356 |
+
|
| 357 |
+
# Load checkpoint
|
| 358 |
+
ckpt = torch.load(model_path, map_location=device)
|
| 359 |
+
model.load_state_dict(ckpt["model"], strict=True)
|
| 360 |
+
logging.info("Loaded model weights")
|
| 361 |
+
|
| 362 |
+
# Memory optimizations
|
| 363 |
+
torch.set_grad_enabled(False)
|
| 364 |
+
model.half() # Use half precision
|
| 365 |
+
logging.info("Applied memory optimizations")
|
| 366 |
+
|
| 367 |
+
return model, device
|
| 368 |
+
|
| 369 |
+
except Exception as e:
|
| 370 |
+
logging.error(f"Model loading failed: {e}")
|
| 371 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def process_stereo_pair(model_selection: str, left_image: str, right_image: str,
|
| 375 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]:
|
| 376 |
+
"""
|
| 377 |
+
Main processing function for stereo pair (with model caching)
|
| 378 |
+
"""
|
| 379 |
+
logging.info("Starting stereo pair processing...")
|
| 380 |
+
|
| 381 |
+
if left_image is None or right_image is None:
|
| 382 |
+
return None, "❌ Please upload both left and right images."
|
| 383 |
+
|
| 384 |
+
# Convert image paths to numpy arrays
|
| 385 |
+
logging.info(f"Loading images: left={left_image}, right={right_image}")
|
| 386 |
+
|
| 387 |
+
try:
|
| 388 |
+
# Load left image
|
| 389 |
+
if left_image is None:
|
| 390 |
+
return None, "❌ Please upload a left image."
|
| 391 |
+
|
| 392 |
+
# Check if file exists first
|
| 393 |
+
if not os.path.exists(left_image):
|
| 394 |
+
logging.error(f"Left image file does not exist: {left_image}")
|
| 395 |
+
return None, f"❌ Left image file not found: {left_image}"
|
| 396 |
+
|
| 397 |
+
logging.info(f"Loading left image from: {left_image}")
|
| 398 |
+
left_img = None
|
| 399 |
+
|
| 400 |
+
# Try multiple loading methods
|
| 401 |
+
try:
|
| 402 |
+
# Method 1: OpenCV
|
| 403 |
+
left_img = cv2.imread(left_image)
|
| 404 |
+
if left_img is not None:
|
| 405 |
+
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
|
| 406 |
+
logging.info("Left image loaded with OpenCV")
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logging.warning(f"OpenCV failed for left image: {e}")
|
| 409 |
+
|
| 410 |
+
if left_img is None:
|
| 411 |
+
try:
|
| 412 |
+
# Method 2: PIL
|
| 413 |
+
from PIL import Image
|
| 414 |
+
with Image.open(left_image) as pil_img:
|
| 415 |
+
left_img = np.array(pil_img.convert('RGB'))
|
| 416 |
+
logging.info("Left image loaded with PIL")
|
| 417 |
+
except Exception as e:
|
| 418 |
+
logging.warning(f"PIL failed for left image: {e}")
|
| 419 |
+
|
| 420 |
+
if left_img is None:
|
| 421 |
+
try:
|
| 422 |
+
# Method 3: imageio
|
| 423 |
+
left_img = imageio.imread(left_image)
|
| 424 |
+
if len(left_img.shape) == 3 and left_img.shape[2] == 4:
|
| 425 |
+
# RGBA to RGB
|
| 426 |
+
left_img = left_img[:, :, :3]
|
| 427 |
+
logging.info("Left image loaded with imageio")
|
| 428 |
+
except Exception as e:
|
| 429 |
+
logging.warning(f"imageio failed for left image: {e}")
|
| 430 |
+
|
| 431 |
+
if left_img is None:
|
| 432 |
+
return None, f"❌ Failed to load left image with any method: {left_image}"
|
| 433 |
+
|
| 434 |
+
# Load right image
|
| 435 |
+
if right_image is None:
|
| 436 |
+
return None, "❌ Please upload a right image."
|
| 437 |
+
|
| 438 |
+
# Check if file exists first
|
| 439 |
+
if not os.path.exists(right_image):
|
| 440 |
+
logging.error(f"Right image file does not exist: {right_image}")
|
| 441 |
+
return None, f"❌ Right image file not found: {right_image}"
|
| 442 |
+
|
| 443 |
+
logging.info(f"Loading right image from: {right_image}")
|
| 444 |
+
right_img = None
|
| 445 |
+
|
| 446 |
+
# Try multiple loading methods
|
| 447 |
+
try:
|
| 448 |
+
# Method 1: OpenCV
|
| 449 |
+
right_img = cv2.imread(right_image)
|
| 450 |
+
if right_img is not None:
|
| 451 |
+
right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
|
| 452 |
+
logging.info("Right image loaded with OpenCV")
|
| 453 |
+
except Exception as e:
|
| 454 |
+
logging.warning(f"OpenCV failed for right image: {e}")
|
| 455 |
+
|
| 456 |
+
if right_img is None:
|
| 457 |
+
try:
|
| 458 |
+
# Method 2: PIL
|
| 459 |
+
from PIL import Image
|
| 460 |
+
with Image.open(right_image) as pil_img:
|
| 461 |
+
right_img = np.array(pil_img.convert('RGB'))
|
| 462 |
+
logging.info("Right image loaded with PIL")
|
| 463 |
+
except Exception as e:
|
| 464 |
+
logging.warning(f"PIL failed for right image: {e}")
|
| 465 |
+
|
| 466 |
+
if right_img is None:
|
| 467 |
+
try:
|
| 468 |
+
# Method 3: imageio
|
| 469 |
+
right_img = imageio.imread(right_image)
|
| 470 |
+
if len(right_img.shape) == 3 and right_img.shape[2] == 4:
|
| 471 |
+
# RGBA to RGB
|
| 472 |
+
right_img = right_img[:, :, :3]
|
| 473 |
+
logging.info("Right image loaded with imageio")
|
| 474 |
+
except Exception as e:
|
| 475 |
+
logging.warning(f"imageio failed for right image: {e}")
|
| 476 |
+
|
| 477 |
+
if right_img is None:
|
| 478 |
+
return None, f"❌ Failed to load right image with any method: {right_image}"
|
| 479 |
+
|
| 480 |
+
# Update variables
|
| 481 |
+
left_image = left_img
|
| 482 |
+
right_image = right_img
|
| 483 |
+
|
| 484 |
+
logging.info(f"Images loaded successfully - Left: {left_image.shape}, Right: {right_image.shape}")
|
| 485 |
+
|
| 486 |
+
except Exception as e:
|
| 487 |
+
logging.error(f"Failed to load images: {e}")
|
| 488 |
+
return None, f"❌ Failed to load images: {str(e)}"
|
| 489 |
+
|
| 490 |
+
try:
|
| 491 |
+
# Import these inside to avoid early CUDA calls
|
| 492 |
+
logging.info("Importing required modules...")
|
| 493 |
+
from core.utils.utils import InputPadder
|
| 494 |
+
# Import vis_disparity safely - it shouldn't have CUDA calls but be careful
|
| 495 |
+
from Utils import vis_disparity
|
| 496 |
+
logging.info("✅ Successfully imported processing modules")
|
| 497 |
+
|
| 498 |
+
# Get cached model (will load if not cached or selection changed)
|
| 499 |
+
variant_name = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else model_selection
|
| 500 |
+
progress(0.1, desc=f"Loading cached model ({variant_name})...")
|
| 501 |
+
logging.info("🚀 Getting cached model...")
|
| 502 |
+
model, device = get_cached_model(model_selection)
|
| 503 |
+
logging.info("✅ Cached model loaded successfully")
|
| 504 |
+
|
| 505 |
+
progress(0.2, desc="Preprocessing images...")
|
| 506 |
+
|
| 507 |
+
# Validate input images
|
| 508 |
+
if left_image.shape != right_image.shape:
|
| 509 |
+
return None, "❌ Left and right images must have the same dimensions."
|
| 510 |
+
|
| 511 |
+
H, W = left_image.shape[:2]
|
| 512 |
+
|
| 513 |
+
# Convert to torch tensors and ensure they are contiguous
|
| 514 |
+
img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 515 |
+
img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 516 |
+
|
| 517 |
+
# Pad images and ensure contiguity
|
| 518 |
+
padder = InputPadder(img0.shape, divis_by=32, force_square=False)
|
| 519 |
+
img0, img1 = padder.pad(img0, img1)
|
| 520 |
+
|
| 521 |
+
# Ensure padded tensors are contiguous
|
| 522 |
+
img0 = img0.contiguous()
|
| 523 |
+
img1 = img1.contiguous()
|
| 524 |
+
|
| 525 |
+
progress(0.5, desc="Running inference...")
|
| 526 |
+
|
| 527 |
+
# Process stereo pair with autocast and ensure clean memory state
|
| 528 |
+
torch.cuda.empty_cache() # Clear any cached memory before inference
|
| 529 |
+
|
| 530 |
+
try:
|
| 531 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 532 |
+
# Ensure tensors are in the right format for cuDNN
|
| 533 |
+
if not img0.is_contiguous():
|
| 534 |
+
img0 = img0.contiguous()
|
| 535 |
+
if not img1.is_contiguous():
|
| 536 |
+
img1 = img1.contiguous()
|
| 537 |
+
|
| 538 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 539 |
+
except RuntimeError as e:
|
| 540 |
+
if "cuDNN" in str(e):
|
| 541 |
+
# Fallback: disable cuDNN optimizations and retry
|
| 542 |
+
logging.warning(f"cuDNN error encountered, retrying with fallback: {e}")
|
| 543 |
+
torch.backends.cudnn.enabled = False
|
| 544 |
+
try:
|
| 545 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 546 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 547 |
+
finally:
|
| 548 |
+
torch.backends.cudnn.enabled = True # Re-enable for future use
|
| 549 |
+
else:
|
| 550 |
+
raise e
|
| 551 |
+
|
| 552 |
+
# Unpad and convert to numpy
|
| 553 |
+
disp = padder.unpad(disp.float())
|
| 554 |
+
disp_cpu = disp.data.cpu().numpy().reshape(H, W)
|
| 555 |
+
|
| 556 |
+
progress(0.8, desc="Creating visualization...")
|
| 557 |
+
|
| 558 |
+
# Create visualization - ONLY disparity
|
| 559 |
+
disparity_vis = vis_disparity(disp_cpu)
|
| 560 |
+
result_image = disparity_vis
|
| 561 |
+
|
| 562 |
+
progress(1.0, desc="Complete!")
|
| 563 |
+
|
| 564 |
+
# Clean up intermediate tensors
|
| 565 |
+
del img0, img1, disp
|
| 566 |
+
|
| 567 |
+
# Clean up model after inference
|
| 568 |
+
del model
|
| 569 |
+
torch.cuda.empty_cache()
|
| 570 |
+
gc.collect()
|
| 571 |
+
|
| 572 |
+
# Create status message
|
| 573 |
+
valid_mask = disp_cpu != np.inf
|
| 574 |
+
min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0
|
| 575 |
+
max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0
|
| 576 |
+
mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0
|
| 577 |
+
|
| 578 |
+
# Get model variant for status
|
| 579 |
+
variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
|
| 580 |
+
|
| 581 |
+
# Check current memory usage
|
| 582 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 583 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 584 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 585 |
+
|
| 586 |
+
status = f"""✅ Processing successful!
|
| 587 |
+
🔧 Model: {variant}{memory_info}
|
| 588 |
+
📊 Disparity Statistics:
|
| 589 |
+
• Range: {min_disp:.2f} - {max_disp:.2f}
|
| 590 |
+
• Mean: {mean_disp:.2f}
|
| 591 |
+
• Input size: {W}×{H}
|
| 592 |
+
• Valid pixels: {valid_mask.sum()}/{valid_mask.size}"""
|
| 593 |
+
|
| 594 |
+
return result_image, status
|
| 595 |
+
|
| 596 |
+
except Exception as e:
|
| 597 |
+
logging.error(f"Processing failed: {e}")
|
| 598 |
+
# Cleanup on error
|
| 599 |
+
if 'img0' in locals():
|
| 600 |
+
del img0
|
| 601 |
+
if 'img1' in locals():
|
| 602 |
+
del img1
|
| 603 |
+
if 'disp' in locals():
|
| 604 |
+
del disp
|
| 605 |
+
if 'model' in locals():
|
| 606 |
+
del model
|
| 607 |
+
# Clean up GPU memory
|
| 608 |
+
torch.cuda.empty_cache()
|
| 609 |
+
gc.collect()
|
| 610 |
+
return None, f"❌ Error: {str(e)}"
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def process_with_depth(model_selection: str, left_image: str, right_image: str,
|
| 614 |
+
camera_matrix: str, baseline: float,
|
| 615 |
+
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
|
| 616 |
+
"""
|
| 617 |
+
Process stereo pair and generate depth map and point cloud (with model caching)
|
| 618 |
+
"""
|
| 619 |
+
from core.utils.utils import InputPadder
|
| 620 |
+
from Utils import vis_disparity
|
| 621 |
+
|
| 622 |
+
# Import Open3D
|
| 623 |
+
global OPEN3D_AVAILABLE
|
| 624 |
+
try:
|
| 625 |
+
import open3d as o3d
|
| 626 |
+
OPEN3D_AVAILABLE = True
|
| 627 |
+
except ImportError as e:
|
| 628 |
+
logging.warning(f"Open3D not available: {e}")
|
| 629 |
+
OPEN3D_AVAILABLE = False
|
| 630 |
+
return None, None, None, "❌ Open3D not available. Point cloud generation disabled."
|
| 631 |
+
|
| 632 |
+
if left_image is None or right_image is None:
|
| 633 |
+
return None, None, None, "❌ Please upload both left and right images."
|
| 634 |
+
|
| 635 |
+
# Convert image paths to numpy arrays
|
| 636 |
+
logging.info(f"Loading images: left={left_image}, right={right_image}")
|
| 637 |
+
|
| 638 |
+
try:
|
| 639 |
+
# Load left image
|
| 640 |
+
if left_image is None:
|
| 641 |
+
return None, None, None, "❌ Left image is None."
|
| 642 |
+
if not os.path.exists(left_image):
|
| 643 |
+
return None, None, None, f"❌ Left image file does not exist: {left_image}"
|
| 644 |
+
left_img = None
|
| 645 |
+
# Try OpenCV
|
| 646 |
+
try:
|
| 647 |
+
left_img = cv2.imread(left_image)
|
| 648 |
+
if left_img is not None:
|
| 649 |
+
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
|
| 650 |
+
except Exception as e:
|
| 651 |
+
logging.warning(f"OpenCV failed for left image: {e}")
|
| 652 |
+
# Try PIL if OpenCV fails
|
| 653 |
+
if left_img is None:
|
| 654 |
+
try:
|
| 655 |
+
from PIL import Image
|
| 656 |
+
left_img = np.array(Image.open(left_image).convert('RGB'))
|
| 657 |
+
except Exception as e:
|
| 658 |
+
logging.warning(f"PIL failed for left image: {e}")
|
| 659 |
+
# Try imageio if PIL fails
|
| 660 |
+
if left_img is None:
|
| 661 |
+
try:
|
| 662 |
+
import imageio
|
| 663 |
+
left_img = imageio.imread(left_image)
|
| 664 |
+
if left_img.ndim == 2:
|
| 665 |
+
left_img = np.stack([left_img]*3, axis=-1)
|
| 666 |
+
elif left_img.shape[2] == 4:
|
| 667 |
+
left_img = left_img[..., :3]
|
| 668 |
+
except Exception as e:
|
| 669 |
+
logging.warning(f"imageio failed for left image: {e}")
|
| 670 |
+
if left_img is None:
|
| 671 |
+
return None, None, None, f"❌ Could not load left image: {left_image}"
|
| 672 |
+
|
| 673 |
+
# Load right image
|
| 674 |
+
if right_image is None:
|
| 675 |
+
return None, None, None, "❌ Right image is None."
|
| 676 |
+
if not os.path.exists(right_image):
|
| 677 |
+
return None, None, None, f"❌ Right image file does not exist: {right_image}"
|
| 678 |
+
right_img = None
|
| 679 |
+
# Try OpenCV
|
| 680 |
+
try:
|
| 681 |
+
right_img = cv2.imread(right_image)
|
| 682 |
+
if right_img is not None:
|
| 683 |
+
right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
|
| 684 |
+
except Exception as e:
|
| 685 |
+
logging.warning(f"OpenCV failed for right image: {e}")
|
| 686 |
+
# Try PIL if OpenCV fails
|
| 687 |
+
if right_img is None:
|
| 688 |
+
try:
|
| 689 |
+
from PIL import Image
|
| 690 |
+
right_img = np.array(Image.open(right_image).convert('RGB'))
|
| 691 |
+
except Exception as e:
|
| 692 |
+
logging.warning(f"PIL failed for right image: {e}")
|
| 693 |
+
# Try imageio if PIL fails
|
| 694 |
+
if right_img is None:
|
| 695 |
+
try:
|
| 696 |
+
import imageio
|
| 697 |
+
right_img = imageio.imread(right_image)
|
| 698 |
+
if right_img.ndim == 2:
|
| 699 |
+
right_img = np.stack([right_img]*3, axis=-1)
|
| 700 |
+
elif right_img.shape[2] == 4:
|
| 701 |
+
right_img = right_img[..., :3]
|
| 702 |
+
except Exception as e:
|
| 703 |
+
logging.warning(f"imageio failed for right image: {e}")
|
| 704 |
+
if right_img is None:
|
| 705 |
+
return None, None, None, f"❌ Could not load right image: {right_image}"
|
| 706 |
+
|
| 707 |
+
# Update variables
|
| 708 |
+
left_image = left_img
|
| 709 |
+
right_image = right_img
|
| 710 |
+
|
| 711 |
+
logging.info(f"Images loaded successfully - Left: {left_image.shape}, Right: {right_image.shape}")
|
| 712 |
+
|
| 713 |
+
except Exception as e:
|
| 714 |
+
logging.error(f"Failed to load images: {e}")
|
| 715 |
+
return None, None, None, f"❌ Failed to load images: {str(e)}"
|
| 716 |
+
|
| 717 |
+
try:
|
| 718 |
+
progress(0.1, desc="Parsing camera parameters...")
|
| 719 |
+
|
| 720 |
+
# Parse camera matrix
|
| 721 |
+
try:
|
| 722 |
+
K_values = list(map(float, camera_matrix.strip().split()))
|
| 723 |
+
if len(K_values) != 9:
|
| 724 |
+
return None, None, None, "❌ Camera matrix must contain exactly 9 values."
|
| 725 |
+
K = np.array(K_values).reshape(3, 3)
|
| 726 |
+
except ValueError:
|
| 727 |
+
return None, None, None, "❌ Invalid camera matrix format. Use space-separated numbers."
|
| 728 |
+
|
| 729 |
+
if baseline <= 0:
|
| 730 |
+
return None, None, None, "❌ Baseline must be positive."
|
| 731 |
+
|
| 732 |
+
variant = model_selection.split('(')[1].split(')')[0] if '(' in model_selection else "Unknown"
|
| 733 |
+
progress(0.2, desc=f"Loading cached model ({variant})...")
|
| 734 |
+
|
| 735 |
+
# Get cached model (will load if not cached or selection changed)
|
| 736 |
+
model, device = get_cached_model(model_selection)
|
| 737 |
+
|
| 738 |
+
progress(0.4, desc="Running stereo inference...")
|
| 739 |
+
|
| 740 |
+
# Get disparity using the same process as the basic function
|
| 741 |
+
H, W = left_image.shape[:2]
|
| 742 |
+
img0 = torch.as_tensor(left_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 743 |
+
img1 = torch.as_tensor(right_image).to(device).half()[None].permute(0,3,1,2).contiguous()
|
| 744 |
+
|
| 745 |
+
padder = InputPadder(img0.shape, divis_by=32, force_square=False)
|
| 746 |
+
img0, img1 = padder.pad(img0, img1)
|
| 747 |
+
|
| 748 |
+
# Ensure padded tensors are contiguous
|
| 749 |
+
img0 = img0.contiguous()
|
| 750 |
+
img1 = img1.contiguous()
|
| 751 |
+
|
| 752 |
+
# Clear cache and ensure clean memory state before inference
|
| 753 |
+
torch.cuda.empty_cache()
|
| 754 |
+
|
| 755 |
+
try:
|
| 756 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 757 |
+
# Double-check tensor contiguity before cuDNN operations
|
| 758 |
+
if not img0.is_contiguous():
|
| 759 |
+
img0 = img0.contiguous()
|
| 760 |
+
if not img1.is_contiguous():
|
| 761 |
+
img1 = img1.contiguous()
|
| 762 |
+
|
| 763 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 764 |
+
except RuntimeError as e:
|
| 765 |
+
if "cuDNN" in str(e):
|
| 766 |
+
# Fallback: disable cuDNN optimizations and retry
|
| 767 |
+
logging.warning(f"cuDNN error encountered in depth processing, retrying with fallback: {e}")
|
| 768 |
+
torch.backends.cudnn.enabled = False
|
| 769 |
+
try:
|
| 770 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 771 |
+
disp = model.forward(img0, img1, iters=32, test_mode=True)
|
| 772 |
+
finally:
|
| 773 |
+
torch.backends.cudnn.enabled = True # Re-enable for future use
|
| 774 |
+
else:
|
| 775 |
+
raise e
|
| 776 |
+
|
| 777 |
+
disp = padder.unpad(disp.float())
|
| 778 |
+
disp_cpu = disp.data.cpu().numpy().reshape(H, W)
|
| 779 |
+
|
| 780 |
+
# Clean up intermediate tensors early
|
| 781 |
+
del img0, img1, disp
|
| 782 |
+
|
| 783 |
+
# Keep model reference for rest of processing
|
| 784 |
+
torch.cuda.empty_cache()
|
| 785 |
+
|
| 786 |
+
progress(0.6, desc="Converting to depth...")
|
| 787 |
+
|
| 788 |
+
# Remove invisible points (same as in original demo)
|
| 789 |
+
yy, xx = np.meshgrid(np.arange(disp_cpu.shape[0]), np.arange(disp_cpu.shape[1]), indexing='ij')
|
| 790 |
+
us_right = xx - disp_cpu
|
| 791 |
+
invalid = us_right < 0
|
| 792 |
+
disp_cpu[invalid] = np.inf
|
| 793 |
+
|
| 794 |
+
# Convert to depth using the formula from the original demo
|
| 795 |
+
depth = K[0, 0] * baseline / disp_cpu
|
| 796 |
+
|
| 797 |
+
# Visualize depth (no rotation)
|
| 798 |
+
depth_vis = vis_disparity(depth, max_val=10.0)
|
| 799 |
+
|
| 800 |
+
progress(0.8, desc="Generating point cloud...")
|
| 801 |
+
|
| 802 |
+
# Generate point cloud with proper coordinate transformation
|
| 803 |
+
fx, fy = K[0, 0], K[1, 1]
|
| 804 |
+
cx, cy = K[0, 2], K[1, 2]
|
| 805 |
+
|
| 806 |
+
# Create coordinate meshgrids
|
| 807 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 808 |
+
|
| 809 |
+
# Convert to 3D coordinates (proper camera coordinate system)
|
| 810 |
+
valid_depth = depth != np.inf
|
| 811 |
+
z = depth[valid_depth] # Z coordinate (depth)
|
| 812 |
+
x = (u[valid_depth] - cx) * z / fx # X coordinate
|
| 813 |
+
y = (v[valid_depth] - cy) * z / fy # Y coordinate
|
| 814 |
+
|
| 815 |
+
# Stack coordinates (X, Y, Z)
|
| 816 |
+
points = np.stack([x, y, z], axis=-1)
|
| 817 |
+
|
| 818 |
+
# Get corresponding colors
|
| 819 |
+
colors = left_image[valid_depth]
|
| 820 |
+
|
| 821 |
+
# Filter points by depth range
|
| 822 |
+
depth_mask = (z > 0) & (z <= 10.0)
|
| 823 |
+
valid_points = points[depth_mask]
|
| 824 |
+
valid_colors = colors[depth_mask]
|
| 825 |
+
|
| 826 |
+
if len(valid_points) == 0:
|
| 827 |
+
return depth_vis, None, None, "⚠️ No valid points generated for point cloud."
|
| 828 |
+
|
| 829 |
+
# Subsample points for better 3D visualization performance
|
| 830 |
+
if len(valid_points) > 100000:
|
| 831 |
+
indices = np.random.choice(len(valid_points), 100000, replace=False)
|
| 832 |
+
valid_points = valid_points[indices]
|
| 833 |
+
valid_colors = valid_colors[indices]
|
| 834 |
+
|
| 835 |
+
# Transform coordinates for proper visualization orientation
|
| 836 |
+
# Standard computer vision: X right, Y down, Z forward
|
| 837 |
+
# For better 3D viewing: X right, Y up, Z backward
|
| 838 |
+
transformed_points = valid_points.copy()
|
| 839 |
+
transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis
|
| 840 |
+
transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis
|
| 841 |
+
|
| 842 |
+
# Generate point cloud using transformed coordinates
|
| 843 |
+
pcd = o3d.geometry.PointCloud()
|
| 844 |
+
pcd.points = o3d.utility.Vector3dVector(transformed_points)
|
| 845 |
+
pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0)
|
| 846 |
+
|
| 847 |
+
progress(1.0, desc="Complete!")
|
| 848 |
+
|
| 849 |
+
# Clean up model after inference
|
| 850 |
+
del model
|
| 851 |
+
torch.cuda.empty_cache()
|
| 852 |
+
gc.collect()
|
| 853 |
+
|
| 854 |
+
# Check current memory usage
|
| 855 |
+
current_memory = torch.cuda.memory_allocated(0) / 1024**3
|
| 856 |
+
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3
|
| 857 |
+
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak"
|
| 858 |
+
|
| 859 |
+
status = f"""✅ Depth processing successful!
|
| 860 |
+
🔧 Model: {variant}{memory_info}
|
| 861 |
+
📊 Statistics:
|
| 862 |
+
• Valid points: {len(valid_points):,}
|
| 863 |
+
• Depth range: {z.min():.2f} - {z.max():.2f} m
|
| 864 |
+
• Baseline: {baseline} m
|
| 865 |
+
• Point cloud generated with {len(valid_points)} points (not saved to file)
|
| 866 |
+
• 3D visualization available (in-memory)"""
|
| 867 |
+
|
| 868 |
+
return depth_vis, None, None, status
|
| 869 |
+
|
| 870 |
+
except Exception as e:
|
| 871 |
+
logging.error(f"Depth processing failed: {e}")
|
| 872 |
+
# Cleanup on error
|
| 873 |
+
if 'img0' in locals():
|
| 874 |
+
del img0
|
| 875 |
+
if 'img1' in locals():
|
| 876 |
+
del img1
|
| 877 |
+
if 'disp' in locals():
|
| 878 |
+
del disp
|
| 879 |
+
if 'model' in locals():
|
| 880 |
+
del model
|
| 881 |
+
# Clean up GPU memory
|
| 882 |
+
torch.cuda.empty_cache()
|
| 883 |
+
gc.collect()
|
| 884 |
+
return None, None, None, f"❌ Error: {str(e)}"
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def preload_all_models():
|
| 888 |
+
"""Pre-download all Hugging Face models to cache during startup"""
|
| 889 |
+
logging.info("🔄 Pre-downloading all models to cache...")
|
| 890 |
+
|
| 891 |
+
downloaded_models = {}
|
| 892 |
+
|
| 893 |
+
for variant, info in MODEL_VARIANTS.items():
|
| 894 |
+
try:
|
| 895 |
+
logging.info(f"📥 Downloading {variant} model to cache...")
|
| 896 |
+
model_path, config_path = download_model_from_hf(variant, force_download=False)
|
| 897 |
+
downloaded_models[variant] = {
|
| 898 |
+
"model_path": model_path,
|
| 899 |
+
"config_path": config_path,
|
| 900 |
+
"display_name": info["display_name"]
|
| 901 |
+
}
|
| 902 |
+
logging.info(f"✅ {variant} model cached successfully")
|
| 903 |
+
except Exception as e:
|
| 904 |
+
logging.warning(f"⚠️ Failed to download {variant} model: {e}")
|
| 905 |
+
# Continue with other models even if one fails
|
| 906 |
+
|
| 907 |
+
logging.info(f"✅ Model pre-loading complete. {len(downloaded_models)}/{len(MODEL_VARIANTS)} models cached.")
|
| 908 |
+
return downloaded_models
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def create_app() -> gr.Blocks:
|
| 912 |
+
"""Create the Gradio application"""
|
| 913 |
+
|
| 914 |
+
global MODEL_PATH, CONFIG_PATH
|
| 915 |
+
|
| 916 |
+
# Debug: Print current directory and check for files
|
| 917 |
+
print(f"Current directory: {current_dir}")
|
| 918 |
+
print(f"Python working directory: {os.getcwd()}")
|
| 919 |
+
|
| 920 |
+
# Pre-download all models to cache
|
| 921 |
+
try:
|
| 922 |
+
cached_models = preload_all_models()
|
| 923 |
+
logging.info(f"Pre-loaded {len(cached_models)} models to cache")
|
| 924 |
+
except Exception as e:
|
| 925 |
+
logging.error(f"Failed to pre-load models: {e}")
|
| 926 |
+
cached_models = {}
|
| 927 |
+
|
| 928 |
+
# Get available models (this should be safe as it only does file system operations)
|
| 929 |
+
try:
|
| 930 |
+
available_models = get_available_models()
|
| 931 |
+
logging.info(f"Successfully got available models: {len(available_models)} found")
|
| 932 |
+
except Exception as e:
|
| 933 |
+
logging.error(f"Failed to get available models: {e}")
|
| 934 |
+
available_models = {}
|
| 935 |
+
|
| 936 |
+
# Find model and config paths (legacy) - should be safe as well
|
| 937 |
+
try:
|
| 938 |
+
MODEL_PATH, CONFIG_PATH = find_model_path()
|
| 939 |
+
logging.info("Successfully found model paths")
|
| 940 |
+
except Exception as e:
|
| 941 |
+
logging.error(f"Failed to find model paths: {e}")
|
| 942 |
+
MODEL_PATH, CONFIG_PATH = None, None
|
| 943 |
+
|
| 944 |
+
with gr.Blocks(
|
| 945 |
+
title="FoundationStereo - Stereo Depth Estimation",
|
| 946 |
+
theme=gr.themes.Soft(),
|
| 947 |
+
css="footer {visibility: hidden}",
|
| 948 |
+
delete_cache=(60, 60) # Delete cache after 60 seconds
|
| 949 |
+
) as app:
|
| 950 |
+
|
| 951 |
+
gr.Markdown("""
|
| 952 |
+
# 🔍 FoundationStereo: Zero-Shot Stereo Matching
|
| 953 |
+
|
| 954 |
+
Upload a pair of **rectified** stereo images to get disparity estimation.
|
| 955 |
+
|
| 956 |
+
⚠️ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted.
|
| 957 |
+
⚡ **GPU Powered**: Runs on high-performance GPUs for fast inference.
|
| 958 |
+
📦 **Smart Caching**: All models are pre-downloaded for instant model switching.
|
| 959 |
+
""")
|
| 960 |
+
|
| 961 |
+
# Instructions section
|
| 962 |
+
with gr.Accordion("📋 Instructions to Run This Repository", open=False):
|
| 963 |
+
gr.Markdown("""
|
| 964 |
+
## 🚀 How to Run This Demo
|
| 965 |
+
This is a **demo application** showcasing the FoundationStereo model for stereo matching estimation.
|
| 966 |
+
|
| 967 |
+
### 🖼️ Input Requirements
|
| 968 |
+
|
| 969 |
+
1. **Image Format**: Upload images in JPEG or PNG format.
|
| 970 |
+
2. **Image Size**: Images should be of the same size and resolution.
|
| 971 |
+
3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted.
|
| 972 |
+
4. **Camera Parameters**: For advanced processing, provide camera parameters (camera matrix and baseline).
|
| 973 |
+
|
| 974 |
+
### 📊 Using the Demo
|
| 975 |
+
|
| 976 |
+
1. **Select Model**: Choose between low-cost (11-33-40) or high-quality (23-51-11) variants
|
| 977 |
+
2. **Upload Images**: Provide rectified stereo image pairs
|
| 978 |
+
3. **Basic Processing**: Get disparity visualization
|
| 979 |
+
4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters)
|
| 980 |
+
|
| 981 |
+
### Original Work
|
| 982 |
+
|
| 983 |
+
This demo is based on the original FoundationStereo research. Please visit the official resources:
|
| 984 |
+
- **Paper**: [FoundationStereo: Zero-Shot Stereo Matching via Foundation Model](https://arxiv.org/abs/2501.09898)
|
| 985 |
+
- **Project Page**: [https://nvlabs.github.io/FoundationStereo/](https://nvlabs.github.io/FoundationStereo/)
|
| 986 |
+
- **Official Repository**: [https://github.com/NVlabs/FoundationStereo](https://github.com/NVlabs/FoundationStereo)
|
| 987 |
+
|
| 988 |
+
**⚠️ Demo Notice**: This is a demonstration interface. For research and production use, please refer to the original repository and follow the official implementation guidelines.
|
| 989 |
+
""")
|
| 990 |
+
|
| 991 |
+
# Model selection
|
| 992 |
+
with gr.Row():
|
| 993 |
+
# Always include Hugging Face models in the choices
|
| 994 |
+
all_choices = list(available_models.keys())
|
| 995 |
+
|
| 996 |
+
# If no models found, add the HF models manually
|
| 997 |
+
if not all_choices:
|
| 998 |
+
all_choices = [
|
| 999 |
+
"FoundationStereo (Low-cost variant - 11-33-40) [Hugging Face]",
|
| 1000 |
+
"FoundationStereo (High-quality variant - 23-51-11) [Hugging Face]"
|
| 1001 |
+
]
|
| 1002 |
+
|
| 1003 |
+
# Get default model (prefer Hugging Face low-cost variant)
|
| 1004 |
+
default_model = None
|
| 1005 |
+
|
| 1006 |
+
# First try Hugging Face low-cost variant
|
| 1007 |
+
for name in all_choices:
|
| 1008 |
+
if "11-33-40" in name and "[Hugging Face]" in name:
|
| 1009 |
+
default_model = name
|
| 1010 |
+
break
|
| 1011 |
+
|
| 1012 |
+
# If no HF low-cost variant, try any low-cost variant
|
| 1013 |
+
if default_model is None:
|
| 1014 |
+
for name in all_choices:
|
| 1015 |
+
if "11-33-40" in name:
|
| 1016 |
+
default_model = name
|
| 1017 |
+
break
|
| 1018 |
+
|
| 1019 |
+
# If no low-cost variant, use first available
|
| 1020 |
+
if default_model is None:
|
| 1021 |
+
default_model = all_choices[0] if all_choices else None
|
| 1022 |
+
|
| 1023 |
+
model_selector = gr.Dropdown(
|
| 1024 |
+
choices=all_choices,
|
| 1025 |
+
value=default_model,
|
| 1026 |
+
label="🎯 Select Model",
|
| 1027 |
+
info="Choose the FoundationStereo model variant. Hugging Face models download automatically.",
|
| 1028 |
+
interactive=True
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
with gr.Tabs():
|
| 1032 |
+
# Basic stereo processing tab
|
| 1033 |
+
with gr.TabItem("🖼️ Basic Stereo Processing"):
|
| 1034 |
+
with gr.Row():
|
| 1035 |
+
with gr.Column():
|
| 1036 |
+
left_input = gr.Image(
|
| 1037 |
+
label="📷 Left Image",
|
| 1038 |
+
type="filepath",
|
| 1039 |
+
height=300
|
| 1040 |
+
)
|
| 1041 |
+
right_input = gr.Image(
|
| 1042 |
+
label="📷 Right Image",
|
| 1043 |
+
type="filepath",
|
| 1044 |
+
height=300
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
process_btn = gr.Button(
|
| 1048 |
+
"🚀 Process Stereo Pair",
|
| 1049 |
+
variant="primary",
|
| 1050 |
+
size="lg"
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
with gr.Column():
|
| 1054 |
+
output_image = gr.Image(
|
| 1055 |
+
label="📊 Disparity Visualization",
|
| 1056 |
+
height=400
|
| 1057 |
+
)
|
| 1058 |
+
status_text = gr.Textbox(
|
| 1059 |
+
label="Status",
|
| 1060 |
+
interactive=False,
|
| 1061 |
+
lines=8
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
# Example images
|
| 1065 |
+
examples_list = []
|
| 1066 |
+
|
| 1067 |
+
# Example 1
|
| 1068 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 1069 |
+
examples_list.append([
|
| 1070 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 1071 |
+
os.path.join(current_dir, "assets", "example1", "right.png")
|
| 1072 |
+
])
|
| 1073 |
+
|
| 1074 |
+
# Example 2
|
| 1075 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 1076 |
+
examples_list.append([
|
| 1077 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 1078 |
+
os.path.join(current_dir, "assets", "example2", "right.png")
|
| 1079 |
+
])
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
gr.Examples(
|
| 1084 |
+
examples=examples_list,
|
| 1085 |
+
inputs=[left_input, right_input],
|
| 1086 |
+
label="📋 Example Images"
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
# Advanced processing with depth
|
| 1090 |
+
with gr.TabItem("📐 Advanced Processing (Depth & Point Cloud)"):
|
| 1091 |
+
with gr.Row():
|
| 1092 |
+
with gr.Column():
|
| 1093 |
+
left_input_adv = gr.Image(
|
| 1094 |
+
label="📷 Left Image",
|
| 1095 |
+
type="filepath",
|
| 1096 |
+
height=250
|
| 1097 |
+
)
|
| 1098 |
+
right_input_adv = gr.Image(
|
| 1099 |
+
label="📷 Right Image",
|
| 1100 |
+
type="filepath",
|
| 1101 |
+
height=250
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
# Camera parameters
|
| 1105 |
+
with gr.Group():
|
| 1106 |
+
gr.Markdown("### 📹 Camera Parameters")
|
| 1107 |
+
camera_matrix_input = gr.Textbox(
|
| 1108 |
+
label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)",
|
| 1109 |
+
value="",
|
| 1110 |
+
|
| 1111 |
+
)
|
| 1112 |
+
baseline_input = gr.Number(
|
| 1113 |
+
label="Baseline (meters)",
|
| 1114 |
+
value=None,
|
| 1115 |
+
minimum=0.001,
|
| 1116 |
+
maximum=10.0,
|
| 1117 |
+
step=0.001
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
process_depth_btn = gr.Button(
|
| 1121 |
+
"🔬 Process with Depth",
|
| 1122 |
+
variant="primary",
|
| 1123 |
+
size="lg"
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
with gr.Column():
|
| 1127 |
+
depth_output = gr.Image(
|
| 1128 |
+
label="📏 Depth Visualization",
|
| 1129 |
+
height=300
|
| 1130 |
+
)
|
| 1131 |
+
pointcloud_output = gr.File(
|
| 1132 |
+
label="☁️ Point Cloud Download (.ply)",
|
| 1133 |
+
file_types=[".ply"]
|
| 1134 |
+
)
|
| 1135 |
+
status_depth = gr.Textbox(
|
| 1136 |
+
label="Status",
|
| 1137 |
+
interactive=False,
|
| 1138 |
+
lines=6
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
# 3D Point Cloud Visualization
|
| 1142 |
+
with gr.Row():
|
| 1143 |
+
pointcloud_3d = gr.Model3D(
|
| 1144 |
+
label="🌐 3D Point Cloud Viewer",
|
| 1145 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 1146 |
+
height=400
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
# Example images for advanced processing
|
| 1150 |
+
examples_advanced_list = []
|
| 1151 |
+
|
| 1152 |
+
# Example 1 - Camera parameters from K.txt
|
| 1153 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")):
|
| 1154 |
+
examples_advanced_list.append([
|
| 1155 |
+
os.path.join(current_dir, "assets", "example1", "left.png"),
|
| 1156 |
+
os.path.join(current_dir, "assets", "example1", "right.png"),
|
| 1157 |
+
"754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0", # Camera matrix
|
| 1158 |
+
0.063 # Baseline in meters
|
| 1159 |
+
])
|
| 1160 |
+
|
| 1161 |
+
# Example 2 - Camera parameters from K.txt
|
| 1162 |
+
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")):
|
| 1163 |
+
examples_advanced_list.append([
|
| 1164 |
+
os.path.join(current_dir, "assets", "example2", "left.png"),
|
| 1165 |
+
os.path.join(current_dir, "assets", "example2", "right.png"),
|
| 1166 |
+
"1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0", # Camera matrix
|
| 1167 |
+
0.537 # Baseline in meters (converted from 536.62mm)
|
| 1168 |
+
])
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
gr.Examples(
|
| 1173 |
+
examples=examples_advanced_list,
|
| 1174 |
+
inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 1175 |
+
label="📋 Example Images with Camera Parameters"
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
# Event handlers - Always enable since we have HF models
|
| 1179 |
+
process_btn.click(
|
| 1180 |
+
fn=process_stereo_pair,
|
| 1181 |
+
inputs=[model_selector, left_input, right_input],
|
| 1182 |
+
outputs=[output_image, status_text],
|
| 1183 |
+
show_progress=True
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
if OPEN3D_AVAILABLE:
|
| 1187 |
+
process_depth_btn.click(
|
| 1188 |
+
fn=process_with_depth,
|
| 1189 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 1190 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth],
|
| 1191 |
+
show_progress=True
|
| 1192 |
+
)
|
| 1193 |
+
else:
|
| 1194 |
+
process_depth_btn.click(
|
| 1195 |
+
fn=lambda *args: (None, None, None, "❌ Open3D not available. Install with: pip install open3d"),
|
| 1196 |
+
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input],
|
| 1197 |
+
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth]
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
# Citation section at the bottom
|
| 1201 |
+
with gr.Accordion("📖 Citation", open=False):
|
| 1202 |
+
gr.Markdown("""
|
| 1203 |
+
### 📄 Please Cite the Original Paper
|
| 1204 |
+
|
| 1205 |
+
If you use this work in your research, please cite:
|
| 1206 |
+
|
| 1207 |
+
```bibtex
|
| 1208 |
+
@article{wen2025stereo,
|
| 1209 |
+
title={FoundationStereo: Zero-Shot Stereo Matching},
|
| 1210 |
+
author={Bowen Wen and Matthew Trepte and Joseph Aribido and Jan Kautz and Orazio Gallo and Stan Birchfield},
|
| 1211 |
+
journal={CVPR},
|
| 1212 |
+
year={2025}
|
| 1213 |
+
}
|
| 1214 |
+
```
|
| 1215 |
+
""")
|
| 1216 |
+
|
| 1217 |
+
# Footer
|
| 1218 |
+
gr.Markdown(f"""
|
| 1219 |
+
---
|
| 1220 |
+
### 📝 Notes:
|
| 1221 |
+
- **Input images must be rectified stereo pairs** (epipolar lines are horizontal)
|
| 1222 |
+
- **🤗 Hugging Face Integration**: Models are automatically downloaded from `{HF_REPO_ID}`
|
| 1223 |
+
- **📦 Smart Caching**: All models are pre-downloaded and cached for instant switching
|
| 1224 |
+
- **⚡ GPU Acceleration**: Powered by high-performance GPUs
|
| 1225 |
+
- For best results, use PNG images without lossy compression
|
| 1226 |
+
- Model works on RGB images but also supports monochrome/IR stereo pairs
|
| 1227 |
+
- **Optimized for Performance**: Memory-efficient inference
|
| 1228 |
+
|
| 1229 |
+
### 🔗 References:
|
| 1230 |
+
- [FoundationStereo Paper](https://arxiv.org/abs/2501.09898)
|
| 1231 |
+
- [Project Website](https://nvlabs.github.io/FoundationStereo/)
|
| 1232 |
+
- [GitHub Repository](https://github.com/NVlabs/FoundationStereo)
|
| 1233 |
+
- [Hugging Face Models]({f"https://huggingface.co/{HF_REPO_ID}"})
|
| 1234 |
+
""")
|
| 1235 |
+
|
| 1236 |
+
return app
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
def main():
|
| 1240 |
+
"""Main function to launch the app"""
|
| 1241 |
+
|
| 1242 |
+
# Ensure no CUDA operations during startup
|
| 1243 |
+
if torch.cuda.is_available():
|
| 1244 |
+
logging.warning("CUDA detected during startup")
|
| 1245 |
+
|
| 1246 |
+
logging.info("🚀 Starting FoundationStereo Gradio App...")
|
| 1247 |
+
|
| 1248 |
+
# Parse command line arguments
|
| 1249 |
+
import argparse
|
| 1250 |
+
parser = argparse.ArgumentParser(description="FoundationStereo Gradio App")
|
| 1251 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 1252 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
| 1253 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 1254 |
+
|
| 1255 |
+
args = parser.parse_args()
|
| 1256 |
+
|
| 1257 |
+
if args.debug:
|
| 1258 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 1259 |
+
|
| 1260 |
+
try:
|
| 1261 |
+
# Create and launch app
|
| 1262 |
+
logging.info("Creating Gradio app...")
|
| 1263 |
+
app = create_app()
|
| 1264 |
+
logging.info("✅ Gradio app created successfully")
|
| 1265 |
+
|
| 1266 |
+
logging.info(f"Launching app on {args.host}:{args.port}")
|
| 1267 |
+
|
| 1268 |
+
# Launch with appropriate settings
|
| 1269 |
+
app.launch(
|
| 1270 |
+
server_name=args.host,
|
| 1271 |
+
server_port=args.port,
|
| 1272 |
+
share=False,
|
| 1273 |
+
show_error=True,
|
| 1274 |
+
favicon_path=None,
|
| 1275 |
+
ssr_mode=False, # Disable SSR for compatibility
|
| 1276 |
+
allowed_paths=["./"] # Allow access to local files
|
| 1277 |
+
)
|
| 1278 |
+
except Exception as e:
|
| 1279 |
+
logging.error(f"Failed to launch app: {e}")
|
| 1280 |
+
raise
|
| 1281 |
+
|
| 1282 |
+
|
| 1283 |
+
if __name__ == "__main__":
|
| 1284 |
+
# Additional safety check for Spaces environment
|
| 1285 |
+
if 'SPACE_ID' in os.environ:
|
| 1286 |
+
logging.info("Running in Hugging Face Spaces environment")
|
| 1287 |
+
|
| 1288 |
+
# Do not check CUDA status during startup - this can trigger CUDA initialization
|
| 1289 |
+
# The CUDA status will be checked inside the GPU decorated functions
|
| 1290 |
+
logging.info("✅ CUDA status will be checked within GPU functions")
|
| 1291 |
+
|
| 1292 |
+
main()
|
FoundationStereo_demo/core/extractor.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch,logging,os,sys,urllib,warnings
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 14 |
+
sys.path.append(f'{code_dir}/../')
|
| 15 |
+
from core.submodule import *
|
| 16 |
+
from Utils import *
|
| 17 |
+
import timm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ResidualBlock(nn.Module):
|
| 21 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
| 22 |
+
super(ResidualBlock, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
| 25 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
| 26 |
+
self.relu = nn.ReLU(inplace=True)
|
| 27 |
+
|
| 28 |
+
num_groups = planes // 8
|
| 29 |
+
|
| 30 |
+
if norm_fn == 'group':
|
| 31 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 32 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 33 |
+
if not (stride == 1 and in_planes == planes):
|
| 34 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 35 |
+
|
| 36 |
+
elif norm_fn == 'batch':
|
| 37 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 38 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 39 |
+
if not (stride == 1 and in_planes == planes):
|
| 40 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 41 |
+
|
| 42 |
+
elif norm_fn == 'instance':
|
| 43 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 44 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 45 |
+
if not (stride == 1 and in_planes == planes):
|
| 46 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 47 |
+
|
| 48 |
+
elif norm_fn=='layer':
|
| 49 |
+
self.norm1 = LayerNorm2d(planes)
|
| 50 |
+
self.norm2 = LayerNorm2d(planes)
|
| 51 |
+
if not (stride == 1 and in_planes == planes):
|
| 52 |
+
self.norm3 = LayerNorm2d(planes)
|
| 53 |
+
|
| 54 |
+
elif norm_fn == 'none':
|
| 55 |
+
self.norm1 = nn.Sequential()
|
| 56 |
+
self.norm2 = nn.Sequential()
|
| 57 |
+
if not (stride == 1 and in_planes == planes):
|
| 58 |
+
self.norm3 = nn.Sequential()
|
| 59 |
+
|
| 60 |
+
if stride == 1 and in_planes == planes:
|
| 61 |
+
self.downsample = None
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
self.downsample = nn.Sequential(
|
| 65 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
y = x
|
| 70 |
+
y = self.conv1(y)
|
| 71 |
+
y = self.norm1(y)
|
| 72 |
+
y = self.relu(y)
|
| 73 |
+
y = self.conv2(y)
|
| 74 |
+
y = self.norm2(y)
|
| 75 |
+
y = self.relu(y)
|
| 76 |
+
|
| 77 |
+
if self.downsample is not None:
|
| 78 |
+
x = self.downsample(x)
|
| 79 |
+
|
| 80 |
+
return self.relu(x+y)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiBasicEncoder(nn.Module):
|
| 85 |
+
def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
|
| 86 |
+
super(MultiBasicEncoder, self).__init__()
|
| 87 |
+
self.norm_fn = norm_fn
|
| 88 |
+
self.downsample = downsample
|
| 89 |
+
|
| 90 |
+
if self.norm_fn == 'group':
|
| 91 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 92 |
+
|
| 93 |
+
elif self.norm_fn == 'batch':
|
| 94 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 95 |
+
|
| 96 |
+
elif self.norm_fn == 'instance':
|
| 97 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 98 |
+
|
| 99 |
+
elif self.norm_fn=='layer':
|
| 100 |
+
self.norm1 = LayerNorm2d(64)
|
| 101 |
+
|
| 102 |
+
elif self.norm_fn == 'none':
|
| 103 |
+
self.norm1 = nn.Sequential()
|
| 104 |
+
|
| 105 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
|
| 106 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 107 |
+
|
| 108 |
+
self.in_planes = 64
|
| 109 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 110 |
+
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
|
| 111 |
+
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
|
| 112 |
+
self.layer4 = self._make_layer(128, stride=2)
|
| 113 |
+
self.layer5 = self._make_layer(128, stride=2)
|
| 114 |
+
|
| 115 |
+
output_list = []
|
| 116 |
+
|
| 117 |
+
for dim in output_dim:
|
| 118 |
+
conv_out = nn.Sequential(
|
| 119 |
+
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
| 120 |
+
nn.Conv2d(128, dim[2], 3, padding=1))
|
| 121 |
+
output_list.append(conv_out)
|
| 122 |
+
|
| 123 |
+
self.outputs04 = nn.ModuleList(output_list)
|
| 124 |
+
|
| 125 |
+
output_list = []
|
| 126 |
+
for dim in output_dim:
|
| 127 |
+
conv_out = nn.Sequential(
|
| 128 |
+
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
| 129 |
+
nn.Conv2d(128, dim[1], 3, padding=1))
|
| 130 |
+
output_list.append(conv_out)
|
| 131 |
+
|
| 132 |
+
self.outputs08 = nn.ModuleList(output_list)
|
| 133 |
+
|
| 134 |
+
output_list = []
|
| 135 |
+
for dim in output_dim:
|
| 136 |
+
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
|
| 137 |
+
output_list.append(conv_out)
|
| 138 |
+
|
| 139 |
+
self.outputs16 = nn.ModuleList(output_list)
|
| 140 |
+
|
| 141 |
+
if dropout > 0:
|
| 142 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 143 |
+
else:
|
| 144 |
+
self.dropout = None
|
| 145 |
+
|
| 146 |
+
for m in self.modules():
|
| 147 |
+
if isinstance(m, nn.Conv2d):
|
| 148 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 149 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 150 |
+
if m.weight is not None:
|
| 151 |
+
nn.init.constant_(m.weight, 1)
|
| 152 |
+
if m.bias is not None:
|
| 153 |
+
nn.init.constant_(m.bias, 0)
|
| 154 |
+
|
| 155 |
+
def _make_layer(self, dim, stride=1):
|
| 156 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 157 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 158 |
+
layers = (layer1, layer2)
|
| 159 |
+
|
| 160 |
+
self.in_planes = dim
|
| 161 |
+
return nn.Sequential(*layers)
|
| 162 |
+
|
| 163 |
+
def forward(self, x, dual_inp=False, num_layers=3):
|
| 164 |
+
|
| 165 |
+
x = self.conv1(x)
|
| 166 |
+
x = self.norm1(x)
|
| 167 |
+
x = self.relu1(x)
|
| 168 |
+
x = self.layer1(x)
|
| 169 |
+
x = self.layer2(x)
|
| 170 |
+
x = self.layer3(x)
|
| 171 |
+
if dual_inp:
|
| 172 |
+
v = x
|
| 173 |
+
x = x[:(x.shape[0]//2)]
|
| 174 |
+
|
| 175 |
+
outputs04 = [f(x) for f in self.outputs04]
|
| 176 |
+
if num_layers == 1:
|
| 177 |
+
return (outputs04, v) if dual_inp else (outputs04,)
|
| 178 |
+
|
| 179 |
+
y = self.layer4(x)
|
| 180 |
+
outputs08 = [f(y) for f in self.outputs08]
|
| 181 |
+
|
| 182 |
+
if num_layers == 2:
|
| 183 |
+
return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08)
|
| 184 |
+
|
| 185 |
+
z = self.layer5(y)
|
| 186 |
+
outputs16 = [f(z) for f in self.outputs16]
|
| 187 |
+
|
| 188 |
+
return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class ContextNetDino(MultiBasicEncoder):
|
| 193 |
+
def __init__(self, args, output_dim=[128], norm_fn='batch', downsample=3):
|
| 194 |
+
nn.Module.__init__(self)
|
| 195 |
+
self.args = args
|
| 196 |
+
self.patch_size = 14
|
| 197 |
+
self.image_size = 518
|
| 198 |
+
self.vit_feat_dim = 384
|
| 199 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 200 |
+
|
| 201 |
+
self.out_dims = output_dim
|
| 202 |
+
|
| 203 |
+
self.norm_fn = norm_fn
|
| 204 |
+
|
| 205 |
+
if self.norm_fn == 'group':
|
| 206 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 207 |
+
|
| 208 |
+
elif self.norm_fn == 'batch':
|
| 209 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 210 |
+
|
| 211 |
+
elif self.norm_fn == 'instance':
|
| 212 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 213 |
+
|
| 214 |
+
elif self.norm_fn=='layer':
|
| 215 |
+
self.norm1 = LayerNorm2d(64)
|
| 216 |
+
|
| 217 |
+
elif self.norm_fn == 'none':
|
| 218 |
+
self.norm1 = nn.Sequential()
|
| 219 |
+
|
| 220 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
|
| 221 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 222 |
+
|
| 223 |
+
self.in_planes = 64
|
| 224 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 225 |
+
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
|
| 226 |
+
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
|
| 227 |
+
self.layer4 = self._make_layer(128, stride=2)
|
| 228 |
+
self.layer5 = self._make_layer(128, stride=2)
|
| 229 |
+
self.down = nn.Sequential(
|
| 230 |
+
nn.Conv2d(128, 128, kernel_size=4, stride=4, padding=0),
|
| 231 |
+
nn.BatchNorm2d(128),
|
| 232 |
+
)
|
| 233 |
+
vit_dim = DepthAnythingFeature.model_configs[self.args.vit_size]['features']//2
|
| 234 |
+
self.conv2 = BasicConv(128+vit_dim, 128, kernel_size=3, padding=1)
|
| 235 |
+
self.norm = nn.BatchNorm2d(256)
|
| 236 |
+
|
| 237 |
+
output_list = []
|
| 238 |
+
for dim in output_dim:
|
| 239 |
+
conv_out = nn.Sequential(
|
| 240 |
+
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
| 241 |
+
nn.Conv2d(128, dim[2], 3, padding=1))
|
| 242 |
+
output_list.append(conv_out)
|
| 243 |
+
|
| 244 |
+
self.outputs04 = nn.ModuleList(output_list)
|
| 245 |
+
|
| 246 |
+
output_list = []
|
| 247 |
+
for dim in output_dim:
|
| 248 |
+
conv_out = nn.Sequential(
|
| 249 |
+
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
| 250 |
+
nn.Conv2d(128, dim[1], 3, padding=1))
|
| 251 |
+
output_list.append(conv_out)
|
| 252 |
+
|
| 253 |
+
self.outputs08 = nn.ModuleList(output_list)
|
| 254 |
+
|
| 255 |
+
output_list = []
|
| 256 |
+
for dim in output_dim:
|
| 257 |
+
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
|
| 258 |
+
output_list.append(conv_out)
|
| 259 |
+
|
| 260 |
+
self.outputs16 = nn.ModuleList(output_list)
|
| 261 |
+
|
| 262 |
+
def forward(self, x_in, vit_feat, dual_inp=False, num_layers=3):
|
| 263 |
+
B,C,H,W = x_in.shape
|
| 264 |
+
x = self.conv1(x_in)
|
| 265 |
+
x = self.norm1(x)
|
| 266 |
+
x = self.relu1(x)
|
| 267 |
+
x = self.layer1(x)
|
| 268 |
+
x = self.layer2(x)
|
| 269 |
+
x = self.layer3(x)
|
| 270 |
+
|
| 271 |
+
divider = np.lcm(self.patch_size, 16)
|
| 272 |
+
H_resize, W_resize = get_resize_keep_aspect_ratio(H,W, divider=divider, max_H=1344, max_W=1344)
|
| 273 |
+
x = torch.cat([x, vit_feat], dim=1)
|
| 274 |
+
x = self.conv2(x)
|
| 275 |
+
outputs04 = [f(x) for f in self.outputs04]
|
| 276 |
+
|
| 277 |
+
y = self.layer4(x)
|
| 278 |
+
outputs08 = [f(y) for f in self.outputs08]
|
| 279 |
+
|
| 280 |
+
z = self.layer5(y)
|
| 281 |
+
outputs16 = [f(z) for f in self.outputs16]
|
| 282 |
+
|
| 283 |
+
return (outputs04, outputs08, outputs16)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class DepthAnythingFeature(nn.Module):
|
| 287 |
+
model_configs = {
|
| 288 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 289 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 290 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
def __init__(self, encoder='vits'):
|
| 294 |
+
super().__init__()
|
| 295 |
+
from depth_anything.dpt import DepthAnything
|
| 296 |
+
self.encoder = encoder
|
| 297 |
+
depth_anything = DepthAnything(self.model_configs[encoder])
|
| 298 |
+
self.depth_anything = depth_anything
|
| 299 |
+
|
| 300 |
+
self.intermediate_layer_idx = { #!NOTE For V2
|
| 301 |
+
'vits': [2, 5, 8, 11],
|
| 302 |
+
'vitb': [2, 5, 8, 11],
|
| 303 |
+
'vitl': [4, 11, 17, 23],
|
| 304 |
+
'vitg': [9, 19, 29, 39]
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def forward(self, x):
|
| 309 |
+
"""
|
| 310 |
+
@x: (B,C,H,W)
|
| 311 |
+
"""
|
| 312 |
+
h, w = x.shape[-2:]
|
| 313 |
+
features = self.depth_anything.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
patch_size = self.depth_anything.pretrained.patch_size
|
| 317 |
+
patch_h, patch_w = h // patch_size, w // patch_size
|
| 318 |
+
out, path_1, path_2, path_3, path_4, disp = self.depth_anything.depth_head.forward(features, patch_h, patch_w, return_intermediate=True)
|
| 319 |
+
|
| 320 |
+
return {'out':out, 'path_1':path_1, 'path_2':path_2, 'path_3':path_3, 'path_4':path_4, 'features':features, 'disp':disp} # path_1 is 1/2; path_2 is 1/4
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class Feature(nn.Module):
|
| 324 |
+
def __init__(self, args):
|
| 325 |
+
super(Feature, self).__init__()
|
| 326 |
+
self.args = args
|
| 327 |
+
model = timm.create_model('edgenext_small', pretrained=True, features_only=False)
|
| 328 |
+
self.stem = model.stem
|
| 329 |
+
self.stages = model.stages
|
| 330 |
+
chans = [48, 96, 160, 304]
|
| 331 |
+
self.chans = chans
|
| 332 |
+
self.dino = DepthAnythingFeature(encoder=self.args.vit_size)
|
| 333 |
+
self.dino = freeze_model(self.dino)
|
| 334 |
+
vit_feat_dim = DepthAnythingFeature.model_configs[self.args.vit_size]['features']//2
|
| 335 |
+
|
| 336 |
+
self.deconv32_16 = Conv2x_IN(chans[3], chans[2], deconv=True, concat=True)
|
| 337 |
+
self.deconv16_8 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True)
|
| 338 |
+
self.deconv8_4 = Conv2x_IN(chans[1]*2, chans[0], deconv=True, concat=True)
|
| 339 |
+
self.conv4 = nn.Sequential(
|
| 340 |
+
BasicConv(chans[0]*2+vit_feat_dim, chans[0]*2+vit_feat_dim, kernel_size=3, stride=1, padding=1, norm='instance'),
|
| 341 |
+
ResidualBlock(chans[0]*2+vit_feat_dim, chans[0]*2+vit_feat_dim, norm_fn='instance'),
|
| 342 |
+
ResidualBlock(chans[0]*2+vit_feat_dim, chans[0]*2+vit_feat_dim, norm_fn='instance'),
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
self.patch_size = 14
|
| 346 |
+
self.d_out = [chans[0]*2+vit_feat_dim, chans[1]*2, chans[2]*2, chans[3]]
|
| 347 |
+
|
| 348 |
+
def forward(self, x):
|
| 349 |
+
B,C,H,W = x.shape
|
| 350 |
+
divider = np.lcm(self.patch_size, 16)
|
| 351 |
+
H_resize, W_resize = get_resize_keep_aspect_ratio(H,W, divider=divider, max_H=1344, max_W=1344)
|
| 352 |
+
x_in_ = F.interpolate(x, size=(H_resize, W_resize), mode='bicubic', align_corners=False)
|
| 353 |
+
self.dino = self.dino.eval()
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
output = self.dino(x_in_)
|
| 356 |
+
vit_feat = output['out']
|
| 357 |
+
vit_feat = F.interpolate(vit_feat, size=(H//4,W//4), mode='bilinear', align_corners=True)
|
| 358 |
+
x = self.stem(x)
|
| 359 |
+
x4 = self.stages[0](x)
|
| 360 |
+
x8 = self.stages[1](x4)
|
| 361 |
+
x16 = self.stages[2](x8)
|
| 362 |
+
x32 = self.stages[3](x16)
|
| 363 |
+
|
| 364 |
+
x16 = self.deconv32_16(x32, x16)
|
| 365 |
+
x8 = self.deconv16_8(x16, x8)
|
| 366 |
+
x4 = self.deconv8_4(x8, x4)
|
| 367 |
+
x4 = torch.cat([x4, vit_feat], dim=1)
|
| 368 |
+
x4 = self.conv4(x4)
|
| 369 |
+
return [x4, x8, x16, x32], vit_feat
|
| 370 |
+
|
| 371 |
+
|
FoundationStereo_demo/core/foundation_stereo.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch,pdb,logging,timm
|
| 11 |
+
import torchvision # Add missing torchvision import
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import sys,os
|
| 15 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 16 |
+
sys.path.append(f'{code_dir}/../')
|
| 17 |
+
from core.update import *
|
| 18 |
+
from core.extractor import *
|
| 19 |
+
from core.geometry import Combined_Geo_Encoding_Volume
|
| 20 |
+
from core.submodule import *
|
| 21 |
+
from core.utils.utils import *
|
| 22 |
+
from Utils import *
|
| 23 |
+
import time,huggingface_hub
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
autocast = torch.cuda.amp.autocast
|
| 28 |
+
except:
|
| 29 |
+
class autocast:
|
| 30 |
+
def __init__(self, enabled):
|
| 31 |
+
pass
|
| 32 |
+
def __enter__(self):
|
| 33 |
+
pass
|
| 34 |
+
def __exit__(self, *args):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def normalize_image(img):
|
| 39 |
+
'''
|
| 40 |
+
@img: (B,C,H,W) in range 0-255, RGB order
|
| 41 |
+
'''
|
| 42 |
+
tf = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
|
| 43 |
+
normalized = tf(img/255.0)
|
| 44 |
+
return normalized.contiguous() # Ensure contiguous tensor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class hourglass(nn.Module):
|
| 48 |
+
def __init__(self, cfg, in_channels, feat_dims=None):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.cfg = cfg
|
| 51 |
+
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
| 52 |
+
padding=1, stride=2, dilation=1),
|
| 53 |
+
Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17))
|
| 54 |
+
|
| 55 |
+
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
| 56 |
+
padding=1, stride=2, dilation=1),
|
| 57 |
+
Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17))
|
| 58 |
+
|
| 59 |
+
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
| 60 |
+
padding=1, stride=2, dilation=1),
|
| 61 |
+
Conv3dNormActReduced(in_channels*6, in_channels*6, kernel_size=3, kernel_disp=17))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
|
| 65 |
+
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
| 66 |
+
|
| 67 |
+
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
|
| 68 |
+
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
| 69 |
+
|
| 70 |
+
self.conv1_up = BasicConv(in_channels*2, in_channels, deconv=True, is_3d=True, bn=True,
|
| 71 |
+
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
| 72 |
+
self.conv_out = nn.Sequential(
|
| 73 |
+
Conv3dNormActReduced(in_channels, in_channels, kernel_size=3, kernel_disp=17),
|
| 74 |
+
Conv3dNormActReduced(in_channels, in_channels, kernel_size=3, kernel_disp=17),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
|
| 78 |
+
Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17),
|
| 79 |
+
Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17),)
|
| 80 |
+
|
| 81 |
+
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
|
| 82 |
+
Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17),
|
| 83 |
+
Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17))
|
| 84 |
+
self.atts = nn.ModuleDict({
|
| 85 |
+
"4": CostVolumeDisparityAttention(d_model=in_channels, nhead=4, dim_feedforward=in_channels, norm_first=False, num_transformer=4, max_len=self.cfg['max_disp']//16),
|
| 86 |
+
})
|
| 87 |
+
self.conv_patch = nn.Sequential(
|
| 88 |
+
nn.Conv3d(in_channels, in_channels, kernel_size=4, stride=4, padding=0, groups=in_channels),
|
| 89 |
+
nn.BatchNorm3d(in_channels),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.feature_att_8 = FeatureAtt(in_channels*2, feat_dims[1])
|
| 93 |
+
self.feature_att_16 = FeatureAtt(in_channels*4, feat_dims[2])
|
| 94 |
+
self.feature_att_32 = FeatureAtt(in_channels*6, feat_dims[3])
|
| 95 |
+
self.feature_att_up_16 = FeatureAtt(in_channels*4, feat_dims[2])
|
| 96 |
+
self.feature_att_up_8 = FeatureAtt(in_channels*2, feat_dims[1])
|
| 97 |
+
|
| 98 |
+
def forward(self, x, features):
|
| 99 |
+
conv1 = self.conv1(x)
|
| 100 |
+
conv1 = self.feature_att_8(conv1, features[1])
|
| 101 |
+
|
| 102 |
+
conv2 = self.conv2(conv1)
|
| 103 |
+
conv2 = self.feature_att_16(conv2, features[2])
|
| 104 |
+
|
| 105 |
+
conv3 = self.conv3(conv2)
|
| 106 |
+
conv3 = self.feature_att_32(conv3, features[3])
|
| 107 |
+
|
| 108 |
+
conv3_up = self.conv3_up(conv3)
|
| 109 |
+
conv2 = torch.cat((conv3_up, conv2), dim=1)
|
| 110 |
+
conv2 = self.agg_0(conv2)
|
| 111 |
+
conv2 = self.feature_att_up_16(conv2, features[2])
|
| 112 |
+
|
| 113 |
+
conv2_up = self.conv2_up(conv2)
|
| 114 |
+
conv1 = torch.cat((conv2_up, conv1), dim=1)
|
| 115 |
+
conv1 = self.agg_1(conv1)
|
| 116 |
+
conv1 = self.feature_att_up_8(conv1, features[1])
|
| 117 |
+
|
| 118 |
+
conv = self.conv1_up(conv1)
|
| 119 |
+
x = self.conv_patch(x)
|
| 120 |
+
x = self.atts["4"](x)
|
| 121 |
+
x = F.interpolate(x, scale_factor=4, mode='trilinear', align_corners=False)
|
| 122 |
+
conv = conv + x
|
| 123 |
+
conv = self.conv_out(conv)
|
| 124 |
+
|
| 125 |
+
return conv
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FoundationStereo(nn.Module, huggingface_hub.PyTorchModelHubMixin):
|
| 130 |
+
def __init__(self, args):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.args = args
|
| 133 |
+
|
| 134 |
+
context_dims = args.hidden_dims
|
| 135 |
+
self.cv_group = 8
|
| 136 |
+
volume_dim = 28
|
| 137 |
+
|
| 138 |
+
self.cnet = ContextNetDino(args, output_dim=[args.hidden_dims, context_dims], downsample=args.n_downsample)
|
| 139 |
+
self.update_block = BasicSelectiveMultiUpdateBlock(self.args, self.args.hidden_dims[0], volume_dim=volume_dim)
|
| 140 |
+
self.sam = SpatialAttentionExtractor()
|
| 141 |
+
self.cam = ChannelAttentionEnhancement(self.args.hidden_dims[0])
|
| 142 |
+
|
| 143 |
+
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, kernel_size=3, padding=3//2) for i in range(self.args.n_gru_layers)])
|
| 144 |
+
|
| 145 |
+
self.feature = Feature(args)
|
| 146 |
+
self.proj_cmb = nn.Conv2d(self.feature.d_out[0], 12, kernel_size=1, padding=0)
|
| 147 |
+
|
| 148 |
+
self.stem_2 = nn.Sequential(
|
| 149 |
+
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
| 150 |
+
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
|
| 151 |
+
nn.InstanceNorm2d(32), nn.ReLU()
|
| 152 |
+
)
|
| 153 |
+
self.stem_4 = nn.Sequential(
|
| 154 |
+
BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
|
| 155 |
+
nn.Conv2d(48, 48, 3, 1, 1, bias=False),
|
| 156 |
+
nn.InstanceNorm2d(48), nn.ReLU()
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
self.spx_2_gru = Conv2x(32, 32, True, bn=False)
|
| 161 |
+
self.spx_gru = nn.Sequential(
|
| 162 |
+
nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
self.corr_stem = nn.Sequential(
|
| 167 |
+
nn.Conv3d(32, volume_dim, kernel_size=1),
|
| 168 |
+
BasicConv(volume_dim, volume_dim, kernel_size=3, padding=1, is_3d=True),
|
| 169 |
+
ResnetBasicBlock3D(volume_dim, volume_dim, kernel_size=3, stride=1, padding=1),
|
| 170 |
+
ResnetBasicBlock3D(volume_dim, volume_dim, kernel_size=3, stride=1, padding=1),
|
| 171 |
+
)
|
| 172 |
+
self.corr_feature_att = FeatureAtt(volume_dim, self.feature.d_out[0])
|
| 173 |
+
self.cost_agg = hourglass(cfg=self.args, in_channels=volume_dim, feat_dims=self.feature.d_out)
|
| 174 |
+
self.classifier = nn.Sequential(
|
| 175 |
+
BasicConv(volume_dim, volume_dim//2, kernel_size=3, padding=1, is_3d=True),
|
| 176 |
+
ResnetBasicBlock3D(volume_dim//2, volume_dim//2, kernel_size=3, stride=1, padding=1),
|
| 177 |
+
nn.Conv3d(volume_dim//2, 1, kernel_size=7, padding=3),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
r = self.args.corr_radius
|
| 181 |
+
dx = torch.linspace(-r, r, 2*r+1, requires_grad=False).reshape(1, 1, 2*r+1, 1)
|
| 182 |
+
self.dx = dx
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def upsample_disp(self, disp, mask_feat_4, stem_2x):
|
| 186 |
+
|
| 187 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 188 |
+
xspx = self.spx_2_gru(mask_feat_4, stem_2x) # 1/2 resolution
|
| 189 |
+
spx_pred = self.spx_gru(xspx)
|
| 190 |
+
spx_pred = F.softmax(spx_pred, 1)
|
| 191 |
+
up_disp = context_upsample(disp*4., spx_pred).unsqueeze(1)
|
| 192 |
+
|
| 193 |
+
return up_disp.float()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False, low_memory=False, init_disp=None):
|
| 197 |
+
""" Estimate disparity between pair of frames """
|
| 198 |
+
B = len(image1)
|
| 199 |
+
low_memory = low_memory or (self.args.get('low_memory', False))
|
| 200 |
+
image1 = normalize_image(image1)
|
| 201 |
+
image2 = normalize_image(image2)
|
| 202 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 203 |
+
out, vit_feat = self.feature(torch.cat([image1, image2], dim=0))
|
| 204 |
+
vit_feat = vit_feat[:B]
|
| 205 |
+
features_left = [o[:B] for o in out]
|
| 206 |
+
features_right = [o[B:] for o in out]
|
| 207 |
+
stem_2x = self.stem_2(image1)
|
| 208 |
+
|
| 209 |
+
gwc_volume = build_gwc_volume(features_left[0], features_right[0], self.args.max_disp//4, self.cv_group) # Group-wise correlation volume (B, N_group, max_disp, H, W)
|
| 210 |
+
left_tmp = self.proj_cmb(features_left[0])
|
| 211 |
+
right_tmp = self.proj_cmb(features_right[0])
|
| 212 |
+
concat_volume = build_concat_volume(left_tmp, right_tmp, maxdisp=self.args.max_disp//4)
|
| 213 |
+
del left_tmp, right_tmp
|
| 214 |
+
comb_volume = torch.cat([gwc_volume, concat_volume], dim=1)
|
| 215 |
+
comb_volume = self.corr_stem(comb_volume)
|
| 216 |
+
comb_volume = self.corr_feature_att(comb_volume, features_left[0])
|
| 217 |
+
comb_volume = self.cost_agg(comb_volume, features_left)
|
| 218 |
+
|
| 219 |
+
# Init disp from geometry encoding volume
|
| 220 |
+
prob = F.softmax(self.classifier(comb_volume).squeeze(1), dim=1) #(B, max_disp, H, W)
|
| 221 |
+
if init_disp is None:
|
| 222 |
+
init_disp = disparity_regression(prob, self.args.max_disp//4) # Weighted sum of disparity
|
| 223 |
+
|
| 224 |
+
cnet_list = self.cnet(image1, vit_feat=vit_feat, num_layers=self.args.n_gru_layers) #(1/4, 1/8, 1/16)
|
| 225 |
+
cnet_list = list(cnet_list)
|
| 226 |
+
net_list = [torch.tanh(x[0]) for x in cnet_list] # Hidden information
|
| 227 |
+
inp_list = [torch.relu(x[1]) for x in cnet_list] # Context information list of pyramid levels
|
| 228 |
+
inp_list = [self.cam(x) * x for x in inp_list]
|
| 229 |
+
att = [self.sam(x) for x in inp_list]
|
| 230 |
+
|
| 231 |
+
geo_fn = Combined_Geo_Encoding_Volume(features_left[0].float(), features_right[0].float(), comb_volume.float(), num_levels=self.args.corr_levels, dx=self.dx)
|
| 232 |
+
b, c, h, w = features_left[0].shape
|
| 233 |
+
coords = torch.arange(w, dtype=torch.float, device=init_disp.device).reshape(1,1,w,1).repeat(b, h, 1, 1) # (B,H,W,1) Horizontal only
|
| 234 |
+
disp = init_disp.float()
|
| 235 |
+
disp_preds = []
|
| 236 |
+
|
| 237 |
+
# GRUs iterations to update disparity (1/4 resolution)
|
| 238 |
+
for itr in range(iters):
|
| 239 |
+
disp = disp.detach()
|
| 240 |
+
geo_feat = geo_fn(disp, coords, low_memory=low_memory)
|
| 241 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 242 |
+
net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, att)
|
| 243 |
+
|
| 244 |
+
disp = disp + delta_disp.float()
|
| 245 |
+
if test_mode and itr < iters-1:
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
# upsample predictions
|
| 249 |
+
disp_up = self.upsample_disp(disp.float(), mask_feat_4.float(), stem_2x.float())
|
| 250 |
+
disp_preds.append(disp_up)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if test_mode:
|
| 254 |
+
return disp_up
|
| 255 |
+
|
| 256 |
+
return init_disp, disp_preds
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def run_hierachical(self, image1, image2, iters=12, test_mode=False, low_memory=False, small_ratio=0.5):
|
| 260 |
+
B,_,H,W = image1.shape
|
| 261 |
+
img1_small = F.interpolate(image1, scale_factor=small_ratio, align_corners=False, mode='bilinear')
|
| 262 |
+
img2_small = F.interpolate(image2, scale_factor=small_ratio, align_corners=False, mode='bilinear')
|
| 263 |
+
padder = InputPadder(img1_small.shape[-2:], divis_by=32, force_square=False)
|
| 264 |
+
img1_small, img2_small = padder.pad(img1_small, img2_small)
|
| 265 |
+
disp_small = self.forward(img1_small, img2_small, test_mode=True, iters=iters, low_memory=low_memory)
|
| 266 |
+
disp_small = padder.unpad(disp_small.float())
|
| 267 |
+
disp_small_up = F.interpolate(disp_small, size=(H,W), mode='bilinear', align_corners=True) * 1/small_ratio
|
| 268 |
+
disp_small_up = disp_small_up.clip(0, None)
|
| 269 |
+
|
| 270 |
+
padder = InputPadder(image1.shape[-2:], divis_by=32, force_square=False)
|
| 271 |
+
image1, image2, disp_small_up = padder.pad(image1, image2, disp_small_up)
|
| 272 |
+
disp_small_up += padder._pad[0]
|
| 273 |
+
init_disp = F.interpolate(disp_small_up, scale_factor=0.25, mode='bilinear', align_corners=True) * 0.25 # Init disp will be 1/4
|
| 274 |
+
disp = self.forward(image1, image2, iters=iters, test_mode=test_mode, low_memory=low_memory, init_disp=init_disp)
|
| 275 |
+
disp = padder.unpad(disp.float())
|
| 276 |
+
return disp
|
| 277 |
+
|
FoundationStereo_demo/core/geometry.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch,pdb,os,sys
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from core.utils.utils import bilinear_sampler
|
| 13 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 14 |
+
sys.path.append(f'{code_dir}/../')
|
| 15 |
+
from Utils import *
|
| 16 |
+
|
| 17 |
+
class Combined_Geo_Encoding_Volume:
|
| 18 |
+
def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, dx=None):
|
| 19 |
+
self.num_levels = num_levels
|
| 20 |
+
self.geo_volume_pyramid = []
|
| 21 |
+
self.init_corr_pyramid = []
|
| 22 |
+
self.dx = dx
|
| 23 |
+
|
| 24 |
+
# all pairs correlation
|
| 25 |
+
init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2)
|
| 26 |
+
|
| 27 |
+
b, h, w, _, w2 = init_corr.shape
|
| 28 |
+
b, c, d, h, w = geo_volume.shape
|
| 29 |
+
geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d).contiguous()
|
| 30 |
+
|
| 31 |
+
init_corr = init_corr.reshape(b*h*w, 1, 1, w2)
|
| 32 |
+
self.geo_volume_pyramid.append(geo_volume)
|
| 33 |
+
self.init_corr_pyramid.append(init_corr)
|
| 34 |
+
for i in range(self.num_levels-1):
|
| 35 |
+
geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2])
|
| 36 |
+
self.geo_volume_pyramid.append(geo_volume)
|
| 37 |
+
|
| 38 |
+
for i in range(self.num_levels-1):
|
| 39 |
+
init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
|
| 40 |
+
self.init_corr_pyramid.append(init_corr)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def __call__(self, disp, coords, low_memory=False):
|
| 44 |
+
b, _, h, w = disp.shape
|
| 45 |
+
self.dx = self.dx.to(disp.device)
|
| 46 |
+
out_pyramid = []
|
| 47 |
+
for i in range(self.num_levels):
|
| 48 |
+
geo_volume = self.geo_volume_pyramid[i]
|
| 49 |
+
x0 = self.dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
| 50 |
+
y0 = torch.zeros_like(x0)
|
| 51 |
+
|
| 52 |
+
disp_lvl = torch.cat([x0,y0], dim=-1)
|
| 53 |
+
geo_volume = bilinear_sampler(geo_volume, disp_lvl, low_memory=low_memory)
|
| 54 |
+
geo_volume = geo_volume.reshape(b, h, w, -1)
|
| 55 |
+
|
| 56 |
+
init_corr = self.init_corr_pyramid[i]
|
| 57 |
+
init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + self.dx # X on right image
|
| 58 |
+
init_coords_lvl = torch.cat([init_x0,y0], dim=-1)
|
| 59 |
+
init_corr = bilinear_sampler(init_corr, init_coords_lvl, low_memory=low_memory)
|
| 60 |
+
init_corr = init_corr.reshape(b, h, w, -1)
|
| 61 |
+
|
| 62 |
+
out_pyramid.append(geo_volume)
|
| 63 |
+
out_pyramid.append(init_corr)
|
| 64 |
+
out_pyramid = torch.cat(out_pyramid, dim=-1)
|
| 65 |
+
return out_pyramid.permute(0, 3, 1, 2).contiguous() #(B,C,H,W)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def corr(fmap1, fmap2):
|
| 70 |
+
B, D, H, W1 = fmap1.shape
|
| 71 |
+
_, _, _, W2 = fmap2.shape
|
| 72 |
+
fmap1 = fmap1.reshape(B, D, H, W1)
|
| 73 |
+
fmap2 = fmap2.reshape(B, D, H, W2)
|
| 74 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 75 |
+
corr = torch.einsum('aijk,aijh->ajkh', F.normalize(fmap1.float(), dim=1), F.normalize(fmap2.float(), dim=1))
|
| 76 |
+
corr = corr.reshape(B, H, W1, 1, W2)
|
| 77 |
+
return corr
|
FoundationStereo_demo/core/submodule.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch,pdb,os,sys
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 15 |
+
sys.path.append(f'{code_dir}/../')
|
| 16 |
+
from Utils import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _is_contiguous(tensor: torch.Tensor) -> bool:
|
| 20 |
+
if torch.jit.is_scripting():
|
| 21 |
+
return tensor.is_contiguous()
|
| 22 |
+
else:
|
| 23 |
+
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 27 |
+
r""" https://huggingface.co/spaces/Roll20/pet_score/blob/b258ef28152ab0d5b377d9142a23346f863c1526/lib/timm/models/convnext.py#L85
|
| 28 |
+
LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
| 32 |
+
super().__init__(normalized_shape, eps=eps)
|
| 33 |
+
|
| 34 |
+
def forward(self, x) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
@x: (B,C,H,W)
|
| 37 |
+
"""
|
| 38 |
+
if _is_contiguous(x):
|
| 39 |
+
return F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2).contiguous()
|
| 40 |
+
else:
|
| 41 |
+
s, u = torch.var_mean(x, dim=1, keepdim=True)
|
| 42 |
+
x = (x - u) * torch.rsqrt(s + self.eps)
|
| 43 |
+
x = x * self.weight[:, None, None] + self.bias[:, None, None]
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BasicConv(nn.Module):
|
| 49 |
+
|
| 50 |
+
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, norm='batch', **kwargs):
|
| 51 |
+
super(BasicConv, self).__init__()
|
| 52 |
+
|
| 53 |
+
self.relu = relu
|
| 54 |
+
self.use_bn = bn
|
| 55 |
+
self.bn = nn.Identity()
|
| 56 |
+
if is_3d:
|
| 57 |
+
if deconv:
|
| 58 |
+
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
| 59 |
+
else:
|
| 60 |
+
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
| 61 |
+
if self.use_bn:
|
| 62 |
+
if norm=='batch':
|
| 63 |
+
self.bn = nn.BatchNorm3d(out_channels)
|
| 64 |
+
elif norm=='instance':
|
| 65 |
+
self.bn = nn.InstanceNorm3d(out_channels)
|
| 66 |
+
else:
|
| 67 |
+
if deconv:
|
| 68 |
+
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
| 69 |
+
else:
|
| 70 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
| 71 |
+
if self.use_bn:
|
| 72 |
+
if norm=='batch':
|
| 73 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 74 |
+
elif norm=='instance':
|
| 75 |
+
self.bn = nn.InstanceNorm2d(out_channels)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
x = self.conv(x)
|
| 79 |
+
if self.use_bn:
|
| 80 |
+
x = self.bn(x)
|
| 81 |
+
if self.relu:
|
| 82 |
+
x = nn.LeakyReLU()(x)#, inplace=True)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Conv3dNormActReduced(nn.Module):
|
| 87 |
+
def __init__(self, C_in, C_out, hidden=None, kernel_size=3, kernel_disp=None, stride=1, norm=nn.BatchNorm3d):
|
| 88 |
+
super().__init__()
|
| 89 |
+
if kernel_disp is None:
|
| 90 |
+
kernel_disp = kernel_size
|
| 91 |
+
if hidden is None:
|
| 92 |
+
hidden = C_out
|
| 93 |
+
self.conv1 = nn.Sequential(
|
| 94 |
+
nn.Conv3d(C_in, hidden, kernel_size=(1,kernel_size,kernel_size), padding=(0, kernel_size//2, kernel_size//2), stride=(1, stride, stride)),
|
| 95 |
+
norm(hidden),
|
| 96 |
+
nn.ReLU(),
|
| 97 |
+
)
|
| 98 |
+
self.conv2 = nn.Sequential(
|
| 99 |
+
nn.Conv3d(hidden, C_out, kernel_size=(kernel_disp, 1, 1), padding=(kernel_disp//2, 0, 0), stride=(stride, 1, 1)),
|
| 100 |
+
norm(C_out),
|
| 101 |
+
nn.ReLU(),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
"""
|
| 107 |
+
@x: (B,C,D,H,W)
|
| 108 |
+
"""
|
| 109 |
+
x = self.conv1(x)
|
| 110 |
+
x = self.conv2(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ResnetBasicBlock(nn.Module):
|
| 117 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm2d, bias=False):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.norm_layer = norm_layer
|
| 120 |
+
if groups != 1 or base_width != 64:
|
| 121 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 122 |
+
if dilation > 1:
|
| 123 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 124 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 125 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
|
| 126 |
+
if self.norm_layer is not None:
|
| 127 |
+
self.bn1 = norm_layer(planes)
|
| 128 |
+
self.relu = nn.ReLU(inplace=True)
|
| 129 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
|
| 130 |
+
if self.norm_layer is not None:
|
| 131 |
+
self.bn2 = norm_layer(planes)
|
| 132 |
+
self.downsample = downsample
|
| 133 |
+
self.stride = stride
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
identity = x
|
| 138 |
+
|
| 139 |
+
out = self.conv1(x)
|
| 140 |
+
if self.norm_layer is not None:
|
| 141 |
+
out = self.bn1(out)
|
| 142 |
+
out = self.relu(out)
|
| 143 |
+
|
| 144 |
+
out = self.conv2(out)
|
| 145 |
+
if self.norm_layer is not None:
|
| 146 |
+
out = self.bn2(out)
|
| 147 |
+
|
| 148 |
+
if self.downsample is not None:
|
| 149 |
+
identity = self.downsample(x)
|
| 150 |
+
out += identity
|
| 151 |
+
out = self.relu(out)
|
| 152 |
+
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class ResnetBasicBlock3D(nn.Module):
|
| 157 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm3d, bias=False):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.norm_layer = norm_layer
|
| 160 |
+
if groups != 1 or base_width != 64:
|
| 161 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 162 |
+
if dilation > 1:
|
| 163 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 164 |
+
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
|
| 165 |
+
if self.norm_layer is not None:
|
| 166 |
+
self.bn1 = norm_layer(planes)
|
| 167 |
+
self.relu = nn.ReLU(inplace=True)
|
| 168 |
+
self.conv2 = nn.Conv3d(planes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding)
|
| 169 |
+
if self.norm_layer is not None:
|
| 170 |
+
self.bn2 = norm_layer(planes)
|
| 171 |
+
self.downsample = downsample
|
| 172 |
+
self.stride = stride
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
identity = x
|
| 177 |
+
|
| 178 |
+
out = self.conv1(x)
|
| 179 |
+
if self.norm_layer is not None:
|
| 180 |
+
out = self.bn1(out)
|
| 181 |
+
out = self.relu(out)
|
| 182 |
+
|
| 183 |
+
out = self.conv2(out)
|
| 184 |
+
if self.norm_layer is not None:
|
| 185 |
+
out = self.bn2(out)
|
| 186 |
+
|
| 187 |
+
if self.downsample is not None:
|
| 188 |
+
identity = self.downsample(x)
|
| 189 |
+
out += identity
|
| 190 |
+
out = self.relu(out)
|
| 191 |
+
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class FlashMultiheadAttention(nn.Module):
|
| 196 |
+
def __init__(self, embed_dim, num_heads):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.num_heads = num_heads
|
| 199 |
+
self.embed_dim = embed_dim
|
| 200 |
+
self.head_dim = embed_dim // num_heads
|
| 201 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 202 |
+
|
| 203 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 204 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 205 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 206 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 207 |
+
|
| 208 |
+
def forward(self, query, key, value, attn_mask=None, window_size=(-1,-1)):
|
| 209 |
+
"""
|
| 210 |
+
@query: (B,L,C)
|
| 211 |
+
"""
|
| 212 |
+
B,L,C = query.shape
|
| 213 |
+
Q = self.q_proj(query)
|
| 214 |
+
K = self.k_proj(key)
|
| 215 |
+
V = self.v_proj(value)
|
| 216 |
+
|
| 217 |
+
Q = Q.view(Q.size(0), Q.size(1), self.num_heads, self.head_dim)
|
| 218 |
+
K = K.view(K.size(0), K.size(1), self.num_heads, self.head_dim)
|
| 219 |
+
V = V.view(V.size(0), V.size(1), self.num_heads, self.head_dim)
|
| 220 |
+
|
| 221 |
+
attn_output = F.scaled_dot_product_attention(Q, K, V)
|
| 222 |
+
|
| 223 |
+
attn_output = attn_output.reshape(B,L,-1)
|
| 224 |
+
output = self.out_proj(attn_output)
|
| 225 |
+
|
| 226 |
+
return output
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class FlashAttentionTransformerEncoderLayer(nn.Module):
|
| 231 |
+
def __init__(self, embed_dim, num_heads, dim_feedforward, dropout=0.1, act=nn.GELU, norm=nn.LayerNorm):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.self_attn = FlashMultiheadAttention(embed_dim, num_heads)
|
| 234 |
+
self.act = act()
|
| 235 |
+
|
| 236 |
+
self.linear1 = nn.Linear(embed_dim, dim_feedforward)
|
| 237 |
+
self.dropout = nn.Dropout(dropout)
|
| 238 |
+
self.linear2 = nn.Linear(dim_feedforward, embed_dim)
|
| 239 |
+
|
| 240 |
+
self.norm1 = norm(embed_dim)
|
| 241 |
+
self.norm2 = norm(embed_dim)
|
| 242 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 243 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 244 |
+
|
| 245 |
+
def forward(self, src, src_mask=None, window_size=(-1, -1)):
|
| 246 |
+
src2 = self.self_attn(src, src, src, src_mask, window_size=window_size)
|
| 247 |
+
src = src + self.dropout1(src2)
|
| 248 |
+
src = self.norm1(src)
|
| 249 |
+
|
| 250 |
+
src2 = self.linear2(self.dropout(self.act(self.linear1(src))))
|
| 251 |
+
src = src + self.dropout2(src2)
|
| 252 |
+
src = self.norm2(src)
|
| 253 |
+
|
| 254 |
+
return src
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class UpsampleConv(nn.Module):
|
| 259 |
+
def __init__(self, C_in, C_out, is_3d=False, kernel_size=3, bias=True, stride=1, padding=1):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.is_3d = is_3d
|
| 262 |
+
if is_3d:
|
| 263 |
+
self.conv = nn.Conv3d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=bias)
|
| 264 |
+
else:
|
| 265 |
+
self.conv = nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=bias)
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
if self.is_3d:
|
| 269 |
+
mode = 'trilinear'
|
| 270 |
+
else:
|
| 271 |
+
mode = 'bilinear'
|
| 272 |
+
x = F.interpolate(x, size=None, scale_factor=2, align_corners=False, mode=mode)
|
| 273 |
+
x = self.conv(x)
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class Conv2x(nn.Module):
|
| 279 |
+
|
| 280 |
+
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
|
| 281 |
+
super(Conv2x, self).__init__()
|
| 282 |
+
self.concat = concat
|
| 283 |
+
self.is_3d = is_3d
|
| 284 |
+
if deconv and is_3d:
|
| 285 |
+
kernel = (4, 4, 4)
|
| 286 |
+
elif deconv:
|
| 287 |
+
kernel = 4
|
| 288 |
+
else:
|
| 289 |
+
kernel = 3
|
| 290 |
+
|
| 291 |
+
if deconv and is_3d and keep_dispc:
|
| 292 |
+
kernel = (1, 4, 4)
|
| 293 |
+
stride = (1, 2, 2)
|
| 294 |
+
padding = (0, 1, 1)
|
| 295 |
+
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=bn, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
| 296 |
+
else:
|
| 297 |
+
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=bn, relu=True, kernel_size=kernel, stride=2, padding=1)
|
| 298 |
+
|
| 299 |
+
if self.concat:
|
| 300 |
+
mul = 2 if keep_concat else 1
|
| 301 |
+
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
| 302 |
+
else:
|
| 303 |
+
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
| 304 |
+
|
| 305 |
+
def forward(self, x, rem):
|
| 306 |
+
x = self.conv1(x)
|
| 307 |
+
if x.shape != rem.shape:
|
| 308 |
+
x = F.interpolate(x, size=(rem.shape[-2], rem.shape[-1]), mode='bilinear')
|
| 309 |
+
if self.concat:
|
| 310 |
+
x = torch.cat((x, rem), 1)
|
| 311 |
+
else:
|
| 312 |
+
x = x + rem
|
| 313 |
+
x = self.conv2(x)
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class BasicConv_IN(nn.Module):
|
| 318 |
+
|
| 319 |
+
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
|
| 320 |
+
super(BasicConv_IN, self).__init__()
|
| 321 |
+
|
| 322 |
+
self.relu = relu
|
| 323 |
+
self.use_in = IN
|
| 324 |
+
if is_3d:
|
| 325 |
+
if deconv:
|
| 326 |
+
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
| 327 |
+
else:
|
| 328 |
+
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
| 329 |
+
self.IN = nn.InstanceNorm3d(out_channels)
|
| 330 |
+
else:
|
| 331 |
+
if deconv:
|
| 332 |
+
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
| 333 |
+
else:
|
| 334 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
| 335 |
+
self.IN = nn.InstanceNorm2d(out_channels)
|
| 336 |
+
|
| 337 |
+
def forward(self, x):
|
| 338 |
+
x = self.conv(x)
|
| 339 |
+
if self.use_in:
|
| 340 |
+
x = self.IN(x)
|
| 341 |
+
if self.relu:
|
| 342 |
+
x = nn.LeakyReLU()(x)#, inplace=True)
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class Conv2x_IN(nn.Module):
|
| 347 |
+
|
| 348 |
+
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
|
| 349 |
+
super(Conv2x_IN, self).__init__()
|
| 350 |
+
self.concat = concat
|
| 351 |
+
self.is_3d = is_3d
|
| 352 |
+
if deconv and is_3d:
|
| 353 |
+
kernel = (4, 4, 4)
|
| 354 |
+
elif deconv:
|
| 355 |
+
kernel = 4
|
| 356 |
+
else:
|
| 357 |
+
kernel = 3
|
| 358 |
+
|
| 359 |
+
if deconv and is_3d and keep_dispc:
|
| 360 |
+
kernel = (1, 4, 4)
|
| 361 |
+
stride = (1, 2, 2)
|
| 362 |
+
padding = (0, 1, 1)
|
| 363 |
+
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
| 364 |
+
else:
|
| 365 |
+
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
| 366 |
+
|
| 367 |
+
if self.concat:
|
| 368 |
+
mul = 2 if keep_concat else 1
|
| 369 |
+
self.conv2 = ResnetBasicBlock(out_channels*2, out_channels*mul, kernel_size=3, stride=1, padding=1, norm_layer=nn.InstanceNorm2d)
|
| 370 |
+
else:
|
| 371 |
+
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
| 372 |
+
|
| 373 |
+
def forward(self, x, rem):
|
| 374 |
+
x = self.conv1(x)
|
| 375 |
+
if x.shape != rem.shape:
|
| 376 |
+
x = F.interpolate(x, size=(rem.shape[-2], rem.shape[-1]), mode='bilinear')
|
| 377 |
+
if self.concat:
|
| 378 |
+
x = torch.cat((x, rem), 1)
|
| 379 |
+
else:
|
| 380 |
+
x = x + rem
|
| 381 |
+
x = self.conv2(x)
|
| 382 |
+
return x
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def groupwise_correlation(fea1, fea2, num_groups):
|
| 386 |
+
B, C, H, W = fea1.shape
|
| 387 |
+
assert C % num_groups == 0, f"C:{C}, num_groups:{num_groups}"
|
| 388 |
+
channels_per_group = C // num_groups
|
| 389 |
+
fea1 = fea1.reshape(B, num_groups, channels_per_group, H, W)
|
| 390 |
+
fea2 = fea2.reshape(B, num_groups, channels_per_group, H, W)
|
| 391 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 392 |
+
cost = (F.normalize(fea1.float(), dim=2) * F.normalize(fea2.float(), dim=2)).sum(dim=2) #!NOTE Divide first for numerical stability
|
| 393 |
+
assert cost.shape == (B, num_groups, H, W)
|
| 394 |
+
return cost
|
| 395 |
+
|
| 396 |
+
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups, stride=1):
|
| 397 |
+
"""
|
| 398 |
+
@refimg_fea: left image feature
|
| 399 |
+
@targetimg_fea: right image feature
|
| 400 |
+
"""
|
| 401 |
+
B, C, H, W = refimg_fea.shape
|
| 402 |
+
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
|
| 403 |
+
for i in range(maxdisp):
|
| 404 |
+
if i > 0:
|
| 405 |
+
volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], num_groups)
|
| 406 |
+
else:
|
| 407 |
+
volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups)
|
| 408 |
+
volume = volume.contiguous()
|
| 409 |
+
return volume
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def build_concat_volume(refimg_fea, targetimg_fea, maxdisp):
|
| 414 |
+
B, C, H, W = refimg_fea.shape
|
| 415 |
+
volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W])
|
| 416 |
+
for i in range(maxdisp):
|
| 417 |
+
if i > 0:
|
| 418 |
+
volume[:, :C, i, :, :] = refimg_fea[:, :, :, :]
|
| 419 |
+
volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i]
|
| 420 |
+
else:
|
| 421 |
+
volume[:, :C, i, :, :] = refimg_fea
|
| 422 |
+
volume[:, C:, i, :, :] = targetimg_fea
|
| 423 |
+
volume = volume.contiguous()
|
| 424 |
+
return volume
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def disparity_regression(x, maxdisp):
|
| 429 |
+
assert len(x.shape) == 4
|
| 430 |
+
disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
|
| 431 |
+
disp_values = disp_values.reshape(1, maxdisp, 1, 1)
|
| 432 |
+
return torch.sum(x * disp_values, 1, keepdim=True)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class FeatureAtt(nn.Module):
|
| 436 |
+
def __init__(self, cv_chan, feat_chan):
|
| 437 |
+
super(FeatureAtt, self).__init__()
|
| 438 |
+
|
| 439 |
+
self.feat_att = nn.Sequential(
|
| 440 |
+
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
|
| 441 |
+
nn.Conv2d(feat_chan//2, cv_chan, 1)
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def forward(self, cv, feat):
|
| 445 |
+
'''
|
| 446 |
+
@cv: cost volume (B,C,D,H,W)
|
| 447 |
+
@feat: (B,C,H,W)
|
| 448 |
+
'''
|
| 449 |
+
feat_att = self.feat_att(feat).unsqueeze(2) #(B,C,1,H,W)
|
| 450 |
+
cv = torch.sigmoid(feat_att)*cv
|
| 451 |
+
return cv
|
| 452 |
+
|
| 453 |
+
def context_upsample(disp_low, up_weights):
|
| 454 |
+
"""
|
| 455 |
+
@disp_low: (b,1,h,w) 1/4 resolution
|
| 456 |
+
@up_weights: (b,9,4*h,4*w) Image resolution
|
| 457 |
+
"""
|
| 458 |
+
b, c, h, w = disp_low.shape
|
| 459 |
+
|
| 460 |
+
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
|
| 461 |
+
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
|
| 462 |
+
|
| 463 |
+
disp = (disp_unfold*up_weights).sum(1)
|
| 464 |
+
|
| 465 |
+
return disp
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class PositionalEmbedding(nn.Module):
|
| 470 |
+
def __init__(self, d_model, max_len=512):
|
| 471 |
+
super().__init__()
|
| 472 |
+
|
| 473 |
+
# Compute the positional encodings once in log space.
|
| 474 |
+
pe = torch.zeros(max_len, d_model).float()
|
| 475 |
+
pe.require_grad = False
|
| 476 |
+
|
| 477 |
+
position = torch.arange(0, max_len).float().unsqueeze(1) #(N,1)
|
| 478 |
+
div_term = (torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)).exp()[None]
|
| 479 |
+
|
| 480 |
+
pe[:, 0::2] = torch.sin(position * div_term) #(N, d_model/2)
|
| 481 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 482 |
+
|
| 483 |
+
pe = pe.unsqueeze(0)
|
| 484 |
+
self.pe = pe
|
| 485 |
+
# self.register_buffer('pe', pe) #(1, max_len, D)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def forward(self, x, resize_embed=False):
|
| 489 |
+
'''
|
| 490 |
+
@x: (B,N,D)
|
| 491 |
+
'''
|
| 492 |
+
self.pe = self.pe.to(x.device).to(x.dtype)
|
| 493 |
+
pe = self.pe
|
| 494 |
+
if pe.shape[1]<x.shape[1]:
|
| 495 |
+
if resize_embed:
|
| 496 |
+
pe = F.interpolate(pe.permute(0,2,1), size=x.shape[1], mode='linear', align_corners=False).permute(0,2,1)
|
| 497 |
+
else:
|
| 498 |
+
raise RuntimeError(f'x:{x.shape}, pe:{pe.shape}')
|
| 499 |
+
return x + pe[:, :x.size(1)]
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class CostVolumeDisparityAttention(nn.Module):
|
| 504 |
+
def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, act=nn.GELU, norm_first=False, num_transformer=6, max_len=512, resize_embed=False):
|
| 505 |
+
super().__init__()
|
| 506 |
+
self.resize_embed = resize_embed
|
| 507 |
+
self.sa = nn.ModuleList([])
|
| 508 |
+
for _ in range(num_transformer):
|
| 509 |
+
self.sa.append(FlashAttentionTransformerEncoderLayer(embed_dim=d_model, num_heads=nhead, dim_feedforward=dim_feedforward, act=act, dropout=dropout))
|
| 510 |
+
self.pos_embed0 = PositionalEmbedding(d_model, max_len=max_len)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def forward(self, cv, window_size=(-1,-1)):
|
| 514 |
+
"""
|
| 515 |
+
@cv: (B,C,D,H,W) where D is max disparity
|
| 516 |
+
"""
|
| 517 |
+
x = cv
|
| 518 |
+
B,C,D,H,W = x.shape
|
| 519 |
+
x = x.permute(0,3,4,2,1).reshape(B*H*W, D, C)
|
| 520 |
+
x = self.pos_embed0(x, resize_embed=self.resize_embed) #!NOTE No resize since disparity is pre-determined
|
| 521 |
+
for i in range(len(self.sa)):
|
| 522 |
+
x = self.sa[i](x, window_size=window_size)
|
| 523 |
+
x = x.reshape(B,H,W,D,C).permute(0,4,3,1,2)
|
| 524 |
+
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class ChannelAttentionEnhancement(nn.Module):
|
| 530 |
+
def __init__(self, in_planes, ratio=16):
|
| 531 |
+
super(ChannelAttentionEnhancement, self).__init__()
|
| 532 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 533 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
| 534 |
+
|
| 535 |
+
self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
|
| 536 |
+
nn.ReLU(),
|
| 537 |
+
nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
|
| 538 |
+
self.sigmoid = nn.Sigmoid()
|
| 539 |
+
|
| 540 |
+
def forward(self, x):
|
| 541 |
+
avg_out = self.fc(self.avg_pool(x))
|
| 542 |
+
max_out = self.fc(self.max_pool(x))
|
| 543 |
+
out = avg_out + max_out
|
| 544 |
+
return self.sigmoid(out)
|
| 545 |
+
|
| 546 |
+
class SpatialAttentionExtractor(nn.Module):
|
| 547 |
+
def __init__(self, kernel_size=7):
|
| 548 |
+
super(SpatialAttentionExtractor, self).__init__()
|
| 549 |
+
|
| 550 |
+
self.samconv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
|
| 551 |
+
self.sigmoid = nn.Sigmoid()
|
| 552 |
+
|
| 553 |
+
def forward(self, x):
|
| 554 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
| 555 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
| 556 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
| 557 |
+
x = self.samconv(x)
|
| 558 |
+
return self.sigmoid(x)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class EdgeNextConvEncoder(nn.Module):
|
| 563 |
+
def __init__(self, dim, layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7, norm='layer'):
|
| 564 |
+
super().__init__()
|
| 565 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
|
| 566 |
+
if norm=='layer':
|
| 567 |
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 568 |
+
else:
|
| 569 |
+
self.norm = nn.Identity()
|
| 570 |
+
self.pwconv1 = nn.Linear(dim, expan_ratio * dim)
|
| 571 |
+
self.act = nn.GELU()
|
| 572 |
+
self.pwconv2 = nn.Linear(expan_ratio * dim, dim)
|
| 573 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None
|
| 574 |
+
|
| 575 |
+
def forward(self, x):
|
| 576 |
+
input = x
|
| 577 |
+
x = self.dwconv(x)
|
| 578 |
+
x = self.norm(x)
|
| 579 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 580 |
+
x = self.pwconv1(x)
|
| 581 |
+
x = self.act(x)
|
| 582 |
+
x = self.pwconv2(x)
|
| 583 |
+
if self.gamma is not None:
|
| 584 |
+
x = self.gamma * x
|
| 585 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 586 |
+
|
| 587 |
+
x = input + x
|
| 588 |
+
return x
|
FoundationStereo_demo/core/update.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch,pdb,os,sys
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from torch import einsum
|
| 15 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 16 |
+
sys.path.append(f'{code_dir}/../')
|
| 17 |
+
from core.submodule import *
|
| 18 |
+
from core.extractor import *
|
| 19 |
+
|
| 20 |
+
class DispHead(nn.Module):
|
| 21 |
+
def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
|
| 22 |
+
super(DispHead, self).__init__()
|
| 23 |
+
self.conv = nn.Sequential(
|
| 24 |
+
nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1),
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
EdgeNextConvEncoder(input_dim, expan_ratio=4, kernel_size=7, norm=None),
|
| 27 |
+
EdgeNextConvEncoder(input_dim, expan_ratio=4, kernel_size=7, norm=None),
|
| 28 |
+
nn.Conv2d(input_dim, output_dim, 3, padding=1),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return self.conv(x)
|
| 33 |
+
|
| 34 |
+
class ConvGRU(nn.Module):
|
| 35 |
+
def __init__(self, hidden_dim, input_dim, kernel_size=3):
|
| 36 |
+
super(ConvGRU, self).__init__()
|
| 37 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
| 38 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
| 39 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
| 40 |
+
|
| 41 |
+
def forward(self, h, cz, cr, cq, *x_list):
|
| 42 |
+
x = torch.cat(x_list, dim=1)
|
| 43 |
+
hx = torch.cat([h, x], dim=1)
|
| 44 |
+
z = torch.sigmoid(self.convz(hx) + cz)
|
| 45 |
+
r = torch.sigmoid(self.convr(hx) + cr)
|
| 46 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
|
| 47 |
+
h = (1-z) * h + z * q
|
| 48 |
+
return h
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class BasicMotionEncoder(nn.Module):
|
| 52 |
+
def __init__(self, args, ngroup=8):
|
| 53 |
+
super(BasicMotionEncoder, self).__init__()
|
| 54 |
+
self.args = args
|
| 55 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1) * (ngroup+1)
|
| 56 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
| 57 |
+
self.convc2 = nn.Conv2d(256, 256, 3, padding=1)
|
| 58 |
+
self.convd1 = nn.Conv2d(1, 64, 7, padding=3)
|
| 59 |
+
self.convd2 = nn.Conv2d(64, 64, 3, padding=1)
|
| 60 |
+
self.conv = nn.Conv2d(64+256, 128-1, 3, padding=1)
|
| 61 |
+
|
| 62 |
+
def forward(self, disp, corr):
|
| 63 |
+
cor = F.relu(self.convc1(corr))
|
| 64 |
+
cor = F.relu(self.convc2(cor))
|
| 65 |
+
disp_ = F.relu(self.convd1(disp))
|
| 66 |
+
disp_ = F.relu(self.convd2(disp_))
|
| 67 |
+
|
| 68 |
+
cor_disp = torch.cat([cor, disp_], dim=1)
|
| 69 |
+
out = F.relu(self.conv(cor_disp))
|
| 70 |
+
return torch.cat([out, disp], dim=1)
|
| 71 |
+
|
| 72 |
+
def pool2x(x):
|
| 73 |
+
return F.avg_pool2d(x, 3, stride=2, padding=1)
|
| 74 |
+
|
| 75 |
+
def pool4x(x):
|
| 76 |
+
return F.avg_pool2d(x, 5, stride=4, padding=1)
|
| 77 |
+
|
| 78 |
+
def interp(x, dest):
|
| 79 |
+
interp_args = {'mode': 'bilinear', 'align_corners': True}
|
| 80 |
+
return F.interpolate(x, dest.shape[2:], **interp_args)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class RaftConvGRU(nn.Module):
|
| 84 |
+
def __init__(self, hidden_dim=128, input_dim=256, kernel_size=3):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
|
| 87 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
|
| 88 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
|
| 89 |
+
|
| 90 |
+
def forward(self, h, x, hx):
|
| 91 |
+
z = torch.sigmoid(self.convz(hx))
|
| 92 |
+
r = torch.sigmoid(self.convr(hx))
|
| 93 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
| 94 |
+
h = (1-z) * h + z * q
|
| 95 |
+
return h
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SelectiveConvGRU(nn.Module):
|
| 99 |
+
def __init__(self, hidden_dim=128, input_dim=256, small_kernel_size=1, large_kernel_size=3, patch_size=None):
|
| 100 |
+
super(SelectiveConvGRU, self).__init__()
|
| 101 |
+
self.conv0 = nn.Sequential(
|
| 102 |
+
nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1),
|
| 103 |
+
nn.ReLU(),
|
| 104 |
+
)
|
| 105 |
+
self.conv1 = nn.Sequential(
|
| 106 |
+
nn.Conv2d(input_dim+hidden_dim, input_dim+hidden_dim, kernel_size=3, padding=1),
|
| 107 |
+
nn.ReLU(),
|
| 108 |
+
)
|
| 109 |
+
self.small_gru = RaftConvGRU(hidden_dim, input_dim, small_kernel_size)
|
| 110 |
+
self.large_gru = RaftConvGRU(hidden_dim, input_dim, large_kernel_size)
|
| 111 |
+
|
| 112 |
+
def forward(self, att, h, *x):
|
| 113 |
+
x = torch.cat(x, dim=1)
|
| 114 |
+
x = self.conv0(x)
|
| 115 |
+
hx = torch.cat([x, h], dim=1)
|
| 116 |
+
hx = self.conv1(hx)
|
| 117 |
+
h = self.small_gru(h, x, hx) * att + self.large_gru(h, x, hx) * (1 - att)
|
| 118 |
+
|
| 119 |
+
return h
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class BasicSelectiveMultiUpdateBlock(nn.Module):
|
| 123 |
+
def __init__(self, args, hidden_dim=128, volume_dim=8):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.args = args
|
| 126 |
+
self.encoder = BasicMotionEncoder(args, volume_dim)
|
| 127 |
+
|
| 128 |
+
if args.n_gru_layers == 3:
|
| 129 |
+
self.gru16 = SelectiveConvGRU(hidden_dim, hidden_dim * 2)
|
| 130 |
+
if args.n_gru_layers >= 2:
|
| 131 |
+
self.gru08 = SelectiveConvGRU(hidden_dim, hidden_dim * (args.n_gru_layers == 3) + hidden_dim * 2)
|
| 132 |
+
self.gru04 = SelectiveConvGRU(hidden_dim, hidden_dim * (args.n_gru_layers > 1) + hidden_dim * 2)
|
| 133 |
+
self.disp_head = DispHead(hidden_dim, 256)
|
| 134 |
+
self.mask = nn.Sequential(
|
| 135 |
+
nn.Conv2d(128, 64, 3, padding=1),
|
| 136 |
+
nn.ReLU(inplace=True),
|
| 137 |
+
nn.Conv2d(64, 32, 3, padding=1),
|
| 138 |
+
nn.ReLU(inplace=True),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, net, inp, corr, disp, att):
|
| 142 |
+
if self.args.n_gru_layers == 3:
|
| 143 |
+
net[2] = self.gru16(att[2], net[2], inp[2], pool2x(net[1]))
|
| 144 |
+
if self.args.n_gru_layers >= 2:
|
| 145 |
+
if self.args.n_gru_layers > 2:
|
| 146 |
+
net[1] = self.gru08(att[1], net[1], inp[1], pool2x(net[0]), interp(net[2], net[1]))
|
| 147 |
+
else:
|
| 148 |
+
net[1] = self.gru08(att[1], net[1], inp[1], pool2x(net[0]))
|
| 149 |
+
|
| 150 |
+
motion_features = self.encoder(disp, corr)
|
| 151 |
+
motion_features = torch.cat([inp[0], motion_features], dim=1)
|
| 152 |
+
if self.args.n_gru_layers > 1:
|
| 153 |
+
net[0] = self.gru04(att[0], net[0], motion_features, interp(net[1], net[0]))
|
| 154 |
+
|
| 155 |
+
delta_disp = self.disp_head(net[0])
|
| 156 |
+
|
| 157 |
+
# scale mask to balence gradients
|
| 158 |
+
mask = .25 * self.mask(net[0])
|
| 159 |
+
return net, mask, delta_disp
|
FoundationStereo_demo/core/utils/utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch,pdb,logging
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
from scipy import interpolate
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class InputPadder:
|
| 18 |
+
""" Pads images such that dimensions are divisible by 8 """
|
| 19 |
+
def __init__(self, dims, mode='sintel', divis_by=8, force_square=False):
|
| 20 |
+
self.ht, self.wd = dims[-2:]
|
| 21 |
+
if force_square:
|
| 22 |
+
max_side = max(self.ht, self.wd)
|
| 23 |
+
pad_ht = ((max_side // divis_by) + 1) * divis_by - self.ht
|
| 24 |
+
pad_wd = ((max_side // divis_by) + 1) * divis_by - self.wd
|
| 25 |
+
else:
|
| 26 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
| 27 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
| 28 |
+
if mode == 'sintel':
|
| 29 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 30 |
+
else:
|
| 31 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
| 32 |
+
|
| 33 |
+
def pad(self, *inputs):
|
| 34 |
+
assert all((x.ndim == 4) for x in inputs)
|
| 35 |
+
# Ensure padded tensors are contiguous to avoid cuDNN issues
|
| 36 |
+
return [F.pad(x, self._pad, mode='replicate').contiguous() for x in inputs]
|
| 37 |
+
|
| 38 |
+
def unpad(self, x):
|
| 39 |
+
assert x.ndim == 4
|
| 40 |
+
ht, wd = x.shape[-2:]
|
| 41 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 42 |
+
# Ensure unpadded tensor is contiguous
|
| 43 |
+
return x[..., c[0]:c[1], c[2]:c[3]].contiguous()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False, low_memory=False):
|
| 47 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
| 48 |
+
H, W = img.shape[-2:]
|
| 49 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
| 50 |
+
xgrid = 2*xgrid/(W-1) - 1 # Normalize to [-1,1]
|
| 51 |
+
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
|
| 52 |
+
grid = torch.cat([xgrid, ygrid], dim=-1).to(img.dtype).contiguous()
|
| 53 |
+
img = F.grid_sample(img, grid, align_corners=True).contiguous()
|
| 54 |
+
if mask:
|
| 55 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 56 |
+
return img, mask.float().contiguous()
|
| 57 |
+
return img
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def coords_grid(batch, ht, wd):
|
| 61 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
| 62 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 63 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 64 |
+
|
FoundationStereo_demo/depth_anything/LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
FoundationStereo_demo/depth_anything/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# depth_anything package
|
| 2 |
+
# This file allows depth_anything to be imported as a package
|
FoundationStereo_demo/depth_anything/blocks.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
| 5 |
+
scratch = nn.Module()
|
| 6 |
+
|
| 7 |
+
out_shape1 = out_shape
|
| 8 |
+
out_shape2 = out_shape
|
| 9 |
+
out_shape3 = out_shape
|
| 10 |
+
if len(in_shape) >= 4:
|
| 11 |
+
out_shape4 = out_shape
|
| 12 |
+
|
| 13 |
+
if expand:
|
| 14 |
+
out_shape1 = out_shape
|
| 15 |
+
out_shape2 = out_shape*2
|
| 16 |
+
out_shape3 = out_shape*4
|
| 17 |
+
if len(in_shape) >= 4:
|
| 18 |
+
out_shape4 = out_shape*8
|
| 19 |
+
|
| 20 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 21 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 22 |
+
)
|
| 23 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 24 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 25 |
+
)
|
| 26 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 27 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 28 |
+
)
|
| 29 |
+
if len(in_shape) >= 4:
|
| 30 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 31 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return scratch
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ResidualConvUnit(nn.Module):
|
| 38 |
+
"""Residual convolution module.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, features, activation, bn):
|
| 42 |
+
"""Init.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
features (int): number of features
|
| 46 |
+
"""
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.bn = bn
|
| 50 |
+
|
| 51 |
+
self.groups=1
|
| 52 |
+
|
| 53 |
+
self.conv1 = nn.Conv2d(
|
| 54 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.conv2 = nn.Conv2d(
|
| 58 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if self.bn==True:
|
| 62 |
+
self.bn1 = nn.BatchNorm2d(features)
|
| 63 |
+
self.bn2 = nn.BatchNorm2d(features)
|
| 64 |
+
|
| 65 |
+
self.activation = activation
|
| 66 |
+
|
| 67 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
"""Forward pass.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
x (tensor): input
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
tensor: output
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
out = self.activation(x)
|
| 80 |
+
out = self.conv1(out)
|
| 81 |
+
if self.bn==True:
|
| 82 |
+
out = self.bn1(out)
|
| 83 |
+
|
| 84 |
+
out = self.activation(out)
|
| 85 |
+
out = self.conv2(out)
|
| 86 |
+
if self.bn==True:
|
| 87 |
+
out = self.bn2(out)
|
| 88 |
+
|
| 89 |
+
if self.groups > 1:
|
| 90 |
+
out = self.conv_merge(out)
|
| 91 |
+
|
| 92 |
+
return self.skip_add.add(out, x)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FeatureFusionBlock(nn.Module):
|
| 96 |
+
"""Feature fusion block.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
| 100 |
+
"""Init.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
features (int): number of features
|
| 104 |
+
"""
|
| 105 |
+
super(FeatureFusionBlock, self).__init__()
|
| 106 |
+
|
| 107 |
+
self.deconv = deconv
|
| 108 |
+
self.align_corners = align_corners
|
| 109 |
+
|
| 110 |
+
self.groups=1
|
| 111 |
+
|
| 112 |
+
self.expand = expand
|
| 113 |
+
out_features = features
|
| 114 |
+
if self.expand==True:
|
| 115 |
+
out_features = features//2
|
| 116 |
+
|
| 117 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
| 118 |
+
|
| 119 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
| 120 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
| 121 |
+
|
| 122 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 123 |
+
|
| 124 |
+
self.size=size
|
| 125 |
+
|
| 126 |
+
def forward(self, *xs, size=None):
|
| 127 |
+
"""Forward pass.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
tensor: output
|
| 131 |
+
"""
|
| 132 |
+
output = xs[0]
|
| 133 |
+
|
| 134 |
+
if len(xs) == 2:
|
| 135 |
+
res = self.resConfUnit1(xs[1])
|
| 136 |
+
output = self.skip_add.add(output, res)
|
| 137 |
+
|
| 138 |
+
output = self.resConfUnit2(output)
|
| 139 |
+
|
| 140 |
+
if (size is None) and (self.size is None):
|
| 141 |
+
modifier = {"scale_factor": 2}
|
| 142 |
+
elif size is None:
|
| 143 |
+
modifier = {"size": self.size}
|
| 144 |
+
else:
|
| 145 |
+
modifier = {"size": size}
|
| 146 |
+
|
| 147 |
+
output = nn.functional.interpolate(
|
| 148 |
+
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
output = self.out_conv(output)
|
| 152 |
+
|
| 153 |
+
return output
|
FoundationStereo_demo/depth_anything/dpt.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch,os,sys,pdb
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
| 6 |
+
sys.path.append(f'{code_dir}/../')
|
| 7 |
+
from depth_anything.blocks import FeatureFusionBlock, _make_scratch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _make_fusion_block(features, use_bn, size = None):
|
| 11 |
+
return FeatureFusionBlock(
|
| 12 |
+
features,
|
| 13 |
+
nn.ReLU(False),
|
| 14 |
+
deconv=False,
|
| 15 |
+
bn=use_bn,
|
| 16 |
+
expand=False,
|
| 17 |
+
align_corners=True,
|
| 18 |
+
size=size,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DPTHead(nn.Module):
|
| 23 |
+
def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
|
| 24 |
+
super(DPTHead, self).__init__()
|
| 25 |
+
|
| 26 |
+
self.nclass = nclass
|
| 27 |
+
self.use_clstoken = use_clstoken
|
| 28 |
+
|
| 29 |
+
self.projects = nn.ModuleList([
|
| 30 |
+
nn.Conv2d(
|
| 31 |
+
in_channels=in_channels,
|
| 32 |
+
out_channels=out_channel,
|
| 33 |
+
kernel_size=1,
|
| 34 |
+
stride=1,
|
| 35 |
+
padding=0,
|
| 36 |
+
) for out_channel in out_channels
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
self.resize_layers = nn.ModuleList([
|
| 40 |
+
nn.ConvTranspose2d(
|
| 41 |
+
in_channels=out_channels[0],
|
| 42 |
+
out_channels=out_channels[0],
|
| 43 |
+
kernel_size=4,
|
| 44 |
+
stride=4,
|
| 45 |
+
padding=0),
|
| 46 |
+
nn.ConvTranspose2d(
|
| 47 |
+
in_channels=out_channels[1],
|
| 48 |
+
out_channels=out_channels[1],
|
| 49 |
+
kernel_size=2,
|
| 50 |
+
stride=2,
|
| 51 |
+
padding=0),
|
| 52 |
+
nn.Identity(),
|
| 53 |
+
nn.Conv2d(
|
| 54 |
+
in_channels=out_channels[3],
|
| 55 |
+
out_channels=out_channels[3],
|
| 56 |
+
kernel_size=3,
|
| 57 |
+
stride=2,
|
| 58 |
+
padding=1)
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
if use_clstoken:
|
| 62 |
+
self.readout_projects = nn.ModuleList()
|
| 63 |
+
for _ in range(len(self.projects)):
|
| 64 |
+
self.readout_projects.append(
|
| 65 |
+
nn.Sequential(
|
| 66 |
+
nn.Linear(2 * in_channels, in_channels),
|
| 67 |
+
nn.GELU()))
|
| 68 |
+
|
| 69 |
+
self.scratch = _make_scratch(
|
| 70 |
+
out_channels,
|
| 71 |
+
features,
|
| 72 |
+
groups=1,
|
| 73 |
+
expand=False,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.scratch.stem_transpose = None
|
| 77 |
+
|
| 78 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
| 79 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
| 80 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
| 81 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
| 82 |
+
|
| 83 |
+
head_features_1 = features
|
| 84 |
+
head_features_2 = 32
|
| 85 |
+
|
| 86 |
+
if nclass > 1:
|
| 87 |
+
self.scratch.output_conv = nn.Sequential(
|
| 88 |
+
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
| 89 |
+
nn.ReLU(True),
|
| 90 |
+
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
| 94 |
+
|
| 95 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 96 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 97 |
+
nn.ReLU(True),
|
| 98 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
| 99 |
+
nn.ReLU(True),
|
| 100 |
+
nn.Identity(),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def forward(self, out_features, patch_h, patch_w, return_intermediate=False, patch_size=14):
|
| 104 |
+
out = []
|
| 105 |
+
for i, x in enumerate(out_features):
|
| 106 |
+
if self.use_clstoken:
|
| 107 |
+
x, cls_token = x[0], x[1]
|
| 108 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 109 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 110 |
+
else:
|
| 111 |
+
x = x[0]
|
| 112 |
+
|
| 113 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 114 |
+
|
| 115 |
+
x = self.projects[i](x)
|
| 116 |
+
x = self.resize_layers[i](x)
|
| 117 |
+
|
| 118 |
+
out.append(x)
|
| 119 |
+
|
| 120 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
| 121 |
+
|
| 122 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 123 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 124 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 125 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 126 |
+
|
| 127 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 128 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 129 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 130 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 131 |
+
|
| 132 |
+
out = self.scratch.output_conv1(path_1)
|
| 133 |
+
out = F.interpolate(out, (int(patch_h * patch_size), int(patch_w * patch_size)), mode="bilinear", align_corners=True)
|
| 134 |
+
if return_intermediate:
|
| 135 |
+
depth = self.scratch.output_conv2(out)
|
| 136 |
+
depth = F.relu(depth)
|
| 137 |
+
disp = 1/depth
|
| 138 |
+
disp[depth==0] = 0
|
| 139 |
+
disp = disp/disp.max()
|
| 140 |
+
return out, path_1, path_2, path_3, path_4, disp
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
out = self.scratch.output_conv2(out)
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class DPT_DINOv2(nn.Module):
|
| 148 |
+
def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, pretrained_dino=False):
|
| 149 |
+
super(DPT_DINOv2, self).__init__()
|
| 150 |
+
|
| 151 |
+
assert encoder in ['vits', 'vitb', 'vitl']
|
| 152 |
+
|
| 153 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
| 154 |
+
# if localhub:
|
| 155 |
+
# self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
|
| 156 |
+
# else:
|
| 157 |
+
self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder), pretrained=pretrained_dino)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
| 161 |
+
|
| 162 |
+
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
h, w = x.shape[-2:]
|
| 166 |
+
|
| 167 |
+
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
| 168 |
+
patch_size = self.pretrained.patch_size
|
| 169 |
+
patch_h, patch_w = h // patch_size, w // patch_size
|
| 170 |
+
output = self.depth_head(features, patch_h, patch_w, patch_size=patch_size, return_intermediate=True)
|
| 171 |
+
return output
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class DepthAnything(DPT_DINOv2):
|
| 175 |
+
def __init__(self, config):
|
| 176 |
+
super().__init__(**config)
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
h, w = x.shape[-2:]
|
| 180 |
+
|
| 181 |
+
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
| 182 |
+
patch_size = self.pretrained.patch_size
|
| 183 |
+
patch_h, patch_w = h // patch_size, w // patch_size
|
| 184 |
+
depth = self.depth_head(features, patch_h, patch_w, patch_size=patch_size)
|
| 185 |
+
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
| 186 |
+
depth = F.relu(depth)
|
| 187 |
+
|
| 188 |
+
return depth.squeeze(1)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == '__main__':
|
| 192 |
+
parser = argparse.ArgumentParser()
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--encoder",
|
| 195 |
+
default="vits",
|
| 196 |
+
type=str,
|
| 197 |
+
choices=["vits", "vitb", "vitl"],
|
| 198 |
+
)
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
|
| 201 |
+
model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
|
| 202 |
+
|
| 203 |
+
print(model)
|
FoundationStereo_demo/depth_anything/util/transform.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from PIL import Image, ImageOps, ImageFilter
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
| 13 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
sample (dict): sample
|
| 17 |
+
size (tuple): image size
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
tuple: new size
|
| 21 |
+
"""
|
| 22 |
+
shape = list(sample["disparity"].shape)
|
| 23 |
+
|
| 24 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
| 25 |
+
return sample
|
| 26 |
+
|
| 27 |
+
scale = [0, 0]
|
| 28 |
+
scale[0] = size[0] / shape[0]
|
| 29 |
+
scale[1] = size[1] / shape[1]
|
| 30 |
+
|
| 31 |
+
scale = max(scale)
|
| 32 |
+
|
| 33 |
+
shape[0] = math.ceil(scale * shape[0])
|
| 34 |
+
shape[1] = math.ceil(scale * shape[1])
|
| 35 |
+
|
| 36 |
+
# resize
|
| 37 |
+
sample["image"] = cv2.resize(
|
| 38 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
sample["disparity"] = cv2.resize(
|
| 42 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
| 43 |
+
)
|
| 44 |
+
sample["mask"] = cv2.resize(
|
| 45 |
+
sample["mask"].astype(np.float32),
|
| 46 |
+
tuple(shape[::-1]),
|
| 47 |
+
interpolation=cv2.INTER_NEAREST,
|
| 48 |
+
)
|
| 49 |
+
sample["mask"] = sample["mask"].astype(bool)
|
| 50 |
+
|
| 51 |
+
return tuple(shape)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Resize(object):
|
| 55 |
+
"""Resize sample to given size (width, height).
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
width,
|
| 61 |
+
height,
|
| 62 |
+
resize_target=True,
|
| 63 |
+
keep_aspect_ratio=False,
|
| 64 |
+
ensure_multiple_of=1,
|
| 65 |
+
resize_method="lower_bound",
|
| 66 |
+
image_interpolation_method=cv2.INTER_AREA,
|
| 67 |
+
):
|
| 68 |
+
"""Init.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
width (int): desired output width
|
| 72 |
+
height (int): desired output height
|
| 73 |
+
resize_target (bool, optional):
|
| 74 |
+
True: Resize the full sample (image, mask, target).
|
| 75 |
+
False: Resize image only.
|
| 76 |
+
Defaults to True.
|
| 77 |
+
keep_aspect_ratio (bool, optional):
|
| 78 |
+
True: Keep the aspect ratio of the input sample.
|
| 79 |
+
Output sample might not have the given width and height, and
|
| 80 |
+
resize behaviour depends on the parameter 'resize_method'.
|
| 81 |
+
Defaults to False.
|
| 82 |
+
ensure_multiple_of (int, optional):
|
| 83 |
+
Output width and height is constrained to be multiple of this parameter.
|
| 84 |
+
Defaults to 1.
|
| 85 |
+
resize_method (str, optional):
|
| 86 |
+
"lower_bound": Output will be at least as large as the given size.
|
| 87 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
| 88 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
| 89 |
+
Defaults to "lower_bound".
|
| 90 |
+
"""
|
| 91 |
+
self.__width = width
|
| 92 |
+
self.__height = height
|
| 93 |
+
|
| 94 |
+
self.__resize_target = resize_target
|
| 95 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
| 96 |
+
self.__multiple_of = ensure_multiple_of
|
| 97 |
+
self.__resize_method = resize_method
|
| 98 |
+
self.__image_interpolation_method = image_interpolation_method
|
| 99 |
+
|
| 100 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
| 101 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 102 |
+
|
| 103 |
+
if max_val is not None and y > max_val:
|
| 104 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 105 |
+
|
| 106 |
+
if y < min_val:
|
| 107 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 108 |
+
|
| 109 |
+
return y
|
| 110 |
+
|
| 111 |
+
def get_size(self, width, height):
|
| 112 |
+
# determine new height and width
|
| 113 |
+
scale_height = self.__height / height
|
| 114 |
+
scale_width = self.__width / width
|
| 115 |
+
|
| 116 |
+
if self.__keep_aspect_ratio:
|
| 117 |
+
if self.__resize_method == "lower_bound":
|
| 118 |
+
# scale such that output size is lower bound
|
| 119 |
+
if scale_width > scale_height:
|
| 120 |
+
# fit width
|
| 121 |
+
scale_height = scale_width
|
| 122 |
+
else:
|
| 123 |
+
# fit height
|
| 124 |
+
scale_width = scale_height
|
| 125 |
+
elif self.__resize_method == "upper_bound":
|
| 126 |
+
# scale such that output size is upper bound
|
| 127 |
+
if scale_width < scale_height:
|
| 128 |
+
# fit width
|
| 129 |
+
scale_height = scale_width
|
| 130 |
+
else:
|
| 131 |
+
# fit height
|
| 132 |
+
scale_width = scale_height
|
| 133 |
+
elif self.__resize_method == "minimal":
|
| 134 |
+
# scale as least as possbile
|
| 135 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
| 136 |
+
# fit width
|
| 137 |
+
scale_height = scale_width
|
| 138 |
+
else:
|
| 139 |
+
# fit height
|
| 140 |
+
scale_width = scale_height
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
f"resize_method {self.__resize_method} not implemented"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if self.__resize_method == "lower_bound":
|
| 147 |
+
new_height = self.constrain_to_multiple_of(
|
| 148 |
+
scale_height * height, min_val=self.__height
|
| 149 |
+
)
|
| 150 |
+
new_width = self.constrain_to_multiple_of(
|
| 151 |
+
scale_width * width, min_val=self.__width
|
| 152 |
+
)
|
| 153 |
+
elif self.__resize_method == "upper_bound":
|
| 154 |
+
new_height = self.constrain_to_multiple_of(
|
| 155 |
+
scale_height * height, max_val=self.__height
|
| 156 |
+
)
|
| 157 |
+
new_width = self.constrain_to_multiple_of(
|
| 158 |
+
scale_width * width, max_val=self.__width
|
| 159 |
+
)
|
| 160 |
+
elif self.__resize_method == "minimal":
|
| 161 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
| 162 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 165 |
+
|
| 166 |
+
return (new_width, new_height)
|
| 167 |
+
|
| 168 |
+
def __call__(self, sample):
|
| 169 |
+
width, height = self.get_size(
|
| 170 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# resize sample
|
| 174 |
+
sample["image"] = cv2.resize(
|
| 175 |
+
sample["image"],
|
| 176 |
+
(width, height),
|
| 177 |
+
interpolation=self.__image_interpolation_method,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if self.__resize_target:
|
| 181 |
+
if "disparity" in sample:
|
| 182 |
+
sample["disparity"] = cv2.resize(
|
| 183 |
+
sample["disparity"],
|
| 184 |
+
(width, height),
|
| 185 |
+
interpolation=cv2.INTER_NEAREST,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if "depth" in sample:
|
| 189 |
+
sample["depth"] = cv2.resize(
|
| 190 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if "semseg_mask" in sample:
|
| 194 |
+
# sample["semseg_mask"] = cv2.resize(
|
| 195 |
+
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
| 196 |
+
# )
|
| 197 |
+
sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
|
| 198 |
+
|
| 199 |
+
if "mask" in sample:
|
| 200 |
+
sample["mask"] = cv2.resize(
|
| 201 |
+
sample["mask"].astype(np.float32),
|
| 202 |
+
(width, height),
|
| 203 |
+
interpolation=cv2.INTER_NEAREST,
|
| 204 |
+
)
|
| 205 |
+
# sample["mask"] = sample["mask"].astype(bool)
|
| 206 |
+
|
| 207 |
+
# print(sample['image'].shape, sample['depth'].shape)
|
| 208 |
+
return sample
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class NormalizeImage(object):
|
| 212 |
+
"""Normlize image by given mean and std.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, mean, std):
|
| 216 |
+
self.__mean = mean
|
| 217 |
+
self.__std = std
|
| 218 |
+
|
| 219 |
+
def __call__(self, sample):
|
| 220 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
| 221 |
+
|
| 222 |
+
return sample
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class PrepareForNet(object):
|
| 226 |
+
"""Prepare sample for usage as network input.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(self):
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
def __call__(self, sample):
|
| 233 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
| 234 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
| 235 |
+
|
| 236 |
+
if "mask" in sample:
|
| 237 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
| 238 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
| 239 |
+
|
| 240 |
+
if "depth" in sample:
|
| 241 |
+
depth = sample["depth"].astype(np.float32)
|
| 242 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
| 243 |
+
|
| 244 |
+
if "semseg_mask" in sample:
|
| 245 |
+
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
| 246 |
+
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
| 247 |
+
|
| 248 |
+
return sample
|
assets/example1/K.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0
|
| 2 |
+
0.063
|
assets/example1/left.png
ADDED
|
Git LFS Details
|
assets/example1/right.png
ADDED
|
Git LFS Details
|
assets/example2/K.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cam0=[1733.74 0 792.27; 0 1733.74 541.89; 0 0 1]
|
| 2 |
+
cam1=[1733.74 0 792.27; 0 1733.74 541.89; 0 0 1]
|
| 3 |
+
doffs=0
|
| 4 |
+
baseline=536.62
|
| 5 |
+
width=1920
|
| 6 |
+
height=1080
|
| 7 |
+
ndisp=170
|
| 8 |
+
vmin=55
|
| 9 |
+
vmax=142
|
assets/example2/left.png
ADDED
|
Git LFS Details
|
assets/example2/right.png
ADDED
|
Git LFS Details
|