Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import numpy as np | |
| from utils import zero_init, EMANorm, create_rays | |
| import einops | |
| from .render import gaussian_render | |
| from utils import quaternion_to_matrix | |
| def inverse_sigmoid(x): | |
| if type(x) == torch.Tensor: | |
| return torch.log(x/(1-x)) | |
| else: | |
| return math.log(x/(1-x)) | |
| def inverse_softplus(x, beta=1): | |
| if type(x) == torch.Tensor: | |
| return (torch.exp(beta * x) - 1).log() / beta | |
| else: | |
| return math.log((math.exp(beta * x) - 1)) / beta | |
| import copy | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from .autoencoder_kl_wan import WanCausalConv3d, WanRMS_norm, unpatchify | |
| class WANDecoderPixelAligned3DGSReconstructionModel(nn.Module): | |
| def __init__(self, | |
| vae_model, | |
| feat_dim, | |
| # num_remove_decoder_up_blocks=0, | |
| # num_points_per_pixel=4, | |
| use_network_checkpointing=True, | |
| use_render_checkpointing=True | |
| ): | |
| super().__init__() | |
| self.decoder = copy.deepcopy(vae_model.decoder).requires_grad_(True) | |
| self.post_quant_conv = copy.deepcopy(vae_model.post_quant_conv).requires_grad_(True) | |
| self.extra_conv_in = WanCausalConv3d(feat_dim, self.decoder.conv_in.weight.shape[0], 3, padding=1) | |
| time_pad = self.extra_conv_in._padding[4] | |
| self.extra_conv_in.padding = (0, self.extra_conv_in._padding[2], self.extra_conv_in._padding[0]) | |
| self.extra_conv_in._padding = (0, 0, 0, 0, 0, 0) | |
| self.extra_conv_in.weight = torch.nn.Parameter(self.extra_conv_in.weight[:, :, time_pad:].clone()) | |
| with torch.no_grad(): | |
| self.extra_conv_in.weight.data.zero_() | |
| self.extra_conv_in.bias.data.zero_() | |
| # remove one block | |
| # self.decoder.up_blocks = self.decoder.up_blocks[:-1] | |
| dims = [self.decoder.dim * u for u in [self.decoder.dim_mult[-1]] + self.decoder.dim_mult[::-1]] | |
| # self.decoder.up_blocks[-1].upsampler.mode = None | |
| # self.decoder.up_blocks[-1].upsampler.resample = nn.Identity() | |
| # self.decoder.up_blocks[-1].avg_shortcut = None | |
| self.decoder.norm_out = WanRMS_norm(dims[-1], images=False, bias=False) | |
| self.decoder.conv_out = nn.Identity() | |
| # add ema_norm for vae | |
| # for i_level in reversed(range(len(self.decoder.up_blocks))): | |
| # if self.decoder.up_blocks[i_level].upsampler is not None: | |
| # self.decoder.up_blocks[i_level].upsampler.resample = nn.Sequential( | |
| # self.decoder.up_blocks[i_level].upsampler.resample, | |
| # ) | |
| self.patch_size = vae_model.config.patch_size | |
| # assert dims[-1] % 4 == 0 | |
| self.gs_head = PixelAligned3DGS(dims[-1], num_points_per_pixel=2) | |
| del self.decoder.up_blocks[0].upsampler.time_conv | |
| del self.decoder.up_blocks[1].upsampler.time_conv | |
| self.decoder.conv_out = nn.Identity() | |
| self.network_checkpointing = use_network_checkpointing | |
| self.render_checkpointing = use_render_checkpointing | |
| def decode(self, feats, z): | |
| ## conv1 | |
| x = self.decoder.conv_in(self.post_quant_conv(z)) + self.extra_conv_in(feats) | |
| ## middle | |
| if self.network_checkpointing and torch.is_grad_enabled(): | |
| x = torch.utils.checkpoint.checkpoint(self.decoder.mid_block, x, None, [0], use_reentrant=False) | |
| else: | |
| x = self.decoder.mid_block(x, None, [0]) | |
| ## upsamples | |
| for i, up_block in enumerate(self.decoder.up_blocks): | |
| if self.network_checkpointing and torch.is_grad_enabled(): | |
| x = torch.utils.checkpoint.checkpoint(up_block, x, None, [0], True, use_reentrant=False) | |
| else: | |
| x = up_block(x, None, [0], first_chunk=True) | |
| # head | |
| x = self.decoder.norm_out(x) | |
| x = self.decoder.nonlinearity(x) | |
| x = self.decoder.conv_out(x) | |
| # if self.patch_size is not None: | |
| # x = unpatchify(x, patch_size=self.patch_size) | |
| return x | |
| def forward(self, feats, z, cameras): | |
| x = self.decode(feats, z).squeeze(2) | |
| gaussian_params = self.gs_head(x, cameras.flatten(0, 1)).unflatten(0, (cameras.shape[0], cameras.shape[1])) | |
| return gaussian_params | |
| # def forward(self, images, cameras, scene_chunk_lens): | |
| # x, z, feats = self.encode(images) | |
| # return self.reconstruct(x, z, feats, cameras, scene_chunk_lens) | |
| def render(self, gaussian_params, camerass, height, width, bg_mode='random'): | |
| camerass = camerass.to(torch.float32) | |
| test_c2ws = torch.eye(4, device=camerass.device)[None][None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float() | |
| test_c2ws[:, :, :3, :3] = quaternion_to_matrix(camerass[:, :, :4]) | |
| test_c2ws[:, :, :3, 3] = camerass[:, :, 4:7] | |
| test_intr = torch.eye(3, device=camerass.device)[None, None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float() | |
| fx, fy, cx, cy = camerass[:, :, 7:11].split([1, 1, 1, 1], dim=-1) | |
| test_intr = torch.cat([fx * width, fy * height, cx * width, cy * height], dim=-1) | |
| return gaussian_render(gaussian_params, test_c2ws, test_intr, width, height, use_checkpoint=self.render_checkpointing, sh_degree=self.gs_head.sh_degree, bg_mode=bg_mode) | |
| from torch.autograd import Function | |
| class _trunc_exp(Function): | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return torch.exp(x) | |
| def backward(ctx, g): | |
| x = ctx.saved_tensors[0] | |
| return g * torch.exp(x.clamp(-10, 10)) | |
| trunc_exp = _trunc_exp.apply | |
| class PixelAligned3DGS(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| sh_degree=2, | |
| use_mask=False, | |
| scale_range=(0, 16), # related to pixel size | |
| num_points_per_pixel=1, | |
| ): | |
| super().__init__() | |
| self.sh_degree = sh_degree | |
| # sh, uv_offset, depth, opacity, scales, rotations | |
| # TODO: handle different sh_degree | |
| self.gaussian_channels = [3 * (self.sh_degree + 1) ** 2, 2, 1, 1, 3, 4, (1 if use_mask else 0)] | |
| self.gs_proj = nn.Conv2d(embed_dim, num_points_per_pixel * sum(self.gaussian_channels), 3, 1, 1) | |
| self.register_buffer("lrs_mul", torch.Tensor( | |
| [1] * 3 + # sh 0 | |
| [0.5] * 3 * ((self.sh_degree + 1) ** 2 - 1) + # other sh | |
| [0.01] * 2 + # uv_offset | |
| [1] * 1 + # depth | |
| [1] * 1 + # opacity | |
| [1] * 3 + # scales | |
| [1] * 4 + # rotations | |
| [0.1] * (1 if use_mask else 0) # mask | |
| ).repeat(num_points_per_pixel), persistent=True) | |
| self.lrs_mul = self.lrs_mul / self.lrs_mul.max() | |
| self.use_mask = use_mask | |
| self.scale_range = scale_range | |
| with torch.no_grad(): | |
| self.gs_proj.weight.data.zero_() | |
| self.gs_proj.bias = nn.Parameter(torch.Tensor( | |
| [0.0] * 3 * (self.sh_degree + 1) ** 2 + # sh | |
| [0.0] * 2 + # uv_offset | |
| [math.log(1)] * 1 + # depth | |
| # [inverse_softplus(1)] * 1 + # depth | |
| [inverse_sigmoid(0.1)] * 1 + # opacity | |
| [inverse_sigmoid((1 - scale_range[0]) / (scale_range[1] - scale_range[0]))] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size) | |
| # [inverse_softplus(0.005)] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size) | |
| [1., 0, 0, 0] + # rotations | |
| [inverse_sigmoid(0.9)] * (1 if use_mask else 0) # mask (default: 0.9) | |
| ).repeat(num_points_per_pixel) / self.lrs_mul) | |
| self.num_points_per_pixel = num_points_per_pixel | |
| def forward(self, x, cameras): | |
| x = x.to(torch.float32) | |
| cameras = cameras.to(torch.float32) | |
| BN, _, h, w = x.shape | |
| local_gaussian_params = F.conv2d(x, self.gs_proj.weight * self.lrs_mul[:, None, None, None], self.gs_proj.bias * self.lrs_mul, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1)) | |
| # local_gaussian_params = F.conv2d(x, self.gs_proj.weight, self.gs_proj.bias, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1)) | |
| # batch * n_frame, num_points_per_pixel, c, h, w -> batch * n_frame, num_points_per_pixel, h, w, c | |
| local_gaussian_params = local_gaussian_params.permute(0, 1, 3, 4, 2) | |
| features, uv_offset, depth, opacity, scales, rotations, mask = local_gaussian_params.split(self.gaussian_channels, dim=-1) | |
| rays_o, rays_d = create_rays(cameras[:, None].repeat(1, self.num_points_per_pixel, 1), uv_offset=uv_offset, h=h, w=w) | |
| depth = trunc_exp(depth) | |
| # depth = F.softplus(depth, beta=1) | |
| xyz = (rays_o + depth * rays_d) | |
| # features = features.unflatten(-1, (-1, 3)) | |
| opacity = torch.sigmoid(opacity) | |
| if self.use_mask: | |
| if torch.is_grad_enabled(): | |
| mask = torch.sigmoid(mask) | |
| hard_mask = (mask > torch.rand_like(mask)).float() | |
| opacity = opacity * (mask + (hard_mask - mask).detach()) | |
| else: | |
| mask = torch.sigmoid(mask) | |
| hard_mask = (mask > torch.rand_like(mask)).float() | |
| opacity = opacity * hard_mask | |
| fx, fy = cameras[:, 7:9].split([1, 1], dim=-1) | |
| fx, fy = fx / w, fy / h | |
| pixel_size = torch.sqrt(fx.pow(2) + fy.pow(2))[:, None, None, None] * depth | |
| scales = (torch.sigmoid(scales) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]) * pixel_size | |
| # scales = F.softplus(scales, beta=1) | |
| # It’s not required to be normalized for gspalt rasterization? | |
| rotations = torch.nn.functional.normalize(rotations, dim=-1) | |
| gaussian_params = torch.cat([xyz, opacity, scales, rotations, features], dim=-1) | |
| return gaussian_params |