""" Training and evaluation script for HoHo wireframe prediction model. This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports configurable thresholds, visualization of results, and saves detailed performance metrics to files. Key features: - Command-line argument support for model configuration - PointNet-based vertex and edge prediction - Real-time performance monitoring and visualization - Comprehensive metric evaluation and result logging - Support for CUDA acceleration when available """ from datasets import load_dataset from hoho2025.vis import plot_all_modalities from hoho2025.viz3d import * import pycolmap import tempfile,zipfile import io import open3d as o3d import os import argparse # Added for command-line arguments import numpy as np # Make sure numpy is imported if not already implicitly from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color from utils import read_colmap_rec, empty_solution #from hoho2025.example_solutions import predict_wireframe from hoho2025.metric_helper import hss from predict import predict_wireframe, predict_wireframe_old from tqdm import tqdm from fast_pointnet_v2 import load_pointnet_model from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model import torch import time # --- Argument Parsing --- parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.") parser.add_argument('--vertex_threshold', type=float, default=0.59, help='Vertex threshold for prediction.') parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.') parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).') parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.') parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.') args = parser.parse_args() # --- Configuration from Arguments --- config = { 'vertex_threshold': args.vertex_threshold, 'edge_threshold': args.edge_threshold, 'only_predicted_connections': args.only_predicted_connections } print(f"Running with configuration: {config}") # Create results directory if it doesn't exist os.makedirs(args.results_dir, exist_ok=True) ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True) #ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_ALTERNATIVE_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True) #ds = ds.shuffle() scores_hss = [] scores_f1 = [] scores_iou = [] show_visu = True device = "cuda" if torch.cuda.is_available() else "cpu" #pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) #pnet_model = load_pointnet_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True) #pnet_model = None #pnet_class_model = load_pointnet_class_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device) #pnet_class_model = load_pointnet_class_model_10d(model_path="YOUR_MODEL_PATH/initial_epoch_75.pth", device=device) pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device) #pnet_class_model = None #voxel_model = load_3dcnn_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True) voxel_model = None idx = 0 prediction_times = [] for a in tqdm(ds['train'], desc="Processing dataset"): #plot_all_modalities(a) #pred_vertices, pred_edges = predict_wireframe_old(a) #pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config) try: start_time = time.time() pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config) #pred_vertices, pred_edges = predict_wireframe_old(a) end_time = time.time() prediction_time = end_time - start_time prediction_times.append(prediction_time) if prediction_times: # ensure not empty before calculating mean mean_time = np.mean(prediction_times) print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds") else: print(f"Prediction time: {prediction_time:.4f} seconds") except Exception as e: # Catch specific exceptions if possible, or log the error print(f"Error during prediction: {e}") pred_vertices, pred_edges = empty_solution() score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5) print(f"Score: {score}") scores_hss.append(score.hss) scores_f1.append(score.f1) scores_iou.append(score.iou) if show_visu: colmap = read_colmap_rec(a['colmap_binary']) pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True) wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications']) #wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)') bpo_cams = plot_bpo_cameras_from_entry_local(None, a) visu_all = [pcd] + geometries + wireframe + bpo_cams #+ wireframe2 o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}") idx += 1 if idx >= args.max_samples: print(f"Reached max_samples limit: {args.max_samples}") break for i in range(10): # This loop seems to be for console output spacing print("END OF DATASET") mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0 mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0 mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0 print(f"Mean HSS: {mean_hss_val:.4f}") print(f"Mean F1: {mean_f1_val:.4f}") print(f"Mean IoU: {mean_iou_val:.4f}") print(f"Final Config: {config}") if prediction_times: print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds") # --- Writing results to a file --- vt_str = str(config['vertex_threshold']).replace('.', 'p') et_str = str(config['edge_threshold']).replace('.', 'p') opc_str = str(config['only_predicted_connections']) results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt" results_filepath = os.path.join(args.results_dir, results_filename) with open(results_filepath, 'w') as f: f.write(f"Configuration: {config}\n") f.write(f"Max Samples Processed: {args.max_samples}\n") f.write(f"Mean HSS: {mean_hss_val:.4f}\n") f.write(f"Mean F1: {mean_f1_val:.4f}\n") f.write(f"Mean IoU: {mean_iou_val:.4f}\n") if prediction_times: f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n") f.write("\nIndividual HSS Scores:\n") for s_hss in scores_hss: f.write(f"{s_hss:.4f}\n") f.write("\nIndividual F1 Scores:\n") for s_f1 in scores_f1: f.write(f"{s_f1:.4f}\n") f.write("\nIndividual IoU Scores:\n") for s_iou in scores_iou: f.write(f"{s_iou:.4f}\n") print(f"Results saved to {results_filepath}")