SAIL-Recon / eval /datasets /seven_scenes.py
hengli
first
b7f83b0
import glob
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from vggt.utils.load_fn import load_and_preprocess_images
from eval.utils.eval_utils import uniform_sample
class SevenScenesUnifiedDataset(Dataset):
def __init__(self, root_dir, scene_name="chess"):
self.scene_dir = os.path.join(root_dir, f"pgt_7scenes_{scene_name}")
self.train_seqs = os.path.join(self.scene_dir, "train")
self.test_seqs = os.path.join(self.scene_dir, "test")
self.test_samples = sorted(
glob.glob(os.path.join(self.test_seqs, "rgb", "*.png"))
)
self.train_samples = sorted(
glob.glob(os.path.join(self.train_seqs, "rgb", "*.png"))
)
self.all_samples = self.test_samples # + self.train_samples
# len_samples = len(self.all_samples)
# self.all_samples = self.all_samples[::len_samples//200]
def __len__(self):
return len(self.all_samples)
def __getitem__(self, idx):
return self._load_sample(self.all_samples[idx])
def get_train_sample(self, n=4):
uniform_sampled = uniform_sample(len(self.all_samples), n)
selected = [self.all_samples[i] for i in uniform_sampled]
return [self._load_sample(s) for s in selected]
def _load_sample(self, rgb_path):
img_name = os.path.basename(rgb_path)
color = load_and_preprocess_images([rgb_path])[0]
pose_path = (
rgb_path.replace("rgb", "poses")
.replace("color", "pose")
.replace(".png", ".txt")
)
pose = np.loadtxt(pose_path)
pose = torch.from_numpy(pose).float()
return dict(
img=color,
camera_pose=pose, # cam2world
dataset="7Scenes",
true_shape=torch.tensor([392, 518]),
label=img_name,
instance=img_name,
)