import argparse import torch from model import GaussianSplatting2D from utils.misc_utils import load_cfg def get_gaussian_cfg(args): gaussian_cfg = f"num-{args.num_gaussians:d}" if args.disable_inverse_scale: gaussian_cfg += f"_scale-{args.init_scale:.1f}" else: gaussian_cfg += f"_inv-scale-{args.init_scale:.1f}" if not args.quantize: args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits = 32, 32, 32, 32 min_bits = min(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits) max_bits = max(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits) if min_bits < 4 or max_bits > 32: raise ValueError( f"Bit precision must be between 4 and 32 but got: {args.pos_bits:d}, {args.scale_bits:d}, {args.rot_bits:d}, {args.feat_bits:d}" ) gaussian_cfg += f"_bits-{args.pos_bits:d}-{args.scale_bits:d}-{args.rot_bits:d}-{args.feat_bits:d}" if not args.disable_topk_norm: gaussian_cfg += f"_top-{args.topk:d}" gaussian_cfg += f"_{args.init_mode[0]}-{args.init_random_ratio:.1f}" return gaussian_cfg def get_log_dir(args): gaussian_cfg = get_gaussian_cfg(args) loss_cfg = f"l1-{args.l1_loss_ratio:.1f}_l2-{args.l2_loss_ratio:.1f}_ssim-{args.ssim_loss_ratio:.1f}" folder = f"{gaussian_cfg}_{loss_cfg}" if args.downsample: folder += f"_ds-{args.downsample_ratio:.1f}" if not args.disable_lr_schedule: folder += f"_decay-{args.max_decay_times:d}-{args.decay_ratio:.1f}" if not args.disable_prog_optim: folder += "_prog" return f"{args.log_root}/{args.exp_name}/{folder}" def main(args): args.log_dir = get_log_dir(args) ImageGS = GaussianSplatting2D(args) if args.eval: ImageGS.render(render_height=args.render_height) else: ImageGS.optimize() if __name__ == "__main__": torch.hub.set_dir("models/torch") parser = argparse.ArgumentParser() parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser) arguments = parser.parse_args() main(arguments)