jskvrna commited on
Commit
d9fc230
·
1 Parent(s): 5291a52

Improves edge prediction with class PointNet

Browse files

The best settings so far:

Mean HSS: 0.2896
Mean F1: 0.3718
Mean IoU: 0.2466
{'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}

Files changed (2) hide show
  1. predict.py +122 -8
  2. train.py +19 -8
predict.py CHANGED
@@ -15,14 +15,15 @@ from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
15
  from fast_voxel import predict_vertex_from_patch_voxel
16
  import time
17
  from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
 
18
 
19
  GENERATE_DATASET = False
20
  DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
21
  #DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
22
 
23
- GENERATE_DATASET_EDGES = True
24
- #EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
25
- EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
26
 
27
  def convert_entry_to_human_readable(entry):
28
  out = {}
@@ -953,7 +954,7 @@ def generate_edge_patches(frame):
953
  all_patches = positive_patches + negative_patches
954
 
955
  # Visualize edge patches
956
- if False: # Set to True to enable visualization
957
  # Create plotter
958
  plotter = pv.Plotter()
959
 
@@ -1013,6 +1014,95 @@ def generate_edge_patches(frame):
1013
 
1014
  return all_patches
1015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
  def calculate_cylinder_overlap_volume(cyl1, cyl2):
1017
  """
1018
  Calculate the intersection volume between two cylinders using numpy vectorization.
@@ -1113,21 +1203,28 @@ def calculate_cylinder_overlap_volume(cyl1, cyl2):
1113
 
1114
  return max(0.0, overlap_volume)
1115
 
1116
- def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[int]]:
1117
  """
1118
  Predict 3D wireframe from a dataset entry.
1119
  """
1120
  good_entry = convert_entry_to_human_readable(entry)
1121
  colmap_rec = good_entry['colmap_binary']
1122
 
 
 
 
 
1123
  if GENERATE_DATASET_EDGES:
1124
  patches = generate_edge_patches(good_entry)
1125
- save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
1126
  return empty_solution()
1127
 
1128
  vert_edge_per_image = {}
1129
  idxs_points = []
1130
  all_connections = []
 
 
 
1131
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
1132
  good_entry['depth'],
1133
  good_entry['K'],
@@ -1222,7 +1319,7 @@ def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[
1222
 
1223
  #visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
1224
 
1225
- if pred_class > 0.5:
1226
  predicted_vertices.append(pred_vertex)
1227
  else:
1228
  predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
@@ -1273,9 +1370,26 @@ def predict_wireframe(entry, pnet_model, voxel_model) -> Tuple[np.ndarray, List[
1273
 
1274
  #print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
1275
  #print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
 
 
 
 
 
 
 
 
 
 
1276
 
1277
  predicted_vertices = np.array(filtered_vertices)
1278
- connections = filtered_connections
 
 
 
 
 
 
 
1279
 
1280
  return predicted_vertices, connections
1281
 
 
15
  from fast_voxel import predict_vertex_from_patch_voxel
16
  import time
17
  from fast_pointnet_class import save_patches_dataset as save_patches_dataset_class
18
+ from fast_pointnet_class import predict_class_from_patch
19
 
20
  GENERATE_DATASET = False
21
  DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
22
  #DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom'
23
 
24
+ GENERATE_DATASET_EDGES = False
25
+ EDGES_DATASET_DIR = '/home/skvrnjan/personal/hohocustom_edges/'
26
+ #EDGES_DATASET_DIR = '/mnt/personal/skvrnjan/hohocustom_edges'
27
 
28
  def convert_entry_to_human_readable(entry):
29
  out = {}
 
954
  all_patches = positive_patches + negative_patches
955
 
956
  # Visualize edge patches
957
+ if True: # Set to True to enable visualization
958
  # Create plotter
959
  plotter = pv.Plotter()
960
 
 
1014
 
1015
  return all_patches
1016
 
1017
+ def generate_edge_patches_forward(frame, pred_vertices):
1018
+ vertices = pred_vertices
1019
+
1020
+ cylinder_radius = 0.5
1021
+
1022
+ colmap = frame['colmap_binary']
1023
+
1024
+ # Create 6D point cloud from COLMAP data
1025
+ colmap_points_6d = []
1026
+ for pid, p3D in colmap.points3D.items():
1027
+ # Combine xyz coordinates and RGB color
1028
+ point_6d = np.concatenate([p3D.xyz, p3D.color / 255.0]) # Normalize color to [0,1]
1029
+ colmap_points_6d.append(point_6d)
1030
+
1031
+ colmap_points_6d = np.array(colmap_points_6d) if colmap_points_6d else np.empty((0, 6))
1032
+
1033
+ colmap_points_6d[:, 3:] = colmap_points_6d[:, 3:] * 2 - 1
1034
+
1035
+ # Extract 3D coordinates for faster vectorized operations
1036
+ colmap_points_3d = colmap_points_6d[:, :3]
1037
+
1038
+ forward_patches = []
1039
+
1040
+ # For each vertex pair, create a patch without label
1041
+ for i in range(len(vertices)):
1042
+ for j in range(i + 1, len(vertices)):
1043
+ start_vertex = vertices[i]
1044
+ end_vertex = vertices[j]
1045
+
1046
+ # Create line vector from start to end
1047
+ line_vector = end_vertex - start_vertex
1048
+ line_length = np.linalg.norm(line_vector)
1049
+
1050
+ # Normalize line vector
1051
+ line_direction = line_vector / line_length
1052
+
1053
+ # Extend the line by 25 cm (0.25 meters) on both ends for more context
1054
+ extension_length = 0.25 # 25 cm in meters
1055
+ extended_start = start_vertex - extension_length * line_direction
1056
+ extended_end = end_vertex + extension_length * line_direction
1057
+ extended_line_length = line_length + 2 * extension_length
1058
+
1059
+ # Vectorized distance calculation
1060
+ # Vector from extended start to all points
1061
+ start_to_points = colmap_points_3d - extended_start[np.newaxis, :]
1062
+
1063
+ # Project onto line direction to get distance along extended line
1064
+ projection_lengths = np.dot(start_to_points, line_direction)
1065
+
1066
+ # Filter points within extended line segment bounds
1067
+ within_bounds = (projection_lengths >= 0) & (projection_lengths <= extended_line_length)
1068
+
1069
+ # Find closest points on extended line segment for all points
1070
+ closest_points_on_line = extended_start[np.newaxis, :] + projection_lengths[:, np.newaxis] * line_direction[np.newaxis, :]
1071
+
1072
+ # Calculate perpendicular distances from points to line
1073
+ perpendicular_distances = np.linalg.norm(colmap_points_3d - closest_points_on_line, axis=1)
1074
+
1075
+ # Find points within cylinder
1076
+ within_cylinder = within_bounds & (perpendicular_distances <= cylinder_radius)
1077
+
1078
+ if np.sum(within_cylinder) <= 10:
1079
+ continue
1080
+
1081
+ points_in_cylinder = colmap_points_6d[within_cylinder]
1082
+ point_indices_in_cylinder = np.where(within_cylinder)[0]
1083
+
1084
+ # Center the patch at the midpoint of the original line (not extended)
1085
+ line_midpoint = (start_vertex + end_vertex) / 2
1086
+
1087
+ # Shift points to center around origin
1088
+ points_centered = points_in_cylinder.copy()
1089
+ points_centered[:, :3] -= line_midpoint
1090
+
1091
+ # Create edge patch without label
1092
+ edge_patch = {
1093
+ 'patch_6d': points_centered,
1094
+ 'connection': (i, j),
1095
+ 'line_start': start_vertex - line_midpoint,
1096
+ 'line_end': end_vertex - line_midpoint,
1097
+ 'cylinder_radius': cylinder_radius,
1098
+ 'point_indices': point_indices_in_cylinder,
1099
+ 'center': line_midpoint
1100
+ }
1101
+
1102
+ forward_patches.append(edge_patch)
1103
+
1104
+ return forward_patches
1105
+
1106
  def calculate_cylinder_overlap_volume(cyl1, cyl2):
1107
  """
1108
  Calculate the intersection volume between two cylinders using numpy vectorization.
 
1203
 
1204
  return max(0.0, overlap_volume)
1205
 
1206
+ def predict_wireframe(entry, pnet_model, voxel_model, pnet_class_model, config) -> Tuple[np.ndarray, List[int]]:
1207
  """
1208
  Predict 3D wireframe from a dataset entry.
1209
  """
1210
  good_entry = convert_entry_to_human_readable(entry)
1211
  colmap_rec = good_entry['colmap_binary']
1212
 
1213
+ vertex_threshold = config.get('vertex_threshold', 0.5)
1214
+ edge_threshold = config.get('edge_threshold', 0.5)
1215
+ only_predicted_connections = config.get('only_predicted_connections', False)
1216
+
1217
  if GENERATE_DATASET_EDGES:
1218
  patches = generate_edge_patches(good_entry)
1219
+ #save_patches_dataset_class(patches, EDGES_DATASET_DIR, good_entry['order_id'])
1220
  return empty_solution()
1221
 
1222
  vert_edge_per_image = {}
1223
  idxs_points = []
1224
  all_connections = []
1225
+
1226
+ print(f"Processing {len(good_entry['gestalt'])} images")
1227
+
1228
  for i, (gest, depth, K, R, t, img_id, ade_seg) in enumerate(zip(good_entry['gestalt'],
1229
  good_entry['depth'],
1230
  good_entry['K'],
 
1319
 
1320
  #visu_patch_and_pred(patch, pred_vertex, pred_dist, pred_class)
1321
 
1322
+ if pred_class > vertex_threshold:
1323
  predicted_vertices.append(pred_vertex)
1324
  else:
1325
  predicted_vertices.append(np.array([0.0, 0.0, 0.0])) # Append a zero vertex if not predicted
 
1370
 
1371
  #print(f"Filtered vertices: {len(filtered_vertices)} from {len(predicted_vertices)}")
1372
  #print(f"Filtered connections: {len(filtered_connections)} from {len(connections)}")
1373
+
1374
+ forward_patches = generate_edge_patches_forward(good_entry, filtered_vertices)
1375
+ new_connections = []
1376
+ if len(forward_patches) > 0:
1377
+ for patch in forward_patches:
1378
+ start_idx, end_idx = patch['connection']
1379
+
1380
+ pred_class, pred_score = predict_class_from_patch(pnet_class_model, patch, device='cuda')
1381
+ if pred_score > edge_threshold:
1382
+ new_connections.append((start_idx, end_idx))
1383
 
1384
  predicted_vertices = np.array(filtered_vertices)
1385
+
1386
+ if only_predicted_connections:
1387
+ connections = new_connections
1388
+ else:
1389
+ connections = filtered_connections + new_connections
1390
+
1391
+ # Remove duplicates from connections
1392
+ connections = list(set(connections))
1393
 
1394
  return predicted_vertices, connections
1395
 
train.py CHANGED
@@ -16,11 +16,12 @@ from predict import predict_wireframe, predict_wireframe_old
16
  from tqdm import tqdm
17
  from fast_pointnet import load_pointnet_model
18
  from fast_voxel import load_3dcnn_model
 
19
  import torch
20
 
21
- #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
22
- ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
23
- ds = ds.shuffle()
24
 
25
  scores_hss = []
26
  scores_f1 = []
@@ -30,19 +31,24 @@ show_visu = False
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
- #pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
34
- pnet_model = None
 
 
 
35
 
36
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
37
  voxel_model = None
38
 
 
 
39
  idx = 0
40
  for a in tqdm(ds['train'], desc="Processing dataset"):
41
  #plot_all_modalities(a)
42
  #pred_vertices, pred_edges = predict_wireframe_old(a)
43
- #pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
44
  try:
45
- pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
46
  #pred_vertices, pred_edges = predict_wireframe_old(a)
47
  except:
48
  pred_vertices, pred_edges = empty_solution()
@@ -53,7 +59,7 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
53
  scores_f1.append(score.f1)
54
  scores_iou.append(score.iou)
55
 
56
- if show_visu and score.hss < 0.1:
57
  colmap = read_colmap_rec(a['colmap_binary'])
58
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
59
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
@@ -63,9 +69,14 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
63
  visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
64
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
65
 
 
 
 
 
66
  for i in range(10):
67
  print("END OF DATASET")
68
  print(f"Mean HSS: {np.mean(scores_hss):.4f}")
69
  print(f"Mean F1: {np.mean(scores_f1):.4f}")
70
  print(f"Mean IoU: {np.mean(scores_iou):.4f}")
 
71
 
 
16
  from tqdm import tqdm
17
  from fast_pointnet import load_pointnet_model
18
  from fast_voxel import load_3dcnn_model
19
+ from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
20
  import torch
21
 
22
+ ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
23
+ #ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
24
+ #ds = ds.shuffle()
25
 
26
  scores_hss = []
27
  scores_f1 = []
 
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
+ pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
35
+ #pnet_model = None
36
+
37
+ pnet_class_model = load_pointnet_class_model(model_path="/home/skvrnjan/personal/hoho_pnet_edges_v2/initial_epoch_100.pth", device=device)
38
+ #pnet_class_model = None
39
 
40
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
41
  voxel_model = None
42
 
43
+ config = {'vertex_threshold': 0.4, 'edge_threshold': 0.7, 'only_predicted_connections': False}
44
+
45
  idx = 0
46
  for a in tqdm(ds['train'], desc="Processing dataset"):
47
  #plot_all_modalities(a)
48
  #pred_vertices, pred_edges = predict_wireframe_old(a)
49
+ #pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model)
50
  try:
51
+ pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model, pnet_class_model, config)
52
  #pred_vertices, pred_edges = predict_wireframe_old(a)
53
  except:
54
  pred_vertices, pred_edges = empty_solution()
 
59
  scores_f1.append(score.f1)
60
  scores_iou.append(score.iou)
61
 
62
+ if show_visu:
63
  colmap = read_colmap_rec(a['colmap_binary'])
64
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
65
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
 
69
  visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
70
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
71
 
72
+ idx += 1
73
+ if idx >= 100: # Limit to first 10 samples for testing
74
+ break
75
+
76
  for i in range(10):
77
  print("END OF DATASET")
78
  print(f"Mean HSS: {np.mean(scores_hss):.4f}")
79
  print(f"Mean F1: {np.mean(scores_f1):.4f}")
80
  print(f"Mean IoU: {np.mean(scores_iou):.4f}")
81
+ print(config)
82