Improves edge prediction with class PointNet
Browse filesThe best settings so far:
Mean HSS: 0.2896
Mean F1: 0.3718
Mean IoU: 0.2466
{'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
- predict.py +122 -8
- train.py +19 -8
predict.py
CHANGED
|
@@ -15,14 +15,15 @@ from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
|
|
| 15 |
from fast_voxel import predict_vertex_from_patch_voxel
|
| 16 |
import time
|
| 17 |
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
|
|
|
| 18 |
|
| 19 |
GENERATE_DATASET = False
|
| 20 |
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
| 21 |
#DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
|
| 22 |
|
| 23 |
-
GENERATE_DATASET_EDGES =
|
| 24 |
-
|
| 25 |
-
EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
|
| 26 |
|
| 27 |
def convert_entry_to_human_readable(entry):
|
| 28 |
out = {}
|
|
@@ -953,7 +954,7 @@ def generate_edge_patches(frame):
|
|
| 953 |
all_patches = positive_patches + negative_patches
|
| 954 |
|
| 955 |
# Visualize edge patches
|
| 956 |
-
if
|
| 957 |
# Create plotter
|
| 958 |
plotter = pv.Plotter()
|
| 959 |
|
|
@@ -1013,6 +1014,95 @@ def generate_edge_patches(frame):
|
|
| 1013 |
|
| 1014 |
return all_patches
|
| 1015 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1016 |
def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
| 1017 |
"""
|
| 1018 |
Calculate the intersection volume between two cylinders using numpy vectorization.
|
|
@@ -1113,21 +1203,28 @@ def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
|
| 1113 |
|
| 1114 |
return max(0.0, overlap_volume)
|
| 1115 |
|
| 1116 |
-
def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[int]]:
|
| 1117 |
"""
|
| 1118 |
Predict 3D wireframe from a dataset entry.
|
| 1119 |
"""
|
| 1120 |
good_entry = convert_entry_to_human_readable(entry)
|
| 1121 |
colmap_rec = good_entry['colmap_binary']
|
| 1122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1123 |
if GENERATE_DATASET_EDGES:
|
| 1124 |
patches = generate_edge_patches(good_entry)
|
| 1125 |
-
save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
| 1126 |
return empty_solution()
|
| 1127 |
|
| 1128 |
vert_edge_per_image = {}
|
| 1129 |
idxs_points = []
|
| 1130 |
all_connections = []
|
|
|
|
|
|
|
|
|
|
| 1131 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 1132 |
good_entry['depth'],
|
| 1133 |
good_entry['K'],
|
|
@@ -1222,7 +1319,7 @@ def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[
|
|
| 1222 |
|
| 1223 |
#visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
|
| 1224 |
|
| 1225 |
-
if pred_class >
|
| 1226 |
predicted_vertices.append(pred_vertex)
|
| 1227 |
else:
|
| 1228 |
predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
|
|
@@ -1273,9 +1370,26 @@ def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[
|
|
| 1273 |
|
| 1274 |
#print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
|
| 1275 |
#print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1276 |
|
| 1277 |
predicted_vertices = np.array(filtered_vertices)
|
| 1278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1279 |
|
| 1280 |
return predicted_vertices, connections
|
| 1281 |
|
|
|
|
| 15 |
from fast_voxel import predict_vertex_from_patch_voxel
|
| 16 |
import time
|
| 17 |
from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
|
| 18 |
+
from fast_pointnet_class import predict_class_from_patch
|
| 19 |
|
| 20 |
GENERATE_DATASET = False
|
| 21 |
DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
| 22 |
#DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
|
| 23 |
|
| 24 |
+
GENERATE_DATASET_EDGES = False
|
| 25 |
+
EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
|
| 26 |
+
#EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
|
| 27 |
|
| 28 |
def convert_entry_to_human_readable(entry):
|
| 29 |
out = {}
|
|
|
|
| 954 |
all_patches = positive_patches + negative_patches
|
| 955 |
|
| 956 |
# Visualize edge patches
|
| 957 |
+
if True: # Set to True to enable visualization
|
| 958 |
# Create plotter
|
| 959 |
plotter = pv.Plotter()
|
| 960 |
|
|
|
|
| 1014 |
|
| 1015 |
return all_patches
|
| 1016 |
|
| 1017 |
+
def generate_edge_patches_forward(frame, pred_vertices):
|
| 1018 |
+
vertices = pred_vertices
|
| 1019 |
+
|
| 1020 |
+
cylinder_radius = 0.5
|
| 1021 |
+
|
| 1022 |
+
colmap = frame['colmap_binary']
|
| 1023 |
+
|
| 1024 |
+
# Create 6D point cloud from COLMAP data
|
| 1025 |
+
colmap_points_6d = []
|
| 1026 |
+
for pid, p3D in colmap.points3D.items():
|
| 1027 |
+
# Combine xyz coordinates and RGB color
|
| 1028 |
+
point_6d = np.concatenate([p3D.xyz, p3D.color / 255.0]) # Normalize color to [0,1]
|
| 1029 |
+
colmap_points_6d.append(point_6d)
|
| 1030 |
+
|
| 1031 |
+
colmap_points_6d = np.array(colmap_points_6d) if colmap_points_6d else np.empty((0, 6))
|
| 1032 |
+
|
| 1033 |
+
colmap_points_6d[:, 3:] = colmap_points_6d[:, 3:] * 2 - 1
|
| 1034 |
+
|
| 1035 |
+
# Extract 3D coordinates for faster vectorized operations
|
| 1036 |
+
colmap_points_3d = colmap_points_6d[:, :3]
|
| 1037 |
+
|
| 1038 |
+
forward_patches = []
|
| 1039 |
+
|
| 1040 |
+
# For each vertex pair, create a patch without label
|
| 1041 |
+
for i in range(len(vertices)):
|
| 1042 |
+
for j in range(i + 1, len(vertices)):
|
| 1043 |
+
start_vertex = vertices[i]
|
| 1044 |
+
end_vertex = vertices[j]
|
| 1045 |
+
|
| 1046 |
+
# Create line vector from start to end
|
| 1047 |
+
line_vector = end_vertex - start_vertex
|
| 1048 |
+
line_length = np.linalg.norm(line_vector)
|
| 1049 |
+
|
| 1050 |
+
# Normalize line vector
|
| 1051 |
+
line_direction = line_vector / line_length
|
| 1052 |
+
|
| 1053 |
+
# Extend the line by 25 cm (0.25 meters) on both ends for more context
|
| 1054 |
+
extension_length = 0.25 # 25 cm in meters
|
| 1055 |
+
extended_start = start_vertex - extension_length * line_direction
|
| 1056 |
+
extended_end = end_vertex + extension_length * line_direction
|
| 1057 |
+
extended_line_length = line_length + 2 * extension_length
|
| 1058 |
+
|
| 1059 |
+
# Vectorized distance calculation
|
| 1060 |
+
# Vector from extended start to all points
|
| 1061 |
+
start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
|
| 1062 |
+
|
| 1063 |
+
# Project onto line direction to get distance along extended line
|
| 1064 |
+
projection_lengths = np.dot(start_to_points, line_direction)
|
| 1065 |
+
|
| 1066 |
+
# Filter points within extended line segment bounds
|
| 1067 |
+
within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
|
| 1068 |
+
|
| 1069 |
+
# Find closest points on extended line segment for all points
|
| 1070 |
+
closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
|
| 1071 |
+
|
| 1072 |
+
# Calculate perpendicular distances from points to line
|
| 1073 |
+
perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
|
| 1074 |
+
|
| 1075 |
+
# Find points within cylinder
|
| 1076 |
+
within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
|
| 1077 |
+
|
| 1078 |
+
if np.sum(within_cylinder) <= 10:
|
| 1079 |
+
continue
|
| 1080 |
+
|
| 1081 |
+
points_in_cylinder = colmap_points_6d[within_cylinder]
|
| 1082 |
+
point_indices_in_cylinder = np.where(within_cylinder)[0]
|
| 1083 |
+
|
| 1084 |
+
# Center the patch at the midpoint of the original line (not extended)
|
| 1085 |
+
line_midpoint = (start_vertex + end_vertex) / 2
|
| 1086 |
+
|
| 1087 |
+
# Shift points to center around origin
|
| 1088 |
+
points_centered = points_in_cylinder.copy()
|
| 1089 |
+
points_centered[:, :3] -= line_midpoint
|
| 1090 |
+
|
| 1091 |
+
# Create edge patch without label
|
| 1092 |
+
edge_patch = {
|
| 1093 |
+
'patch_6d': points_centered,
|
| 1094 |
+
'connection': (i, j),
|
| 1095 |
+
'line_start': start_vertex - line_midpoint,
|
| 1096 |
+
'line_end': end_vertex - line_midpoint,
|
| 1097 |
+
'cylinder_radius': cylinder_radius,
|
| 1098 |
+
'point_indices': point_indices_in_cylinder,
|
| 1099 |
+
'center': line_midpoint
|
| 1100 |
+
}
|
| 1101 |
+
|
| 1102 |
+
forward_patches.append(edge_patch)
|
| 1103 |
+
|
| 1104 |
+
return forward_patches
|
| 1105 |
+
|
| 1106 |
def calculate_cylinder_overlap_volume(cyl1, cyl2):
|
| 1107 |
"""
|
| 1108 |
Calculate the intersection volume between two cylinders using numpy vectorization.
|
|
|
|
| 1203 |
|
| 1204 |
return max(0.0, overlap_volume)
|
| 1205 |
|
| 1206 |
+
def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
|
| 1207 |
"""
|
| 1208 |
Predict 3D wireframe from a dataset entry.
|
| 1209 |
"""
|
| 1210 |
good_entry = convert_entry_to_human_readable(entry)
|
| 1211 |
colmap_rec = good_entry['colmap_binary']
|
| 1212 |
|
| 1213 |
+
vertex_threshold = config.get('vertex_threshold', 0.5)
|
| 1214 |
+
edge_threshold = config.get('edge_threshold', 0.5)
|
| 1215 |
+
only_predicted_connections = config.get('only_predicted_connections', False)
|
| 1216 |
+
|
| 1217 |
if GENERATE_DATASET_EDGES:
|
| 1218 |
patches = generate_edge_patches(good_entry)
|
| 1219 |
+
#save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
|
| 1220 |
return empty_solution()
|
| 1221 |
|
| 1222 |
vert_edge_per_image = {}
|
| 1223 |
idxs_points = []
|
| 1224 |
all_connections = []
|
| 1225 |
+
|
| 1226 |
+
print(f"Processing {len(good_entry['gestalt'])} images")
|
| 1227 |
+
|
| 1228 |
for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
|
| 1229 |
good_entry['depth'],
|
| 1230 |
good_entry['K'],
|
|
|
|
| 1319 |
|
| 1320 |
#visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
|
| 1321 |
|
| 1322 |
+
if pred_class > vertex_threshold:
|
| 1323 |
predicted_vertices.append(pred_vertex)
|
| 1324 |
else:
|
| 1325 |
predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
|
|
|
|
| 1370 |
|
| 1371 |
#print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
|
| 1372 |
#print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
|
| 1373 |
+
|
| 1374 |
+
forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
|
| 1375 |
+
new_connections = []
|
| 1376 |
+
if len(forward_patches) > 0:
|
| 1377 |
+
for patch in forward_patches:
|
| 1378 |
+
start_idx, end_idx = patch['connection']
|
| 1379 |
+
|
| 1380 |
+
pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device='cuda')
|
| 1381 |
+
if pred_score > edge_threshold:
|
| 1382 |
+
new_connections.append((start_idx, end_idx))
|
| 1383 |
|
| 1384 |
predicted_vertices = np.array(filtered_vertices)
|
| 1385 |
+
|
| 1386 |
+
if only_predicted_connections:
|
| 1387 |
+
connections = new_connections
|
| 1388 |
+
else:
|
| 1389 |
+
connections = filtered_connections + new_connections
|
| 1390 |
+
|
| 1391 |
+
# Remove duplicates from connections
|
| 1392 |
+
connections = list(set(connections))
|
| 1393 |
|
| 1394 |
return predicted_vertices, connections
|
| 1395 |
|
train.py
CHANGED
|
@@ -16,11 +16,12 @@ from predict import predict_wireframe, predict_wireframe_old
|
|
| 16 |
from tqdm import tqdm
|
| 17 |
from fast_pointnet import load_pointnet_model
|
| 18 |
from fast_voxel import load_3dcnn_model
|
|
|
|
| 19 |
import torch
|
| 20 |
|
| 21 |
-
|
| 22 |
-
ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 23 |
-
ds = ds.shuffle()
|
| 24 |
|
| 25 |
scores_hss = []
|
| 26 |
scores_f1 = []
|
|
@@ -30,19 +31,24 @@ show_visu = False
|
|
| 30 |
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
|
| 33 |
-
|
| 34 |
-
pnet_model = None
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 37 |
voxel_model = None
|
| 38 |
|
|
|
|
|
|
|
| 39 |
idx = 0
|
| 40 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 41 |
#plot_all_modalities(a)
|
| 42 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 43 |
-
#pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
|
| 44 |
try:
|
| 45 |
-
pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
|
| 46 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 47 |
except:
|
| 48 |
pred_vertices, pred_edges = empty_solution()
|
|
@@ -53,7 +59,7 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 53 |
scores_f1.append(score.f1)
|
| 54 |
scores_iou.append(score.iou)
|
| 55 |
|
| 56 |
-
if show_visu
|
| 57 |
colmap = read_colmap_rec(a['colmap_binary'])
|
| 58 |
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 59 |
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
|
@@ -63,9 +69,14 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 63 |
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
|
| 64 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
for i in range(10):
|
| 67 |
print("END OF DATASET")
|
| 68 |
print(f"Mean HSS: {np.mean(scores_hss):.4f}")
|
| 69 |
print(f"Mean F1: {np.mean(scores_f1):.4f}")
|
| 70 |
print(f"Mean IoU: {np.mean(scores_iou):.4f}")
|
|
|
|
| 71 |
|
|
|
|
| 16 |
from tqdm import tqdm
|
| 17 |
from fast_pointnet import load_pointnet_model
|
| 18 |
from fast_voxel import load_3dcnn_model
|
| 19 |
+
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
|
| 20 |
import torch
|
| 21 |
|
| 22 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 23 |
+
#ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 24 |
+
#ds = ds.shuffle()
|
| 25 |
|
| 26 |
scores_hss = []
|
| 27 |
scores_f1 = []
|
|
|
|
| 31 |
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
|
| 34 |
+
pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 35 |
+
#pnet_model = None
|
| 36 |
+
|
| 37 |
+
pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
|
| 38 |
+
#pnet_class_model = None
|
| 39 |
|
| 40 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 41 |
voxel_model = None
|
| 42 |
|
| 43 |
+
config = {'vertex_threshold': 0.4, 'edge_threshold': 0.7, 'only_predicted_connections': False}
|
| 44 |
+
|
| 45 |
idx = 0
|
| 46 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 47 |
#plot_all_modalities(a)
|
| 48 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 49 |
+
#pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model)
|
| 50 |
try:
|
| 51 |
+
pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
|
| 52 |
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 53 |
except:
|
| 54 |
pred_vertices, pred_edges = empty_solution()
|
|
|
|
| 59 |
scores_f1.append(score.f1)
|
| 60 |
scores_iou.append(score.iou)
|
| 61 |
|
| 62 |
+
if show_visu:
|
| 63 |
colmap = read_colmap_rec(a['colmap_binary'])
|
| 64 |
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 65 |
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
|
|
|
| 69 |
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
|
| 70 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 71 |
|
| 72 |
+
idx += 1
|
| 73 |
+
if idx >= 100: # Limit to first 10 samples for testing
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
for i in range(10):
|
| 77 |
print("END OF DATASET")
|
| 78 |
print(f"Mean HSS: {np.mean(scores_hss):.4f}")
|
| 79 |
print(f"Mean F1: {np.mean(scores_f1):.4f}")
|
| 80 |
print(f"Mean IoU: {np.mean(scores_iou):.4f}")
|
| 81 |
+
print(config)
|
| 82 |
|