File size: 10,216 Bytes
c8df52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np

from utils import zero_init, EMANorm, create_rays

import einops

from .render import gaussian_render

from utils import quaternion_to_matrix

def inverse_sigmoid(x):
    if type(x) == torch.Tensor:
        return torch.log(x/(1-x))
    else:
        return math.log(x/(1-x))

def inverse_softplus(x, beta=1):
    if type(x) == torch.Tensor:
        return (torch.exp(beta * x) - 1).log() / beta
    else:
        return math.log((math.exp(beta * x) - 1)) / beta

import copy

import math
import torch
import torch.nn as nn
import numpy as np

from .autoencoder_kl_wan import WanCausalConv3d, WanRMS_norm, unpatchify


class WANDecoderPixelAligned3DGSReconstructionModel(nn.Module):
    def __init__(self, 
                 vae_model, 
                 feat_dim,
                #  num_remove_decoder_up_blocks=0,
                #  num_points_per_pixel=4,
                 use_network_checkpointing=True,
                 use_render_checkpointing=True
        ):
        super().__init__()

        self.decoder = copy.deepcopy(vae_model.decoder).requires_grad_(True)
        self.post_quant_conv = copy.deepcopy(vae_model.post_quant_conv).requires_grad_(True)

        self.extra_conv_in = WanCausalConv3d(feat_dim, self.decoder.conv_in.weight.shape[0], 3, padding=1)

        time_pad = self.extra_conv_in._padding[4]
        self.extra_conv_in.padding = (0, self.extra_conv_in._padding[2], self.extra_conv_in._padding[0])
        self.extra_conv_in._padding = (0, 0, 0, 0, 0, 0)
        self.extra_conv_in.weight = torch.nn.Parameter(self.extra_conv_in.weight[:, :, time_pad:].clone())

        with torch.no_grad():
            self.extra_conv_in.weight.data.zero_()
            self.extra_conv_in.bias.data.zero_()

        # remove one block
        # self.decoder.up_blocks = self.decoder.up_blocks[:-1]
        dims = [self.decoder.dim * u for u in [self.decoder.dim_mult[-1]] + self.decoder.dim_mult[::-1]]
        # self.decoder.up_blocks[-1].upsampler.mode = None
        # self.decoder.up_blocks[-1].upsampler.resample = nn.Identity()
        # self.decoder.up_blocks[-1].avg_shortcut = None

        self.decoder.norm_out = WanRMS_norm(dims[-1], images=False, bias=False)
        self.decoder.conv_out = nn.Identity()

        # add ema_norm for vae
        # for i_level in reversed(range(len(self.decoder.up_blocks))):
        #     if self.decoder.up_blocks[i_level].upsampler is not None:
        #         self.decoder.up_blocks[i_level].upsampler.resample = nn.Sequential(
        #             self.decoder.up_blocks[i_level].upsampler.resample,
        #         )

        self.patch_size = vae_model.config.patch_size
        # assert dims[-1] % 4 == 0
        self.gs_head = PixelAligned3DGS(dims[-1], num_points_per_pixel=2)

        del self.decoder.up_blocks[0].upsampler.time_conv
        del self.decoder.up_blocks[1].upsampler.time_conv

        self.decoder.conv_out = nn.Identity()

        self.network_checkpointing = use_network_checkpointing
        self.render_checkpointing = use_render_checkpointing
    
    def decode(self, feats, z):
        ## conv1
        x = self.decoder.conv_in(self.post_quant_conv(z)) + self.extra_conv_in(feats)

        ## middle
        if self.network_checkpointing and torch.is_grad_enabled():
            x = torch.utils.checkpoint.checkpoint(self.decoder.mid_block, x, None, [0], use_reentrant=False)
        else:
            x = self.decoder.mid_block(x, None, [0])

        ## upsamples
        for i, up_block in enumerate(self.decoder.up_blocks):
            if self.network_checkpointing and torch.is_grad_enabled():
                x = torch.utils.checkpoint.checkpoint(up_block, x, None, [0], True, use_reentrant=False)
            else:
                x = up_block(x, None, [0], first_chunk=True)

        # head
        x = self.decoder.norm_out(x)
        x = self.decoder.nonlinearity(x)
        x = self.decoder.conv_out(x)

        # if self.patch_size is not None:
        #     x = unpatchify(x, patch_size=self.patch_size)

        return x

    def forward(self, feats, z, cameras):

        x = self.decode(feats, z).squeeze(2)

        gaussian_params = self.gs_head(x, cameras.flatten(0, 1)).unflatten(0, (cameras.shape[0], cameras.shape[1]))
        
        return gaussian_params
    
    # def forward(self, images, cameras, scene_chunk_lens):

    #     x, z, feats = self.encode(images)

    #     return self.reconstruct(x, z, feats, cameras, scene_chunk_lens)
    
    @torch.amp.autocast(device_type='cuda', enabled=False)
    def render(self, gaussian_params, camerass, height, width, bg_mode='random'):

        camerass = camerass.to(torch.float32)

        test_c2ws = torch.eye(4, device=camerass.device)[None][None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float()
        test_c2ws[:, :, :3, :3] = quaternion_to_matrix(camerass[:, :, :4])
        test_c2ws[:, :, :3, 3] = camerass[:, :, 4:7]

        test_intr = torch.eye(3, device=camerass.device)[None, None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float()
        fx, fy, cx, cy = camerass[:, :, 7:11].split([1, 1, 1, 1], dim=-1)

        test_intr = torch.cat([fx * width, fy * height, cx * width, cy * height], dim=-1)

        return gaussian_render(gaussian_params, test_c2ws, test_intr, width, height, use_checkpoint=self.render_checkpointing, sh_degree=self.gs_head.sh_degree, bg_mode=bg_mode)

from torch.autograd import Function

class _trunc_exp(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    def backward(ctx, g):
        x = ctx.saved_tensors[0]
        return g * torch.exp(x.clamp(-10, 10))

trunc_exp = _trunc_exp.apply

class PixelAligned3DGS(nn.Module):
    def __init__(
            self, 
            embed_dim, 
            sh_degree=2,
            use_mask=False, 
            scale_range=(0, 16), # related to pixel size
            num_points_per_pixel=1,
        ):
        super().__init__()

        self.sh_degree = sh_degree

        # sh, uv_offset, depth, opacity, scales, rotations
        # TODO: handle different sh_degree
        self.gaussian_channels = [3 * (self.sh_degree + 1) ** 2, 2, 1, 1, 3, 4, (1 if use_mask else 0)]

        self.gs_proj = nn.Conv2d(embed_dim, num_points_per_pixel * sum(self.gaussian_channels), 3, 1, 1)
        self.register_buffer("lrs_mul", torch.Tensor(
                [1] * 3 + # sh 0
                [0.5] * 3 * ((self.sh_degree + 1) ** 2 - 1) + # other sh
                [0.01] * 2 + # uv_offset
                [1] * 1 + # depth
                [1] * 1 + # opacity
                [1] * 3 + # scales
                [1] * 4 + # rotations
                [0.1] * (1 if use_mask else 0) #  mask
            ).repeat(num_points_per_pixel), persistent=True)

        self.lrs_mul = self.lrs_mul / self.lrs_mul.max()

        self.use_mask = use_mask

        self.scale_range = scale_range

        with torch.no_grad():
            self.gs_proj.weight.data.zero_()
            self.gs_proj.bias = nn.Parameter(torch.Tensor(
                [0.0] * 3 * (self.sh_degree + 1) ** 2 + # sh
                [0.0] * 2 + # uv_offset
                [math.log(1)] * 1 + # depth
                # [inverse_softplus(1)] * 1 + # depth
                [inverse_sigmoid(0.1)] * 1 + # opacity
                [inverse_sigmoid((1 - scale_range[0]) / (scale_range[1] - scale_range[0]))] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size)
                # [inverse_softplus(0.005)] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size)
                [1., 0, 0, 0] + # rotations
                [inverse_sigmoid(0.9)] * (1 if use_mask else 0) #  mask (default: 0.9)
            ).repeat(num_points_per_pixel) / self.lrs_mul)

        self.num_points_per_pixel = num_points_per_pixel

    @torch.amp.autocast(device_type='cuda', enabled=False)
    def forward(self, x, cameras):
        
        x = x.to(torch.float32)
        cameras = cameras.to(torch.float32)

        BN, _, h, w = x.shape

        local_gaussian_params = F.conv2d(x, self.gs_proj.weight * self.lrs_mul[:, None, None, None], self.gs_proj.bias * self.lrs_mul, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1))
        # local_gaussian_params = F.conv2d(x, self.gs_proj.weight, self.gs_proj.bias, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1))

        # batch * n_frame, num_points_per_pixel, c, h, w -> batch * n_frame, num_points_per_pixel, h, w, c
        local_gaussian_params = local_gaussian_params.permute(0, 1, 3, 4, 2)

        features, uv_offset, depth, opacity, scales, rotations, mask = local_gaussian_params.split(self.gaussian_channels, dim=-1)

        rays_o, rays_d = create_rays(cameras[:, None].repeat(1, self.num_points_per_pixel, 1), uv_offset=uv_offset, h=h, w=w)

        depth = trunc_exp(depth)
        # depth = F.softplus(depth, beta=1)
        xyz = (rays_o + depth * rays_d)

        # features = features.unflatten(-1, (-1, 3))

        opacity = torch.sigmoid(opacity)
        if self.use_mask:
            if torch.is_grad_enabled():
                mask = torch.sigmoid(mask)
                hard_mask = (mask > torch.rand_like(mask)).float()
                opacity = opacity * (mask + (hard_mask - mask).detach())
            else:
                mask = torch.sigmoid(mask)
                hard_mask = (mask > torch.rand_like(mask)).float()
                opacity = opacity * hard_mask

        fx, fy = cameras[:, 7:9].split([1, 1], dim=-1)
        fx, fy = fx / w, fy / h
        pixel_size = torch.sqrt(fx.pow(2) + fy.pow(2))[:, None, None, None] * depth
        scales = (torch.sigmoid(scales) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]) * pixel_size
        # scales = F.softplus(scales, beta=1)

        # It’s not required to be normalized for gspalt rasterization?
        rotations = torch.nn.functional.normalize(rotations, dim=-1)

        gaussian_params = torch.cat([xyz, opacity, scales, rotations, features], dim=-1)
    
        return gaussian_params