Spaces:
Running
on
Zero
Running
on
Zero
| import glob | |
| import os | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from vggt.utils.load_fn import load_and_preprocess_images | |
| from eval.utils.eval_utils import uniform_sample | |
| class SevenScenesUnifiedDataset(Dataset): | |
| def __init__(self, root_dir, scene_name="chess"): | |
| self.scene_dir = os.path.join(root_dir, f"pgt_7scenes_{scene_name}") | |
| self.train_seqs = os.path.join(self.scene_dir, "train") | |
| self.test_seqs = os.path.join(self.scene_dir, "test") | |
| self.test_samples = sorted( | |
| glob.glob(os.path.join(self.test_seqs, "rgb", "*.png")) | |
| ) | |
| self.train_samples = sorted( | |
| glob.glob(os.path.join(self.train_seqs, "rgb", "*.png")) | |
| ) | |
| self.all_samples = self.test_samples # + self.train_samples | |
| # len_samples = len(self.all_samples) | |
| # self.all_samples = self.all_samples[::len_samples//200] | |
| def __len__(self): | |
| return len(self.all_samples) | |
| def __getitem__(self, idx): | |
| return self._load_sample(self.all_samples[idx]) | |
| def get_train_sample(self, n=4): | |
| uniform_sampled = uniform_sample(len(self.all_samples), n) | |
| selected = [self.all_samples[i] for i in uniform_sampled] | |
| return [self._load_sample(s) for s in selected] | |
| def _load_sample(self, rgb_path): | |
| img_name = os.path.basename(rgb_path) | |
| color = load_and_preprocess_images([rgb_path])[0] | |
| pose_path = ( | |
| rgb_path.replace("rgb", "poses") | |
| .replace("color", "pose") | |
| .replace(".png", ".txt") | |
| ) | |
| pose = np.loadtxt(pose_path) | |
| pose = torch.from_numpy(pose).float() | |
| return dict( | |
| img=color, | |
| camera_pose=pose, # cam2world | |
| dataset="7Scenes", | |
| true_shape=torch.tensor([392, 518]), | |
| label=img_name, | |
| instance=img_name, | |
| ) | |