|
|
""" |
|
|
Training and evaluation script for HoHo wireframe prediction model. |
|
|
This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline |
|
|
using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports |
|
|
configurable thresholds, visualization of results, and saves detailed performance metrics to files. |
|
|
Key features: |
|
|
- Command-line argument support for model configuration |
|
|
- PointNet-based vertex and edge prediction |
|
|
- Real-time performance monitoring and visualization |
|
|
- Comprehensive metric evaluation and result logging |
|
|
- Support for CUDA acceleration when available |
|
|
""" |
|
|
from datasets import load_dataset |
|
|
from hoho2025.vis import plot_all_modalities |
|
|
from hoho2025.viz3d import * |
|
|
import pycolmap |
|
|
import tempfile,zipfile |
|
|
import io |
|
|
import open3d as o3d |
|
|
import os |
|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color |
|
|
from utils import read_colmap_rec, empty_solution |
|
|
|
|
|
|
|
|
from hoho2025.metric_helper import hss |
|
|
from predict import predict_wireframe, predict_wireframe_old |
|
|
from tqdm import tqdm |
|
|
from fast_pointnet_v2 import load_pointnet_model |
|
|
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model |
|
|
import torch |
|
|
import time |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.") |
|
|
parser.add_argument('--vertex_threshold', type=float, default=0.59, help='Vertex threshold for prediction.') |
|
|
parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.') |
|
|
parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).') |
|
|
parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.') |
|
|
parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.') |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config = { |
|
|
'vertex_threshold': args.vertex_threshold, |
|
|
'edge_threshold': args.edge_threshold, |
|
|
'only_predicted_connections': args.only_predicted_connections |
|
|
} |
|
|
print(f"Running with configuration: {config}") |
|
|
|
|
|
|
|
|
os.makedirs(args.results_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True) |
|
|
|
|
|
|
|
|
|
|
|
scores_hss = [] |
|
|
scores_f1 = [] |
|
|
scores_iou = [] |
|
|
|
|
|
show_visu = True |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device) |
|
|
|
|
|
|
|
|
|
|
|
voxel_model = None |
|
|
|
|
|
|
|
|
idx = 0 |
|
|
prediction_times = [] |
|
|
for a in tqdm(ds['train'], desc="Processing dataset"): |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config) |
|
|
|
|
|
end_time = time.time() |
|
|
prediction_time = end_time - start_time |
|
|
prediction_times.append(prediction_time) |
|
|
if prediction_times: |
|
|
mean_time = np.mean(prediction_times) |
|
|
print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds") |
|
|
else: |
|
|
print(f"Prediction time: {prediction_time:.4f} seconds") |
|
|
except Exception as e: |
|
|
print(f"Error during prediction: {e}") |
|
|
pred_vertices, pred_edges = empty_solution() |
|
|
|
|
|
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5) |
|
|
print(f"Score: {score}") |
|
|
scores_hss.append(score.hss) |
|
|
scores_f1.append(score.f1) |
|
|
scores_iou.append(score.iou) |
|
|
|
|
|
if show_visu: |
|
|
colmap = read_colmap_rec(a['colmap_binary']) |
|
|
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True) |
|
|
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications']) |
|
|
|
|
|
bpo_cams = plot_bpo_cameras_from_entry_local(None, a) |
|
|
|
|
|
visu_all = [pcd] + geometries + wireframe + bpo_cams |
|
|
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}") |
|
|
|
|
|
idx += 1 |
|
|
if idx >= args.max_samples: |
|
|
print(f"Reached max_samples limit: {args.max_samples}") |
|
|
break |
|
|
|
|
|
for i in range(10): |
|
|
print("END OF DATASET") |
|
|
|
|
|
mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0 |
|
|
mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0 |
|
|
mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0 |
|
|
|
|
|
|
|
|
print(f"Mean HSS: {mean_hss_val:.4f}") |
|
|
print(f"Mean F1: {mean_f1_val:.4f}") |
|
|
print(f"Mean IoU: {mean_iou_val:.4f}") |
|
|
print(f"Final Config: {config}") |
|
|
if prediction_times: |
|
|
print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds") |
|
|
|
|
|
|
|
|
|
|
|
vt_str = str(config['vertex_threshold']).replace('.', 'p') |
|
|
et_str = str(config['edge_threshold']).replace('.', 'p') |
|
|
opc_str = str(config['only_predicted_connections']) |
|
|
|
|
|
results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt" |
|
|
results_filepath = os.path.join(args.results_dir, results_filename) |
|
|
|
|
|
with open(results_filepath, 'w') as f: |
|
|
f.write(f"Configuration: {config}\n") |
|
|
f.write(f"Max Samples Processed: {args.max_samples}\n") |
|
|
f.write(f"Mean HSS: {mean_hss_val:.4f}\n") |
|
|
f.write(f"Mean F1: {mean_f1_val:.4f}\n") |
|
|
f.write(f"Mean IoU: {mean_iou_val:.4f}\n") |
|
|
if prediction_times: |
|
|
f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n") |
|
|
f.write("\nIndividual HSS Scores:\n") |
|
|
for s_hss in scores_hss: |
|
|
f.write(f"{s_hss:.4f}\n") |
|
|
f.write("\nIndividual F1 Scores:\n") |
|
|
for s_f1 in scores_f1: |
|
|
f.write(f"{s_f1:.4f}\n") |
|
|
f.write("\nIndividual IoU Scores:\n") |
|
|
for s_iou in scores_iou: |
|
|
f.write(f"{s_iou:.4f}\n") |
|
|
|
|
|
|
|
|
print(f"Results saved to {results_filepath}") |
|
|
|
|
|
|