Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,857 Bytes
b7f83b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import glob
import os
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from vggt.utils.load_fn import load_and_preprocess_images
from eval.utils.eval_utils import uniform_sample
class TumDatasetAll(Dataset):
def __init__(self, root_dir, scene_name="rgbd_dataset_freiburg1_360"):
self.scene_name = scene_name
self.scene_all_dir = os.path.join(root_dir, f"{scene_name}", "rgb")
self.test_samples = sorted(glob.glob(os.path.join(self.scene_all_dir, "*.png")))
self.all_samples = self.test_samples
def __len__(self):
return len(self.all_samples)
def __getitem__(self, idx):
return self._load_sample(self.all_samples[idx])
def get_train_sample(self, n=4):
gap = len(self.all_samples) // n
gap = max(gap, 1) # Ensure at least one sample is selected
gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
if gap == 1:
uniform_sampled = uniform_sample(len(self.all_samples), n)
selected = [self.all_samples[i] for i in uniform_sampled]
else:
selected = self.all_samples[::gap]
if len(selected) > n:
uniform_sampled = uniform_sample(len(selected), n)
selected = [selected[i] for i in uniform_sampled]
if self.scene_name == "rgbd_dataset_freiburg1_floor":
selected += self.all_samples[-20::5]
return [self._load_sample(s) for s in selected]
def _load_sample(self, rgb_path):
img_name = os.path.basename(rgb_path)
color = load_and_preprocess_images([rgb_path])[0]
return dict(
img=color,
dataset="tnt_all",
true_shape=torch.tensor([392, 518]),
label=img_name,
instance=img_name,
)
|