Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from omegaconf import OmegaConf | |
| import torch | |
| from torch import nn | |
| from .utils.misc import instantiate_from_config | |
| from ..utils import default, exists | |
| def load_model(): | |
| model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml")) | |
| # print(model_config) | |
| if hasattr(model_config, "model"): | |
| model_config = model_config.model | |
| ckpt_path = "./ckpt/checkpoints/aligned_shape_latents/shapevae-256.ckpt" | |
| model = instantiate_from_config(model_config, ckpt_path=ckpt_path) | |
| # model = model.cuda() | |
| model = model.eval() | |
| return model | |
| class ShapeConditioner(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim_latent = None | |
| ): | |
| super().__init__() | |
| self.model = load_model() | |
| self.dim_model_out = 768 | |
| dim_latent = default(dim_latent, self.dim_model_out) | |
| self.dim_latent = dim_latent | |
| def forward( | |
| self, | |
| shape = None, | |
| shape_embed = None, | |
| ): | |
| assert exists(shape) ^ exists(shape_embed) | |
| if not exists(shape_embed): | |
| point_feature = self.model.encode_latents(shape) | |
| shape_latents = self.model.to_shape_latents(point_feature[:, 1:]) | |
| shape_head = point_feature[:, 0:1] | |
| shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1) | |
| # shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp | |
| return shape_head, shape_embed |