Spaces:
Running
on
Zero
Running
on
Zero
| import glob | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from vggt.utils.load_fn import load_and_preprocess_images | |
| from eval.utils.eval_utils import uniform_sample | |
| class TumDatasetAll(Dataset): | |
| def __init__(self, root_dir, scene_name="rgbd_dataset_freiburg1_360"): | |
| self.scene_name = scene_name | |
| self.scene_all_dir = os.path.join(root_dir, f"{scene_name}", "rgb") | |
| self.test_samples = sorted(glob.glob(os.path.join(self.scene_all_dir, "*.png"))) | |
| self.all_samples = self.test_samples | |
| def __len__(self): | |
| return len(self.all_samples) | |
| def __getitem__(self, idx): | |
| return self._load_sample(self.all_samples[idx]) | |
| def get_train_sample(self, n=4): | |
| gap = len(self.all_samples) // n | |
| gap = max(gap, 1) # Ensure at least one sample is selected | |
| gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length | |
| if gap == 1: | |
| uniform_sampled = uniform_sample(len(self.all_samples), n) | |
| selected = [self.all_samples[i] for i in uniform_sampled] | |
| else: | |
| selected = self.all_samples[::gap] | |
| if len(selected) > n: | |
| uniform_sampled = uniform_sample(len(selected), n) | |
| selected = [selected[i] for i in uniform_sampled] | |
| if self.scene_name == "rgbd_dataset_freiburg1_floor": | |
| selected += self.all_samples[-20::5] | |
| return [self._load_sample(s) for s in selected] | |
| def _load_sample(self, rgb_path): | |
| img_name = os.path.basename(rgb_path) | |
| color = load_and_preprocess_images([rgb_path])[0] | |
| return dict( | |
| img=color, | |
| dataset="tnt_all", | |
| true_shape=torch.tensor([392, 518]), | |
| label=img_name, | |
| instance=img_name, | |
| ) | |