File size: 7,563 Bytes
33113fd c22c8c5 6d115a4 4949317 c22c8c5 e36bb23 2b67e89 c23c0b3 33113fd 6d115a4 1d64568 c22c8c5 4949317 32fd25b f911a86 ce848f5 4949317 9518589 33113fd a70b55e c22c8c5 33113fd 6d115a4 c23c0b3 9518589 c6e00e0 d9fc230 9518589 c6e00e0 c22c8c5 9518589 e36bb23 d9fc230 c22c8c5 1d64568 c1c37b0 657c8f1 e36bb23 b594fed c22c8c5 1d64568 e36bb23 1d64568 4949317 c22c8c5 e36bb23 d9fc230 1904e97 c22c8c5 33113fd c22c8c5 33113fd e36bb23 c22c8c5 d9fc230 4949317 1d64568 d9fc230 4949317 c22c8c5 4949317 9518589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Training and evaluation script for HoHo wireframe prediction model.
This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline
using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports
configurable thresholds, visualization of results, and saves detailed performance metrics to files.
Key features:
- Command-line argument support for model configuration
- PointNet-based vertex and edge prediction
- Real-time performance monitoring and visualization
- Comprehensive metric evaluation and result logging
- Support for CUDA acceleration when available
"""
from datasets import load_dataset
from hoho2025.vis import plot_all_modalities
from hoho2025.viz3d import *
import pycolmap
import tempfile,zipfile
import io
import open3d as o3d
import os
import argparse # Added for command-line arguments
import numpy as np # Make sure numpy is imported if not already implicitly
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
from utils import read_colmap_rec, empty_solution
#from hoho2025.example_solutions import predict_wireframe
from hoho2025.metric_helper import hss
from predict import predict_wireframe, predict_wireframe_old
from tqdm import tqdm
from fast_pointnet_v2 import load_pointnet_model
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
import torch
import time
# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
parser.add_argument('--vertex_threshold', type=float, default=0.59, help='Vertex threshold for prediction.')
parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.')
parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.')
parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')
args = parser.parse_args()
# --- Configuration from Arguments ---
config = {
'vertex_threshold': args.vertex_threshold,
'edge_threshold': args.edge_threshold,
'only_predicted_connections': args.only_predicted_connections
}
print(f"Running with configuration: {config}")
# Create results directory if it doesn't exist
os.makedirs(args.results_dir, exist_ok=True)
ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True)
#ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_ALTERNATIVE_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True)
#ds = ds.shuffle()
scores_hss = []
scores_f1 = []
scores_iou = []
show_visu = True
device = "cuda" if torch.cuda.is_available() else "cpu"
#pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
#pnet_model = load_pointnet_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True)
#pnet_model = None
#pnet_class_model = load_pointnet_class_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device)
#pnet_class_model = load_pointnet_class_model_10d(model_path="YOUR_MODEL_PATH/initial_epoch_75.pth", device=device)
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
#pnet_class_model = None
#voxel_model = load_3dcnn_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True)
voxel_model = None
idx = 0
prediction_times = []
for a in tqdm(ds['train'], desc="Processing dataset"):
#plot_all_modalities(a)
#pred_vertices, pred_edges = predict_wireframe_old(a)
#pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
try:
start_time = time.time()
pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
#pred_vertices, pred_edges = predict_wireframe_old(a)
end_time = time.time()
prediction_time = end_time - start_time
prediction_times.append(prediction_time)
if prediction_times: # ensure not empty before calculating mean
mean_time = np.mean(prediction_times)
print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
else:
print(f"Prediction time: {prediction_time:.4f} seconds")
except Exception as e: # Catch specific exceptions if possible, or log the error
print(f"Error during prediction: {e}")
pred_vertices, pred_edges = empty_solution()
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
print(f"Score: {score}")
scores_hss.append(score.hss)
scores_f1.append(score.f1)
scores_iou.append(score.iou)
if show_visu:
colmap = read_colmap_rec(a['colmap_binary'])
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
#wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
visu_all = [pcd] + geometries + wireframe + bpo_cams #+ wireframe2
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
idx += 1
if idx >= args.max_samples:
print(f"Reached max_samples limit: {args.max_samples}")
break
for i in range(10): # This loop seems to be for console output spacing
print("END OF DATASET")
mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0
mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0
mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0
print(f"Mean HSS: {mean_hss_val:.4f}")
print(f"Mean F1: {mean_f1_val:.4f}")
print(f"Mean IoU: {mean_iou_val:.4f}")
print(f"Final Config: {config}")
if prediction_times:
print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds")
# --- Writing results to a file ---
vt_str = str(config['vertex_threshold']).replace('.', 'p')
et_str = str(config['edge_threshold']).replace('.', 'p')
opc_str = str(config['only_predicted_connections'])
results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt"
results_filepath = os.path.join(args.results_dir, results_filename)
with open(results_filepath, 'w') as f:
f.write(f"Configuration: {config}\n")
f.write(f"Max Samples Processed: {args.max_samples}\n")
f.write(f"Mean HSS: {mean_hss_val:.4f}\n")
f.write(f"Mean F1: {mean_f1_val:.4f}\n")
f.write(f"Mean IoU: {mean_iou_val:.4f}\n")
if prediction_times:
f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n")
f.write("\nIndividual HSS Scores:\n")
for s_hss in scores_hss:
f.write(f"{s_hss:.4f}\n")
f.write("\nIndividual F1 Scores:\n")
for s_f1 in scores_f1:
f.write(f"{s_f1:.4f}\n")
f.write("\nIndividual IoU Scores:\n")
for s_iou in scores_iou:
f.write(f"{s_iou:.4f}\n")
print(f"Results saved to {results_filepath}")
|