ZhenweiWang commited on
Commit
9c4386d
·
verified ·
1 Parent(s): 2bbacf3

Update src/models/models/rasterization.py

Browse files
Files changed (1) hide show
  1. src/models/models/rasterization.py +19 -76
src/models/models/rasterization.py CHANGED
@@ -184,67 +184,29 @@ class GaussianSplatRenderer(nn.Module):
184
  # 1) Predict GS features from tokens, then convert to Gaussian parameters
185
  gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
186
  gs_params = self.gs_head(gs_feats_reshape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- # 2) Select cameras (predicted or GT), and organize supervision data (gt_colors, gt_depths, valid_masks)
189
- if self.training:
190
- if self.render_novel_views and V > 0:
191
- pred_all_extrinsic, pred_all_intrinsic = self.prepare_cameras(views, S+V)
192
- render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
193
- render_images = images
194
- gt_colors = render_images.permute(0, 1, 3, 4, 2)
195
- gt_depths = views["depthmap"] # [B, S+V, H, W]
196
-
197
- gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
198
- gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
199
- unproject_masks = calculate_unprojected_mask(views, S) # [B, V, H, W]
200
- valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
201
- else:
202
- # Only render source views
203
- render_viewmats, render_Ks = self.prepare_cameras(views, S)
204
- render_images = views["img"][:, :S]
205
- gt_colors = render_images.permute(0, 1, 3, 4, 2)
206
- gt_depths = views["depthmap"][:, :S]
207
- gt_valid_masks = views["valid_mask"][:, :S]
208
- valid_masks = gt_valid_masks
209
- else:
210
- # Re-predict cameras for novel views and perform translation/scale alignment
211
- Bx = images.shape[0]
212
- pred_all_extrinsic, pred_all_intrinsic = self.prepare_prediction_cameras(predictions, S + V, hw=(H, W))
213
- pred_all_extrinsic = pred_all_extrinsic.reshape(Bx, S + V, 4, 4)
214
- pred_all_source_extrinsic = pred_all_extrinsic[:, :S]
215
-
216
- scale_factor = 1.0
217
- if context_predictions is not None:
218
- pred_source_extrinsic, _ = self.prepare_prediction_cameras(context_predictions, S, hw=(H, W))
219
- pred_source_extrinsic = pred_source_extrinsic.reshape(Bx, S, 4, 4)
220
- scale_factor = pred_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) / (
221
- pred_all_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) + 1e-6
222
- )
223
 
224
- pred_all_extrinsic[..., :3, 3] = pred_all_extrinsic[..., :3, 3] * scale_factor
225
-
226
- render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
227
- render_images = images
228
- gt_colors = render_images.permute(0, 1, 3, 4, 2)
229
-
230
- # Handle pure inference case where views may not have ground truth data
231
- gt_depths = views.get("depthmap")
232
- valid_masks = None
233
- if gt_depths is not None:
234
- if views.get("gt_depth") is not None and views["gt_depth"]:
235
- unproject_masks = calculate_unprojected_mask(views, S)
236
- gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
237
- gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
238
- gt_valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
239
- else:
240
- gt_valid_masks = views.get("valid_mask")
241
- valid_masks = gt_valid_masks
242
 
243
  # 3) Generate splats from gs_params + predictions, and perform voxel merging
244
- if self.training and self.using_gtcamera_splat:
245
- splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+gtcamera", debug=False)
246
- else:
247
- splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+predcamera", context_predictions=context_predictions, debug=False)
248
  splats_raw = {k: v.clone() for k, v in splats.items()}
249
 
250
  # Apply confidence filtering before pruning
@@ -255,19 +217,6 @@ class GaussianSplatRenderer(nn.Module):
255
  splats = self.prune_gs(splats, voxel_size=self.voxel_size)
256
 
257
  # 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
258
- if self.training:
259
- if self.render_novel_views and V > 0:
260
- indices = np.arange(S+V)
261
- else:
262
- indices = np.arange(S)
263
-
264
- render_viewmats = render_viewmats[:, indices]
265
- render_Ks = render_Ks[:, indices]
266
- gt_colors = gt_colors[:, indices]
267
- if gt_depths is not None:
268
- gt_depths = gt_depths[:, indices]
269
- if valid_masks is not None:
270
- valid_masks = valid_masks[:, indices]
271
 
272
  # Prevent OOM by using chunked rendering
273
  rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
@@ -292,16 +241,10 @@ class GaussianSplatRenderer(nn.Module):
292
  rendered_depths = torch.cat(rendered_depths_list, dim=1)
293
  rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
294
 
295
- if self.training and self.render_novel_views and V > 0:
296
- nvs_rendered_mask = rendered_alphas[:, S:, ..., 0].detach() > 0.1
297
- valid_masks[:, S:] = nvs_rendered_mask & valid_masks[:, S:]
298
-
299
  # 5) return predictions
300
  predictions["rendered_colors"] = rendered_colors
301
  predictions["rendered_depths"] = rendered_depths
302
  predictions["gt_colors"] = gt_colors
303
- predictions["gt_depths"] = gt_depths
304
- predictions["valid_masks"] = valid_masks
305
  predictions["splats"] = splats
306
  predictions["splats_raw"] = splats_raw
307
  predictions["rendered_extrinsics"] = render_viewmats
 
184
  # 1) Predict GS features from tokens, then convert to Gaussian parameters
185
  gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
186
  gs_params = self.gs_head(gs_feats_reshape)
187
+
188
+ # 2) Select predicted cameras
189
+ Bx = images.shape[0]
190
+ pred_all_extrinsic, pred_all_intrinsic = self.prepare_prediction_cameras(predictions, S + V, hw=(H, W))
191
+ pred_all_extrinsic = pred_all_extrinsic.reshape(Bx, S + V, 4, 4)
192
+ pred_all_source_extrinsic = pred_all_extrinsic[:, :S]
193
+
194
+ scale_factor = 1.0
195
+ if context_predictions is not None:
196
+ pred_source_extrinsic, _ = self.prepare_prediction_cameras(context_predictions, S, hw=(H, W))
197
+ pred_source_extrinsic = pred_source_extrinsic.reshape(Bx, S, 4, 4)
198
+ scale_factor = pred_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) / (
199
+ pred_all_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) + 1e-6
200
+ )
201
 
202
+ pred_all_extrinsic[..., :3, 3] = pred_all_extrinsic[..., :3, 3] * scale_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
205
+ render_images = images
206
+ gt_colors = render_images.permute(0, 1, 3, 4, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # 3) Generate splats from gs_params + predictions, and perform voxel merging
209
+ splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+predcamera", context_predictions=context_predictions, debug=False)
 
 
 
210
  splats_raw = {k: v.clone() for k, v in splats.items()}
211
 
212
  # Apply confidence filtering before pruning
 
217
  splats = self.prune_gs(splats, voxel_size=self.voxel_size)
218
 
219
  # 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # Prevent OOM by using chunked rendering
222
  rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
 
241
  rendered_depths = torch.cat(rendered_depths_list, dim=1)
242
  rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
243
 
 
 
 
 
244
  # 5) return predictions
245
  predictions["rendered_colors"] = rendered_colors
246
  predictions["rendered_depths"] = rendered_depths
247
  predictions["gt_colors"] = gt_colors
 
 
248
  predictions["splats"] = splats
249
  predictions["splats_raw"] = splats_raw
250
  predictions["rendered_extrinsics"] = render_viewmats