hoho / train.py
jskvrna's picture
Update to the hyperparameters, so they are the winning solution :)
32fd25b
"""
Training and evaluation script for HoHo wireframe prediction model.
This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline
using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports
configurable thresholds, visualization of results, and saves detailed performance metrics to files.
Key features:
- Command-line argument support for model configuration
- PointNet-based vertex and edge prediction
- Real-time performance monitoring and visualization
- Comprehensive metric evaluation and result logging
- Support for CUDA acceleration when available
"""
from datasets import load_dataset
from hoho2025.vis import plot_all_modalities
from hoho2025.viz3d import *
import pycolmap
import tempfile,zipfile
import io
import open3d as o3d
import os
import argparse # Added for command-line arguments
import numpy as np # Make sure numpy is imported if not already implicitly
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
from utils import read_colmap_rec, empty_solution
#from hoho2025.example_solutions import predict_wireframe
from hoho2025.metric_helper import hss
from predict import predict_wireframe, predict_wireframe_old
from tqdm import tqdm
from fast_pointnet_v2 import load_pointnet_model
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
import torch
import time
# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
parser.add_argument('--vertex_threshold', type=float, default=0.59, help='Vertex threshold for prediction.')
parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.')
parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.')
parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')
args = parser.parse_args()
# --- Configuration from Arguments ---
config = {
'vertex_threshold': args.vertex_threshold,
'edge_threshold': args.edge_threshold,
'only_predicted_connections': args.only_predicted_connections
}
print(f"Running with configuration: {config}")
# Create results directory if it doesn't exist
os.makedirs(args.results_dir, exist_ok=True)
ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True)
#ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_ALTERNATIVE_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True)
#ds = ds.shuffle()
scores_hss = []
scores_f1 = []
scores_iou = []
show_visu = True
device = "cuda" if torch.cuda.is_available() else "cpu"
#pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
#pnet_model = load_pointnet_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True)
#pnet_model = None
#pnet_class_model = load_pointnet_class_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device)
#pnet_class_model = load_pointnet_class_model_10d(model_path="YOUR_MODEL_PATH/initial_epoch_75.pth", device=device)
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
#pnet_class_model = None
#voxel_model = load_3dcnn_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True)
voxel_model = None
idx = 0
prediction_times = []
for a in tqdm(ds['train'], desc="Processing dataset"):
#plot_all_modalities(a)
#pred_vertices, pred_edges = predict_wireframe_old(a)
#pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
try:
start_time = time.time()
pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
#pred_vertices, pred_edges = predict_wireframe_old(a)
end_time = time.time()
prediction_time = end_time - start_time
prediction_times.append(prediction_time)
if prediction_times: # ensure not empty before calculating mean
mean_time = np.mean(prediction_times)
print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
else:
print(f"Prediction time: {prediction_time:.4f} seconds")
except Exception as e: # Catch specific exceptions if possible, or log the error
print(f"Error during prediction: {e}")
pred_vertices, pred_edges = empty_solution()
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
print(f"Score: {score}")
scores_hss.append(score.hss)
scores_f1.append(score.f1)
scores_iou.append(score.iou)
if show_visu:
colmap = read_colmap_rec(a['colmap_binary'])
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
#wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
visu_all = [pcd] + geometries + wireframe + bpo_cams #+ wireframe2
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
idx += 1
if idx >= args.max_samples:
print(f"Reached max_samples limit: {args.max_samples}")
break
for i in range(10): # This loop seems to be for console output spacing
print("END OF DATASET")
mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0
mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0
mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0
print(f"Mean HSS: {mean_hss_val:.4f}")
print(f"Mean F1: {mean_f1_val:.4f}")
print(f"Mean IoU: {mean_iou_val:.4f}")
print(f"Final Config: {config}")
if prediction_times:
print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds")
# --- Writing results to a file ---
vt_str = str(config['vertex_threshold']).replace('.', 'p')
et_str = str(config['edge_threshold']).replace('.', 'p')
opc_str = str(config['only_predicted_connections'])
results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt"
results_filepath = os.path.join(args.results_dir, results_filename)
with open(results_filepath, 'w') as f:
f.write(f"Configuration: {config}\n")
f.write(f"Max Samples Processed: {args.max_samples}\n")
f.write(f"Mean HSS: {mean_hss_val:.4f}\n")
f.write(f"Mean F1: {mean_f1_val:.4f}\n")
f.write(f"Mean IoU: {mean_iou_val:.4f}\n")
if prediction_times:
f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n")
f.write("\nIndividual HSS Scores:\n")
for s_hss in scores_hss:
f.write(f"{s_hss:.4f}\n")
f.write("\nIndividual F1 Scores:\n")
for s_f1 in scores_f1:
f.write(f"{s_f1:.4f}\n")
f.write("\nIndividual IoU Scores:\n")
for s_iou in scores_iou:
f.write(f"{s_iou:.4f}\n")
print(f"Results saved to {results_filepath}")