File size: 7,563 Bytes
33113fd
 
 
 
 
 
 
 
 
 
 
 
c22c8c5
 
 
 
 
 
 
6d115a4
4949317
 
c22c8c5
 
 
 
 
 
e36bb23
2b67e89
c23c0b3
33113fd
6d115a4
1d64568
c22c8c5
4949317
 
32fd25b
 
f911a86
ce848f5
4949317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9518589
 
33113fd
a70b55e
c22c8c5
 
 
 
33113fd
6d115a4
 
 
c23c0b3
9518589
 
c6e00e0
d9fc230
9518589
 
 
c6e00e0
c22c8c5
9518589
e36bb23
 
d9fc230
c22c8c5
1d64568
c1c37b0
657c8f1
e36bb23
b594fed
c22c8c5
1d64568
 
e36bb23
1d64568
 
 
4949317
 
 
 
 
 
 
c22c8c5
 
e36bb23
 
 
 
 
 
d9fc230
1904e97
c22c8c5
 
33113fd
c22c8c5
 
33113fd
e36bb23
c22c8c5
d9fc230
4949317
 
1d64568
d9fc230
4949317
c22c8c5
 
4949317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9518589
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Training and evaluation script for HoHo wireframe prediction model.
This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline
using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports
configurable thresholds, visualization of results, and saves detailed performance metrics to files.
Key features:
- Command-line argument support for model configuration
- PointNet-based vertex and edge prediction
- Real-time performance monitoring and visualization
- Comprehensive metric evaluation and result logging
- Support for CUDA acceleration when available
"""
from datasets import load_dataset
from hoho2025.vis import plot_all_modalities
from hoho2025.viz3d import *
import pycolmap
import tempfile,zipfile
import io
import open3d as o3d
import os
import argparse # Added for command-line arguments
import numpy as np # Make sure numpy is imported if not already implicitly

from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
from utils import read_colmap_rec, empty_solution

#from hoho2025.example_solutions import predict_wireframe
from hoho2025.metric_helper import hss
from predict import predict_wireframe, predict_wireframe_old
from tqdm import tqdm
from fast_pointnet_v2 import load_pointnet_model
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
import torch
import time

# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
parser.add_argument('--vertex_threshold', type=float, default=0.59, help='Vertex threshold for prediction.')
parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.')
parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.')
parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')


args = parser.parse_args()

# --- Configuration from Arguments ---
config = {
    'vertex_threshold': args.vertex_threshold,
    'edge_threshold': args.edge_threshold,
    'only_predicted_connections': args.only_predicted_connections
}
print(f"Running with configuration: {config}")

# Create results directory if it doesn't exist
os.makedirs(args.results_dir, exist_ok=True)


ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True)
#ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_ALTERNATIVE_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True)
#ds = ds.shuffle()
 
scores_hss = []
scores_f1 = []
scores_iou = []

show_visu = True

device = "cuda" if torch.cuda.is_available() else "cpu"

#pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
#pnet_model = load_pointnet_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True)
#pnet_model = None

#pnet_class_model = load_pointnet_class_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device)
#pnet_class_model = load_pointnet_class_model_10d(model_path="YOUR_MODEL_PATH/initial_epoch_75.pth", device=device)
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
#pnet_class_model = None

#voxel_model = load_3dcnn_model(model_path="YOUR_MODEL_PATH/initial_epoch_100.pth", device=device, predict_score=True)
voxel_model = None


idx = 0
prediction_times = []
for a in tqdm(ds['train'], desc="Processing dataset"):
    #plot_all_modalities(a)
    #pred_vertices, pred_edges = predict_wireframe_old(a)
    #pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
    try:
        start_time = time.time()
        pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config)
        #pred_vertices, pred_edges = predict_wireframe_old(a)
        end_time = time.time()
        prediction_time = end_time - start_time
        prediction_times.append(prediction_time)
        if prediction_times: # ensure not empty before calculating mean
            mean_time = np.mean(prediction_times)
            print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")  
        else:
            print(f"Prediction time: {prediction_time:.4f} seconds")
    except Exception as e: # Catch specific exceptions if possible, or log the error
        print(f"Error during prediction: {e}")
        pred_vertices, pred_edges = empty_solution()

    score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
    print(f"Score: {score}")
    scores_hss.append(score.hss)
    scores_f1.append(score.f1)
    scores_iou.append(score.iou)

    if show_visu:
        colmap = read_colmap_rec(a['colmap_binary'])
        pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
        wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
        #wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
        bpo_cams = plot_bpo_cameras_from_entry_local(None, a)

        visu_all = [pcd] + geometries + wireframe + bpo_cams #+ wireframe2
        o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")

    idx += 1
    if idx >= args.max_samples:
        print(f"Reached max_samples limit: {args.max_samples}")
        break

for i in range(10): # This loop seems to be for console output spacing
    print("END OF DATASET")

mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0
mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0
mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0


print(f"Mean HSS: {mean_hss_val:.4f}")
print(f"Mean F1: {mean_f1_val:.4f}")
print(f"Mean IoU: {mean_iou_val:.4f}")
print(f"Final Config: {config}")
if prediction_times:
    print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds")


# --- Writing results to a file ---
vt_str = str(config['vertex_threshold']).replace('.', 'p')
et_str = str(config['edge_threshold']).replace('.', 'p')
opc_str = str(config['only_predicted_connections'])

results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt"
results_filepath = os.path.join(args.results_dir, results_filename)

with open(results_filepath, 'w') as f:
    f.write(f"Configuration: {config}\n")
    f.write(f"Max Samples Processed: {args.max_samples}\n")
    f.write(f"Mean HSS: {mean_hss_val:.4f}\n")
    f.write(f"Mean F1: {mean_f1_val:.4f}\n")
    f.write(f"Mean IoU: {mean_iou_val:.4f}\n")
    if prediction_times:
        f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n")
    f.write("\nIndividual HSS Scores:\n")
    for s_hss in scores_hss:
        f.write(f"{s_hss:.4f}\n")
    f.write("\nIndividual F1 Scores:\n")
    for s_f1 in scores_f1:
        f.write(f"{s_f1:.4f}\n")
    f.write("\nIndividual IoU Scores:\n")
    for s_iou in scores_iou:
        f.write(f"{s_iou:.4f}\n")


print(f"Results saved to {results_filepath}")