ZhenweiWang commited on
Commit
073f8bf
·
verified ·
1 Parent(s): a795c20

Update src/models/models/worldmirror.py

Browse files
Files changed (1) hide show
  1. src/models/models/worldmirror.py +6 -23
src/models/models/worldmirror.py CHANGED
@@ -131,40 +131,27 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
131
  imgs = views['img']
132
 
133
  # Enable conditional input during training if enabled, or during inference if any cond_flags are set
134
- use_cond = (
135
- (self.training and self.enable_cond) or
136
- (not self.training and sum(cond_flags) > 0)
137
- )
138
 
139
  # Extract priors and process features based on conditional input
140
- context_token_list = None
141
  if use_cond:
142
  priors = self.extract_priors(views)
143
  token_list, patch_start_idx = self.visual_geometry_transformer(
144
  imgs, priors, cond_flags=cond_flags
145
  )
146
- if self.enable_gs:
147
- cnums = views["context_nums"]
148
- context_priors = (priors[0][:,:cnums], priors[1][:,:cnums], priors[2][:,:cnums])
149
- context_token_list = self.visual_geometry_transformer(
150
- imgs[:,:cnums], context_priors, cond_flags=cond_flags
151
- )[0]
152
  else:
153
  token_list, patch_start_idx = self.visual_geometry_transformer(imgs)
154
- if self.enable_gs:
155
- cnums = views["context_nums"] if "context_nums" in views else imgs.shape[1]
156
- context_token_list = self.visual_geometry_transformer(imgs[:,:cnums])[0]
157
 
158
  # Execute predictions
159
  with torch.amp.autocast('cuda', enabled=False):
160
  # Generate all predictions
161
  preds = self._gen_all_preds(
162
- token_list, context_token_list, imgs, patch_start_idx, views
163
  )
164
 
165
  return preds
166
 
167
- def _gen_all_preds(self, token_list, context_token_list,
168
  imgs, patch_start_idx, views):
169
  """Generate all enabled predictions"""
170
  preds = {}
@@ -175,9 +162,7 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
175
  cam_seq = self.cam_head(token_list)
176
  cam_params = cam_seq[-1]
177
  preds["camera_params"] = cam_params
178
- if context_token_list is not None:
179
- context_cam_params = self.cam_head(context_token_list)[-1]
180
- context_preds = {"camera_params": context_cam_params}
181
  ext_mat, int_mat = vector_to_camera_matrices(
182
  cam_params, image_hw=(imgs.shape[-2], imgs.shape[-1])
183
  )
@@ -216,9 +201,8 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
216
 
217
  # 3D Gaussian Splatting
218
  if self.enable_gs:
219
- views['context_nums'] = imgs.shape[1] if "context_nums" not in views else views["context_nums"]
220
  gs_feat, gs_depth, gs_depth_conf = self.gs_head(
221
- context_token_list, images=imgs[:,:views["context_nums"]], patch_start_idx=patch_start_idx
222
  )
223
 
224
  preds["gs_depth"] = gs_depth
@@ -228,7 +212,6 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
228
  images=imgs,
229
  predictions=preds,
230
  views=views,
231
- context_predictions=context_preds
232
  )
233
 
234
  return preds
@@ -246,7 +229,7 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
246
  h, w = views['img'].shape[-2:]
247
 
248
  # Initialize prior variables
249
- poses = depths = rays = None
250
 
251
  # Extract camera pose
252
  if 'camera_pose' in views:
 
131
  imgs = views['img']
132
 
133
  # Enable conditional input during training if enabled, or during inference if any cond_flags are set
134
+ use_cond = sum(cond_flags) > 0
 
 
 
135
 
136
  # Extract priors and process features based on conditional input
 
137
  if use_cond:
138
  priors = self.extract_priors(views)
139
  token_list, patch_start_idx = self.visual_geometry_transformer(
140
  imgs, priors, cond_flags=cond_flags
141
  )
 
 
 
 
 
 
142
  else:
143
  token_list, patch_start_idx = self.visual_geometry_transformer(imgs)
 
 
 
144
 
145
  # Execute predictions
146
  with torch.amp.autocast('cuda', enabled=False):
147
  # Generate all predictions
148
  preds = self._gen_all_preds(
149
+ token_list, imgs, patch_start_idx, views
150
  )
151
 
152
  return preds
153
 
154
+ def _gen_all_preds(self, token_list,
155
  imgs, patch_start_idx, views):
156
  """Generate all enabled predictions"""
157
  preds = {}
 
162
  cam_seq = self.cam_head(token_list)
163
  cam_params = cam_seq[-1]
164
  preds["camera_params"] = cam_params
165
+
 
 
166
  ext_mat, int_mat = vector_to_camera_matrices(
167
  cam_params, image_hw=(imgs.shape[-2], imgs.shape[-1])
168
  )
 
201
 
202
  # 3D Gaussian Splatting
203
  if self.enable_gs:
 
204
  gs_feat, gs_depth, gs_depth_conf = self.gs_head(
205
+ token_list, images=imgs, patch_start_idx=patch_start_idx
206
  )
207
 
208
  preds["gs_depth"] = gs_depth
 
212
  images=imgs,
213
  predictions=preds,
214
  views=views,
 
215
  )
216
 
217
  return preds
 
229
  h, w = views['img'].shape[-2:]
230
 
231
  # Initialize prior variables
232
+ depths = rays = poses = None
233
 
234
  # Extract camera pose
235
  if 'camera_pose' in views: