Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| import torch | |
| import nvdiffrast.torch as dr | |
| from . import util | |
| from . import renderutils as ru | |
| from . import light | |
| from .texture import Texture2D | |
| # ============================================================================================== | |
| # Helper functions | |
| # ============================================================================================== | |
| def interpolate(attr, rast, attr_idx, rast_db=None): | |
| return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') | |
| # ============================================================================================== | |
| # pixel shader | |
| # ============================================================================================== | |
| def shade( | |
| gb_pos, | |
| gb_geometric_normal, | |
| gb_normal, | |
| gb_tangent, | |
| gb_tex_pos, | |
| gb_texc, | |
| gb_texc_deriv, | |
| w2c, | |
| view_pos, | |
| lgt, | |
| material, | |
| bsdf, | |
| feat, | |
| two_sided_shading, | |
| delta_xy_interp=None, | |
| dino_pred=None, | |
| class_vector=None, | |
| im_features_map=None, | |
| mvp=None | |
| ): | |
| ################################################################################ | |
| # Texture lookups | |
| ################################################################################ | |
| perturbed_nrm = None | |
| # Combined texture, used for MLPs because lookups are expensive | |
| # all_tex_jitter = material.sample(gb_tex_pos + torch.normal(mean=0, std=0.01, size=gb_tex_pos.shape, device="cuda"), feat=feat) | |
| if isinstance(material, Texture2D): | |
| all_tex = material.sample(gb_texc, gb_texc_deriv) | |
| elif material is not None: | |
| if im_features_map is None: | |
| all_tex = material.sample(gb_tex_pos, feat=feat) | |
| else: | |
| all_tex = material.sample(gb_tex_pos, feat=feat, feat_map=im_features_map, mvp=mvp, w2c=w2c, deform_xyz=gb_pos) | |
| else: | |
| all_tex = torch.ones(*gb_pos.shape[:-1], 9, device=gb_pos.device) | |
| kd, ks, perturbed_nrm = all_tex[..., :3], all_tex[..., 3:6], all_tex[..., 6:9] | |
| # Compute albedo (kd) gradient, used for material regularizer | |
| # kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / | |
| if dino_pred is not None and class_vector is None: | |
| # DOR: predive the dino value using x,y,z, we would concatenate the label vector. | |
| # trained together, generated image as the supervision for the one-hot-vector. | |
| dino_feat_im_pred = dino_pred.sample(gb_tex_pos) | |
| # dino_feat_im_pred = dino_pred.sample(gb_tex_pos.detach()) | |
| if dino_pred is not None and class_vector is not None: | |
| dino_feat_im_pred = dino_pred.sample(gb_tex_pos, feat=class_vector) | |
| # else: | |
| # kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv) | |
| # kd = material['kd'].sample(gb_texc, gb_texc_deriv) | |
| # ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha | |
| # if 'normal' in material: | |
| # perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv) | |
| # kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3 | |
| # Separate kd into alpha and color, default alpha = 1 | |
| # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) | |
| # kd = kd[..., 0:3] | |
| alpha = torch.ones_like(kd[..., 0:1]) | |
| ################################################################################ | |
| # Normal perturbation & normal bend | |
| ################################################################################ | |
| if material is None or isinstance(material, Texture2D) or not material.perturb_normal: | |
| perturbed_nrm = None | |
| gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=two_sided_shading, opengl=True, use_python=True) | |
| # if two_sided_shading: | |
| # view_vec = util.safe_normalize(view_pos - gb_pos, -1) | |
| # gb_normal = torch.where(torch.sum(gb_geometric_normal * view_vec, -1, keepdim=True) > 0, gb_geometric_normal, -gb_geometric_normal) | |
| # else: | |
| # gb_normal = gb_geometric_normal | |
| b, h, w, _ = gb_normal.shape | |
| cam_normal = util.safe_normalize(torch.matmul(gb_normal.view(b, -1, 3), w2c[:,:3,:3].transpose(2,1))).view(b, h, w, 3) | |
| ################################################################################ | |
| # Evaluate BSDF | |
| ################################################################################ | |
| assert bsdf is not None or material.bsdf is not None, "Material must specify a BSDF type" | |
| bsdf = bsdf if bsdf is not None else material.bsdf | |
| shading = None | |
| if bsdf == 'pbr': | |
| if isinstance(lgt, light.EnvironmentLight): | |
| shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) | |
| else: | |
| assert False, "Invalid light type" | |
| elif bsdf == 'diffuse': | |
| if lgt is None: | |
| shaded_col = kd | |
| elif isinstance(lgt, light.EnvironmentLight): | |
| shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False) | |
| # elif isinstance(lgt, light.DirectionalLight): | |
| # shaded_col, shading = lgt.shade(feat, kd, cam_normal) | |
| # else: | |
| # assert False, "Invalid light type" | |
| else: | |
| shaded_col, shading = lgt.shade(feat, kd, cam_normal) | |
| elif bsdf == 'normal': | |
| shaded_col = (gb_normal + 1.0) * 0.5 | |
| elif bsdf == 'geo_normal': | |
| shaded_col = (gb_geometric_normal + 1.0) * 0.5 | |
| elif bsdf == 'tangent': | |
| shaded_col = (gb_tangent + 1.0) * 0.5 | |
| elif bsdf == 'kd': | |
| shaded_col = kd | |
| elif bsdf == 'ks': | |
| shaded_col = ks | |
| else: | |
| assert False, "Invalid BSDF '%s'" % bsdf | |
| # Return multiple buffers | |
| buffers = { | |
| 'kd' : torch.cat((kd, alpha), dim=-1), | |
| 'shaded' : torch.cat((shaded_col, alpha), dim=-1), | |
| # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), | |
| # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1), | |
| } | |
| if dino_pred is not None: | |
| buffers['dino_feat_im_pred'] = torch.cat((dino_feat_im_pred, alpha), dim=-1) | |
| if delta_xy_interp is not None: | |
| buffers['flow'] = torch.cat((delta_xy_interp, alpha), dim=-1) | |
| if shading is not None: | |
| buffers['shading'] = torch.cat((shading, alpha), dim=-1) | |
| return buffers | |
| # ============================================================================================== | |
| # Render a depth slice of the mesh (scene), some limitations: | |
| # - Single light | |
| # - Single material | |
| # ============================================================================================== | |
| def render_layer( | |
| rast, | |
| rast_deriv, | |
| mesh, | |
| w2c, | |
| view_pos, | |
| material, | |
| lgt, | |
| resolution, | |
| spp, | |
| msaa, | |
| bsdf, | |
| feat, | |
| prior_mesh, | |
| two_sided_shading, | |
| render_flow, | |
| delta_xy=None, | |
| dino_pred=None, | |
| class_vector=None, | |
| im_features_map=None, | |
| mvp=None | |
| ): | |
| full_res = [resolution[0]*spp, resolution[1]*spp] | |
| if prior_mesh is None: | |
| prior_mesh = mesh | |
| ################################################################################ | |
| # Rasterize | |
| ################################################################################ | |
| # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution | |
| if spp > 1 and msaa: | |
| rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') | |
| rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp | |
| else: | |
| rast_out_s = rast | |
| rast_out_deriv_s = rast_deriv | |
| if render_flow: | |
| delta_xy_interp, _ = interpolate(delta_xy, rast_out_s, mesh.t_pos_idx[0].int()) | |
| else: | |
| delta_xy_interp = None | |
| ################################################################################ | |
| # Interpolate attributes | |
| ################################################################################ | |
| # Interpolate world space position | |
| gb_pos, _ = interpolate(mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int()) | |
| # Compute geometric normals. We need those because of bent normals trick (for bump mapping) | |
| v0 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 0], :] | |
| v1 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 1], :] | |
| v2 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 2], :] | |
| face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1)) | |
| num_faces = face_normals.shape[1] | |
| face_normal_indices = (torch.arange(0, num_faces, dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) | |
| gb_geometric_normal, _ = interpolate(face_normals, rast_out_s, face_normal_indices.int()) | |
| # Compute tangent space | |
| assert mesh.v_nrm is not None and mesh.v_tng is not None | |
| gb_normal, _ = interpolate(mesh.v_nrm, rast_out_s, mesh.t_nrm_idx[0].int()) | |
| gb_tangent, _ = interpolate(mesh.v_tng, rast_out_s, mesh.t_tng_idx[0].int()) # Interpolate tangents | |
| # Texture coordinate | |
| assert mesh.v_tex is not None | |
| gb_texc, gb_texc_deriv = interpolate(mesh.v_tex, rast_out_s, mesh.t_tex_idx[0].int(), rast_db=rast_out_deriv_s) | |
| ################################################################################ | |
| # Shade | |
| ################################################################################ | |
| gb_tex_pos, _ = interpolate(prior_mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int()) | |
| buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_tex_pos, gb_texc, gb_texc_deriv, w2c, view_pos, lgt, material, bsdf, feat=feat, two_sided_shading=two_sided_shading, delta_xy_interp=delta_xy_interp, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mvp) | |
| ################################################################################ | |
| # Prepare output | |
| ################################################################################ | |
| # Scale back up to visibility resolution if using MSAA | |
| if spp > 1 and msaa: | |
| for key in buffers.keys(): | |
| buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest') | |
| # Return buffers | |
| return buffers | |
| # ============================================================================================== | |
| # Render a depth peeled mesh (scene), some limitations: | |
| # - Single light | |
| # - Single material | |
| # ============================================================================================== | |
| def render_mesh( | |
| ctx, | |
| mesh, | |
| mtx_in, | |
| w2c, | |
| view_pos, | |
| material, | |
| lgt, | |
| resolution, | |
| spp = 1, | |
| num_layers = 1, | |
| msaa = False, | |
| background = None, | |
| bsdf = None, | |
| feat = None, | |
| prior_mesh = None, | |
| two_sided_shading = True, | |
| render_flow = False, | |
| dino_pred = None, | |
| class_vector = None, | |
| num_frames = None, | |
| im_features_map = None | |
| ): | |
| def prepare_input_vector(x): | |
| x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x | |
| return x[:, None, None, :] if len(x.shape) == 2 else x | |
| def composite_buffer(key, layers, background, antialias): | |
| accum = background | |
| for buffers, rast in reversed(layers): | |
| alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] | |
| accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) | |
| if antialias: | |
| accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx[0].int()) | |
| return accum | |
| assert mesh.t_pos_idx.shape[1] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" | |
| assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) | |
| full_res = [resolution[0] * spp, resolution[1] * spp] | |
| # Convert numpy arrays to torch tensors | |
| mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in | |
| view_pos = prepare_input_vector(view_pos) # Shape: (B, 1, 1, 3) | |
| # clip space transform | |
| v_pos_clip = ru.xfm_points(mesh.v_pos, mtx_in, use_python=True) | |
| # render flow | |
| if render_flow: | |
| v_pos_clip2 = v_pos_clip[..., :2] / v_pos_clip[..., -1:] | |
| v_pos_clip2 = v_pos_clip2.view(-1, num_frames, *v_pos_clip2.shape[1:]) | |
| delta_xy = v_pos_clip2[:, 1:] - v_pos_clip2[:, :-1] | |
| delta_xy = torch.cat([delta_xy, torch.zeros_like(delta_xy[:, :1])], dim=1) | |
| delta_xy = delta_xy.view(-1, *delta_xy.shape[2:]) | |
| else: | |
| delta_xy = None | |
| # Render all layers front-to-back | |
| layers = [] | |
| with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx[0].int(), full_res) as peeler: | |
| for _ in range(num_layers): | |
| rast, db = peeler.rasterize_next_layer() | |
| rendered = render_layer(rast, db, mesh, w2c, view_pos, material, lgt, resolution, spp, msaa, bsdf, feat=feat, prior_mesh=prior_mesh, two_sided_shading=two_sided_shading, render_flow=render_flow, delta_xy=delta_xy, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mtx_in) | |
| layers += [(rendered, rast)] | |
| # Setup background | |
| if background is not None: | |
| if spp > 1: | |
| background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') | |
| background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) | |
| else: | |
| background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') | |
| # Composite layers front-to-back | |
| out_buffers = {} | |
| for key in layers[0][0].keys(): | |
| antialias = key in ['shaded', 'dino_feat_im_pred', 'flow'] | |
| bg = background if key in ['shaded'] else torch.zeros_like(layers[0][0][key]) | |
| accum = composite_buffer(key, layers, bg, antialias) | |
| # Downscale to framebuffer resolution. Use avg pooling | |
| out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum | |
| return out_buffers | |
| # ============================================================================================== | |
| # Render UVs | |
| # ============================================================================================== | |
| def render_uv(ctx, mesh, resolution, mlp_texture, feat=None, prior_shape=None): | |
| # clip space transform | |
| uv_clip = mesh.v_tex * 2.0 - 1.0 | |
| # pad to four component coordinate | |
| uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) | |
| # rasterize | |
| rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx[0].int(), resolution) | |
| # Interpolate world space position | |
| if prior_shape is not None: | |
| gb_pos, _ = interpolate(prior_shape.v_pos, rast, mesh.t_pos_idx[0].int()) | |
| else: | |
| gb_pos, _ = interpolate(mesh.v_pos, rast, mesh.t_pos_idx[0].int()) | |
| # Sample out textures from MLP | |
| all_tex = mlp_texture.sample(gb_pos, feat=feat) | |
| assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" | |
| perturbed_nrm = all_tex[..., -3:] | |
| return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm) | |