Spaces:
Running
on
Zero
Running
on
Zero
Update src/models/models/rasterization.py
Browse files
src/models/models/rasterization.py
CHANGED
|
@@ -184,67 +184,29 @@ class GaussianSplatRenderer(nn.Module):
|
|
| 184 |
# 1) Predict GS features from tokens, then convert to Gaussian parameters
|
| 185 |
gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
|
| 186 |
gs_params = self.gs_head(gs_feats_reshape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
|
| 189 |
-
if self.training:
|
| 190 |
-
if self.render_novel_views and V > 0:
|
| 191 |
-
pred_all_extrinsic, pred_all_intrinsic = self.prepare_cameras(views, S+V)
|
| 192 |
-
render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
|
| 193 |
-
render_images = images
|
| 194 |
-
gt_colors = render_images.permute(0, 1, 3, 4, 2)
|
| 195 |
-
gt_depths = views["depthmap"] # [B, S+V, H, W]
|
| 196 |
-
|
| 197 |
-
gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
|
| 198 |
-
gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
|
| 199 |
-
unproject_masks = calculate_unprojected_mask(views, S) # [B, V, H, W]
|
| 200 |
-
valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
|
| 201 |
-
else:
|
| 202 |
-
# Only render source views
|
| 203 |
-
render_viewmats, render_Ks = self.prepare_cameras(views, S)
|
| 204 |
-
render_images = views["img"][:, :S]
|
| 205 |
-
gt_colors = render_images.permute(0, 1, 3, 4, 2)
|
| 206 |
-
gt_depths = views["depthmap"][:, :S]
|
| 207 |
-
gt_valid_masks = views["valid_mask"][:, :S]
|
| 208 |
-
valid_masks = gt_valid_masks
|
| 209 |
-
else:
|
| 210 |
-
# Re-predict cameras for novel views and perform translation/scale alignment
|
| 211 |
-
Bx = images.shape[0]
|
| 212 |
-
pred_all_extrinsic, pred_all_intrinsic = self.prepare_prediction_cameras(predictions, S + V, hw=(H, W))
|
| 213 |
-
pred_all_extrinsic = pred_all_extrinsic.reshape(Bx, S + V, 4, 4)
|
| 214 |
-
pred_all_source_extrinsic = pred_all_extrinsic[:, :S]
|
| 215 |
-
|
| 216 |
-
scale_factor = 1.0
|
| 217 |
-
if context_predictions is not None:
|
| 218 |
-
pred_source_extrinsic, _ = self.prepare_prediction_cameras(context_predictions, S, hw=(H, W))
|
| 219 |
-
pred_source_extrinsic = pred_source_extrinsic.reshape(Bx, S, 4, 4)
|
| 220 |
-
scale_factor = pred_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) / (
|
| 221 |
-
pred_all_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) + 1e-6
|
| 222 |
-
)
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
render_images = images
|
| 228 |
-
gt_colors = render_images.permute(0, 1, 3, 4, 2)
|
| 229 |
-
|
| 230 |
-
# Handle pure inference case where views may not have ground truth data
|
| 231 |
-
gt_depths = views.get("depthmap")
|
| 232 |
-
valid_masks = None
|
| 233 |
-
if gt_depths is not None:
|
| 234 |
-
if views.get("gt_depth") is not None and views["gt_depth"]:
|
| 235 |
-
unproject_masks = calculate_unprojected_mask(views, S)
|
| 236 |
-
gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
|
| 237 |
-
gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
|
| 238 |
-
gt_valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
|
| 239 |
-
else:
|
| 240 |
-
gt_valid_masks = views.get("valid_mask")
|
| 241 |
-
valid_masks = gt_valid_masks
|
| 242 |
|
| 243 |
# 3) Generate splats from gs_params + predictions, and perform voxel merging
|
| 244 |
-
|
| 245 |
-
splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+gtcamera", debug=False)
|
| 246 |
-
else:
|
| 247 |
-
splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+predcamera", context_predictions=context_predictions, debug=False)
|
| 248 |
splats_raw = {k: v.clone() for k, v in splats.items()}
|
| 249 |
|
| 250 |
# Apply confidence filtering before pruning
|
|
@@ -255,19 +217,6 @@ class GaussianSplatRenderer(nn.Module):
|
|
| 255 |
splats = self.prune_gs(splats, voxel_size=self.voxel_size)
|
| 256 |
|
| 257 |
# 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
|
| 258 |
-
if self.training:
|
| 259 |
-
if self.render_novel_views and V > 0:
|
| 260 |
-
indices = np.arange(S+V)
|
| 261 |
-
else:
|
| 262 |
-
indices = np.arange(S)
|
| 263 |
-
|
| 264 |
-
render_viewmats = render_viewmats[:, indices]
|
| 265 |
-
render_Ks = render_Ks[:, indices]
|
| 266 |
-
gt_colors = gt_colors[:, indices]
|
| 267 |
-
if gt_depths is not None:
|
| 268 |
-
gt_depths = gt_depths[:, indices]
|
| 269 |
-
if valid_masks is not None:
|
| 270 |
-
valid_masks = valid_masks[:, indices]
|
| 271 |
|
| 272 |
# Prevent OOM by using chunked rendering
|
| 273 |
rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
|
|
@@ -292,16 +241,10 @@ class GaussianSplatRenderer(nn.Module):
|
|
| 292 |
rendered_depths = torch.cat(rendered_depths_list, dim=1)
|
| 293 |
rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
|
| 294 |
|
| 295 |
-
if self.training and self.render_novel_views and V > 0:
|
| 296 |
-
nvs_rendered_mask = rendered_alphas[:, S:, ..., 0].detach() > 0.1
|
| 297 |
-
valid_masks[:, S:] = nvs_rendered_mask & valid_masks[:, S:]
|
| 298 |
-
|
| 299 |
# 5) return predictions
|
| 300 |
predictions["rendered_colors"] = rendered_colors
|
| 301 |
predictions["rendered_depths"] = rendered_depths
|
| 302 |
predictions["gt_colors"] = gt_colors
|
| 303 |
-
predictions["gt_depths"] = gt_depths
|
| 304 |
-
predictions["valid_masks"] = valid_masks
|
| 305 |
predictions["splats"] = splats
|
| 306 |
predictions["splats_raw"] = splats_raw
|
| 307 |
predictions["rendered_extrinsics"] = render_viewmats
|
|
|
|
| 184 |
# 1) Predict GS features from tokens, then convert to Gaussian parameters
|
| 185 |
gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
|
| 186 |
gs_params = self.gs_head(gs_feats_reshape)
|
| 187 |
+
|
| 188 |
+
# 2) Select predicted cameras
|
| 189 |
+
Bx = images.shape[0]
|
| 190 |
+
pred_all_extrinsic, pred_all_intrinsic = self.prepare_prediction_cameras(predictions, S + V, hw=(H, W))
|
| 191 |
+
pred_all_extrinsic = pred_all_extrinsic.reshape(Bx, S + V, 4, 4)
|
| 192 |
+
pred_all_source_extrinsic = pred_all_extrinsic[:, :S]
|
| 193 |
+
|
| 194 |
+
scale_factor = 1.0
|
| 195 |
+
if context_predictions is not None:
|
| 196 |
+
pred_source_extrinsic, _ = self.prepare_prediction_cameras(context_predictions, S, hw=(H, W))
|
| 197 |
+
pred_source_extrinsic = pred_source_extrinsic.reshape(Bx, S, 4, 4)
|
| 198 |
+
scale_factor = pred_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) / (
|
| 199 |
+
pred_all_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) + 1e-6
|
| 200 |
+
)
|
| 201 |
|
| 202 |
+
pred_all_extrinsic[..., :3, 3] = pred_all_extrinsic[..., :3, 3] * scale_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
|
| 205 |
+
render_images = images
|
| 206 |
+
gt_colors = render_images.permute(0, 1, 3, 4, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# 3) Generate splats from gs_params + predictions, and perform voxel merging
|
| 209 |
+
splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+predcamera", context_predictions=context_predictions, debug=False)
|
|
|
|
|
|
|
|
|
|
| 210 |
splats_raw = {k: v.clone() for k, v in splats.items()}
|
| 211 |
|
| 212 |
# Apply confidence filtering before pruning
|
|
|
|
| 217 |
splats = self.prune_gs(splats, voxel_size=self.voxel_size)
|
| 218 |
|
| 219 |
# 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
# Prevent OOM by using chunked rendering
|
| 222 |
rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
|
|
|
|
| 241 |
rendered_depths = torch.cat(rendered_depths_list, dim=1)
|
| 242 |
rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
# 5) return predictions
|
| 245 |
predictions["rendered_colors"] = rendered_colors
|
| 246 |
predictions["rendered_depths"] = rendered_depths
|
| 247 |
predictions["gt_colors"] = gt_colors
|
|
|
|
|
|
|
| 248 |
predictions["splats"] = splats
|
| 249 |
predictions["splats_raw"] = splats_raw
|
| 250 |
predictions["rendered_extrinsics"] = render_viewmats
|