jskvrna commited on
Commit
6d115a4
·
1 Parent(s): e6c805b

Enhances FastPointNet architecture for improved accuracy.

Browse files

Deepens the FastPointNet architecture by adding more convolutional layers and fully connected layers.

Increases the number of parameters and introduces residual-like and skip connections for improved feature extraction and gradient flow, which results in better generalization and prediction accuracy.

Introduces global average pooling in addition to global max pooling and adjusts dropout rates to avoid overfitting.

Adds a filtering mechanism to the dataset loader.

Files changed (3) hide show
  1. fast_pointnet.py +100 -68
  2. predict.py +58 -5
  3. train.py +11 -4
fast_pointnet.py CHANGED
@@ -12,6 +12,7 @@ class FastPointNet(nn.Module):
12
  """
13
  Fast PointNet implementation for 3D vertex prediction from point cloud patches.
14
  Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
 
15
  """
16
 
17
  def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True):
@@ -19,34 +20,46 @@ class FastPointNet(nn.Module):
19
  self.max_points = max_points
20
  self.predict_score = predict_score
21
 
22
- # Point-wise MLPs
23
- self.conv1 = nn.Conv1d(input_dim, 64, 1)
24
- self.conv2 = nn.Conv1d(64, 128, 1)
25
- self.conv3 = nn.Conv1d(128, 256, 1)
 
26
 
27
- # Global feature extraction
28
- self.conv4 = nn.Conv1d(256, 512, 1)
29
- self.conv5 = nn.Conv1d(512, 1024, 1)
30
 
31
- # Shared features
32
- self.shared_fc = nn.Linear(1024, 512)
 
33
 
34
- # Position prediction head
35
- self.pos_fc1 = nn.Linear(512, 256)
36
- self.pos_fc2 = nn.Linear(256, output_dim)
 
 
37
 
38
- # Score prediction head (predicts distance to GT)
39
  if self.predict_score:
40
- self.score_fc1 = nn.Linear(512, 256)
41
- self.score_fc2 = nn.Linear(256, 128)
42
- self.score_fc3 = nn.Linear(128, 1) # Single score output
43
-
44
- self.dropout = nn.Dropout(0.3)
45
- self.bn1 = nn.BatchNorm1d(64)
46
- self.bn2 = nn.BatchNorm1d(128)
47
- self.bn3 = nn.BatchNorm1d(256)
48
- self.bn4 = nn.BatchNorm1d(512)
 
 
49
  self.bn5 = nn.BatchNorm1d(1024)
 
 
 
 
 
 
50
 
51
  def forward(self, x):
52
  """
@@ -61,32 +74,47 @@ class FastPointNet(nn.Module):
61
  """
62
  batch_size = x.size(0)
63
 
64
- # Point-wise feature extraction
65
- x = F.relu(self.bn1(self.conv1(x)))
66
- x = F.relu(self.bn2(self.conv2(x)))
67
- x = F.relu(self.bn3(self.conv3(x)))
68
- x = F.relu(self.bn4(self.conv4(x)))
69
- x = F.relu(self.bn5(self.conv5(x)))
70
-
71
- # Global max pooling
72
- x = torch.max(x, 2)[0] # (batch_size, 1024)
73
-
74
- # Shared features
75
- shared_features = F.relu(self.shared_fc(x))
76
- shared_features = self.dropout(shared_features)
77
-
78
- # Position prediction
79
- pos_features = F.relu(self.pos_fc1(shared_features))
80
- pos_features = self.dropout(pos_features)
81
- position = self.pos_fc2(pos_features)
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  if self.predict_score:
84
- # Score prediction (distance to GT)
85
- score_features = F.relu(self.score_fc1(shared_features))
86
- score_features = self.dropout(score_features)
87
- score_features = F.relu(self.score_fc2(score_features))
88
- score_features = self.dropout(score_features)
89
- score = F.relu(self.score_fc3(score_features)) # Ensure positive distance
 
 
 
 
90
 
91
  return position, score
92
  else:
@@ -235,6 +263,25 @@ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
235
 
236
  print(f"Saved {len(patches)} patches for entry {entry_id}")
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
239
  score_weight: float = 0.1):
240
  """
@@ -252,28 +299,9 @@ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, ba
252
  print(f"Training on device: {device}")
253
 
254
  # Create dataset and dataloader
255
- dataset = PatchDataset(dataset_dir, max_points=1024, augment=True)
256
  print(f"Dataset loaded with {len(dataset)} samples")
257
 
258
- # Create dataloader with custom collate function to filter invalid samples
259
- def collate_fn(batch):
260
- valid_batch = []
261
- for patch_data, target, valid_mask, distance in batch:
262
- # Filter out invalid samples (no valid points or dummy targets)
263
- if valid_mask.sum() > 0 and not torch.all(target == 0):
264
- valid_batch.append((patch_data, target, valid_mask, distance))
265
-
266
- if len(valid_batch) == 0:
267
- return None
268
-
269
- # Stack valid samples
270
- patch_data = torch.stack([item[0] for item in valid_batch])
271
- targets = torch.stack([item[1] for item in valid_batch])
272
- valid_masks = torch.stack([item[2] for item in valid_batch])
273
- distances = torch.stack([item[3] for item in valid_batch])
274
-
275
- return patch_data, targets, valid_masks, distances
276
-
277
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
278
  collate_fn=collate_fn, drop_last=True)
279
 
@@ -401,7 +429,7 @@ def load_pointnet_model(model_path: str, device: torch.device = None, predict_sc
401
 
402
  return model
403
 
404
- def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float]:
405
  """
406
  Predict 3D vertex coordinates and confidence score from a patch using trained PointNet.
407
 
@@ -418,7 +446,7 @@ def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device:
418
  if device is None:
419
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
420
 
421
- model.eval()
422
 
423
  # Prepare input
424
  max_points = 1024
@@ -443,6 +471,10 @@ def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device:
443
  position, score = model(patch_tensor)
444
  position = position.cpu().numpy().squeeze()
445
  score = score.cpu().numpy().squeeze()
 
 
 
 
446
  return position, score
447
  else:
448
  position = model(patch_tensor)
 
12
  """
13
  Fast PointNet implementation for 3D vertex prediction from point cloud patches.
14
  Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
15
+ Enhanced with deeper architecture and more parameters for better generalization.
16
  """
17
 
18
  def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True):
 
20
  self.max_points = max_points
21
  self.predict_score = predict_score
22
 
23
+ # Enhanced point-wise MLPs with deeper architecture
24
+ self.conv1 = nn.Conv1d(input_dim, 128, 1)
25
+ self.conv2 = nn.Conv1d(128, 256, 1)
26
+ self.conv3 = nn.Conv1d(256, 512, 1)
27
+ self.conv4 = nn.Conv1d(512, 1024, 1)
28
 
29
+ # Additional layers for better feature extraction
30
+ self.conv5 = nn.Conv1d(1024, 1024, 1)
31
+ self.conv6 = nn.Conv1d(1024, 2048, 1)
32
 
33
+ # Larger shared features
34
+ self.shared_fc1 = nn.Linear(2048, 1024)
35
+ self.shared_fc2 = nn.Linear(1024, 512)
36
 
37
+ # Enhanced position prediction head
38
+ self.pos_fc1 = nn.Linear(512, 512)
39
+ self.pos_fc2 = nn.Linear(512, 256)
40
+ self.pos_fc3 = nn.Linear(256, 128)
41
+ self.pos_fc4 = nn.Linear(128, output_dim)
42
 
43
+ # Enhanced score prediction head
44
  if self.predict_score:
45
+ self.score_fc1 = nn.Linear(512, 512)
46
+ self.score_fc2 = nn.Linear(512, 256)
47
+ self.score_fc3 = nn.Linear(256, 128)
48
+ self.score_fc4 = nn.Linear(128, 64)
49
+ self.score_fc5 = nn.Linear(64, 1)
50
+
51
+ # Batch normalization layers
52
+ self.bn1 = nn.BatchNorm1d(128)
53
+ self.bn2 = nn.BatchNorm1d(256)
54
+ self.bn3 = nn.BatchNorm1d(512)
55
+ self.bn4 = nn.BatchNorm1d(1024)
56
  self.bn5 = nn.BatchNorm1d(1024)
57
+ self.bn6 = nn.BatchNorm1d(2048)
58
+
59
+ # Dropout with different rates
60
+ self.dropout_light = nn.Dropout(0.2)
61
+ self.dropout_medium = nn.Dropout(0.3)
62
+ self.dropout_heavy = nn.Dropout(0.4)
63
 
64
  def forward(self, x):
65
  """
 
74
  """
75
  batch_size = x.size(0)
76
 
77
+ # Enhanced point-wise feature extraction with residual-like connections
78
+ x1 = F.relu(self.bn1(self.conv1(x)))
79
+ x2 = F.relu(self.bn2(self.conv2(x1)))
80
+ x3 = F.relu(self.bn3(self.conv3(x2)))
81
+ x4 = F.relu(self.bn4(self.conv4(x3)))
82
+ x5 = F.relu(self.bn5(self.conv5(x4)))
83
+ x6 = F.relu(self.bn6(self.conv6(x5)))
84
+
85
+ # Global max pooling with additional global average pooling
86
+ max_pool = torch.max(x6, 2)[0] # (batch_size, 2048)
87
+ avg_pool = torch.mean(x6, 2) # (batch_size, 2048)
88
+
89
+ # Combine max and average pooling for richer global features
90
+ global_features = max_pool + avg_pool # (batch_size, 2048)
91
+
92
+ # Enhanced shared features with residual connection
93
+ shared1 = F.relu(self.shared_fc1(global_features))
94
+ shared1 = self.dropout_light(shared1)
95
+ shared2 = F.relu(self.shared_fc2(shared1))
96
+ shared_features = self.dropout_medium(shared2)
97
+
98
+ # Enhanced position prediction with skip connections
99
+ pos1 = F.relu(self.pos_fc1(shared_features))
100
+ pos1 = self.dropout_light(pos1)
101
+ pos2 = F.relu(self.pos_fc2(pos1))
102
+ pos2 = self.dropout_medium(pos2)
103
+ pos3 = F.relu(self.pos_fc3(pos2))
104
+ pos3 = self.dropout_light(pos3)
105
+ position = self.pos_fc4(pos3)
106
 
107
  if self.predict_score:
108
+ # Enhanced score prediction
109
+ score1 = F.relu(self.score_fc1(shared_features))
110
+ score1 = self.dropout_light(score1)
111
+ score2 = F.relu(self.score_fc2(score1))
112
+ score2 = self.dropout_medium(score2)
113
+ score3 = F.relu(self.score_fc3(score2))
114
+ score3 = self.dropout_light(score3)
115
+ score4 = F.relu(self.score_fc4(score3))
116
+ score4 = self.dropout_light(score4)
117
+ score = F.relu(self.score_fc5(score4)) # Ensure positive distance
118
 
119
  return position, score
120
  else:
 
263
 
264
  print(f"Saved {len(patches)} patches for entry {entry_id}")
265
 
266
+ # Create dataloader with custom collate function to filter invalid samples
267
+ def collate_fn(batch):
268
+ valid_batch = []
269
+ for patch_data, target, valid_mask, distance in batch:
270
+ # Filter out invalid samples (no valid points or dummy targets)
271
+ if valid_mask.sum() > 0 and not torch.all(target == 0):
272
+ valid_batch.append((patch_data, target, valid_mask, distance))
273
+
274
+ if len(valid_batch) == 0:
275
+ return None
276
+
277
+ # Stack valid samples
278
+ patch_data = torch.stack([item[0] for item in valid_batch])
279
+ targets = torch.stack([item[1] for item in valid_batch])
280
+ valid_masks = torch.stack([item[2] for item in valid_batch])
281
+ distances = torch.stack([item[3] for item in valid_batch])
282
+
283
+ return patch_data, targets, valid_masks, distances
284
+
285
  def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
286
  score_weight: float = 0.1):
287
  """
 
299
  print(f"Training on device: {device}")
300
 
301
  # Create dataset and dataloader
302
+ dataset = PatchDataset(dataset_dir, max_points=1024, augment=False)
303
  print(f"Dataset loaded with {len(dataset)} samples")
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
306
  collate_fn=collate_fn, drop_last=True)
307
 
 
429
 
430
  return model
431
 
432
+ def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float]:
433
  """
434
  Predict 3D vertex coordinates and confidence score from a patch using trained PointNet.
435
 
 
446
  if device is None:
447
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
448
 
449
+ patch_7d = patch['patch_7d'] # (N, 7)
450
 
451
  # Prepare input
452
  max_points = 1024
 
471
  position, score = model(patch_tensor)
472
  position = position.cpu().numpy().squeeze()
473
  score = score.cpu().numpy().squeeze()
474
+
475
+ offset = patch['offset']
476
+ position -= offset
477
+
478
  return position, score
479
  else:
480
  position = model(patch_tensor)
predict.py CHANGED
@@ -11,7 +11,7 @@ import cv2
11
  import open3d as o3d
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
13
  import pyvista as pv
14
- from fast_pointnet import save_patches_dataset
15
 
16
  GENERATE_DATASET = True
17
  #DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
@@ -388,7 +388,53 @@ def create_3d_wireframe_single_image(vertices: List[dict],
388
  return vertices_3d
389
 
390
 
391
- def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  """
393
  Predict 3D wireframe from a dataset entry.
394
  """
@@ -421,6 +467,13 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
421
 
422
  continue
423
 
 
 
 
 
 
 
 
424
  vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
425
  # Get 2D vertices and edges first
426
  #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
@@ -908,7 +961,7 @@ def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices,
908
  if pid in point_idxs:
909
  patch_7d[i, 6] = 1.0
910
  else:
911
- patch_7d[i, 6] = 0.0
912
 
913
  if filtered_vertices[group_idx] is not None:
914
  initial_pred = filtered_vertices[group_idx] + offset
@@ -961,7 +1014,7 @@ def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices,
961
  plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
962
 
963
  plotter.show(title=f"Patch {group_idx}")
964
-
965
  return patches
966
 
967
  def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
@@ -991,7 +1044,7 @@ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_se
991
 
992
  if len(uv) == 0:
993
  print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
994
- return [], [], []
995
 
996
  house_mask = get_house_mask(ade_seg)
997
 
 
11
  import open3d as o3d
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
13
  import pyvista as pv
14
+ from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
15
 
16
  GENERATE_DATASET = True
17
  #DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
 
388
  return vertices_3d
389
 
390
 
391
+ def visu_patch_and_pred(patch, pred):
392
+ # Create plotter
393
+ plotter = pv.Plotter()
394
+
395
+ # Create point cloud for this patch
396
+ offset = patch.get('offset', None) # Offset if available
397
+ patch_points_3d = np.array(patch['patch_7d'][:, :3])
398
+ patch_points_3d = patch_points_3d - offset
399
+ patch_cloud = pv.PolyData(patch_points_3d)
400
+
401
+ point_idxs = patch['filtered_point_ids'] # List of point indices that are filtered
402
+ patch_point_ids = patch['point_ids'] # Assuming the 7th column contains point IDs
403
+ assigned_gt_vertex = patch.get('assigned_gt_vertex', None) # GT vertex if available
404
+ initial_pred = patch.get('initial_pred', None) # Initial prediction if available
405
+ initial_pred = initial_pred - offset
406
+
407
+ assigned_gt_vertex = assigned_gt_vertex - offset
408
+
409
+ # Color points: red for filtered points, blue for other points
410
+ patch_point_colors = []
411
+ for i, pid in enumerate(patch_point_ids):
412
+ if pid in point_idxs:
413
+ patch_point_colors.append([255, 0, 0]) # Red for filtered points
414
+ else:
415
+ patch_point_colors.append([0, 0, 255]) # Blue for other points
416
+
417
+ patch_cloud["colors"] = np.array(patch_point_colors)
418
+ plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
419
+
420
+ # Create sphere to visualize GT vertex if available
421
+ if assigned_gt_vertex is not None:
422
+ gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex)
423
+ plotter.add_mesh(gt_sphere, color="green", opacity=0.5)
424
+
425
+ if initial_pred is not None:
426
+ # Create sphere to visualize initial prediction
427
+ pred_sphere = pv.Sphere(radius=0.1, center=initial_pred)
428
+ plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
429
+
430
+ if pred is not None:
431
+ # Create sphere to visualize predicted vertex
432
+ pred_sphere = pv.Sphere(radius=0.1, center=pred)
433
+ plotter.add_mesh(pred_sphere, color="red", opacity=0.5)
434
+
435
+ plotter.show(title=f"Patch x")
436
+
437
+ def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
438
  """
439
  Predict 3D wireframe from a dataset entry.
440
  """
 
467
 
468
  continue
469
 
470
+ for patch in patches:
471
+ pred_vertex, pred_dist = predict_vertex_from_patch(pnet_model, patch, device='cuda')
472
+ visu_patch_and_pred(patch, pred_vertex)
473
+
474
+ x = 0
475
+
476
+
477
  vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
478
  # Get 2D vertices and edges first
479
  #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
 
961
  if pid in point_idxs:
962
  patch_7d[i, 6] = 1.0
963
  else:
964
+ patch_7d[i, 6] = -1.0
965
 
966
  if filtered_vertices[group_idx] is not None:
967
  initial_pred = filtered_vertices[group_idx] + offset
 
1014
  plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
1015
 
1016
  plotter.show(title=f"Patch {group_idx}")
1017
+
1018
  return patches
1019
 
1020
  def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
 
1044
 
1045
  if len(uv) == 0:
1046
  print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
1047
+ return [], [], [], []
1048
 
1049
  house_mask = get_house_mask(ade_seg)
1050
 
train.py CHANGED
@@ -5,6 +5,7 @@ import pycolmap
5
  import tempfile,zipfile
6
  import io
7
  import open3d as o3d
 
8
 
9
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
10
  from utils import read_colmap_rec, empty_solution
@@ -13,22 +14,28 @@ from utils import read_colmap_rec, empty_solution
13
  from hoho2025.metric_helper import hss
14
  from predict import predict_wireframe
15
  from tqdm import tqdm
 
 
16
 
17
- ds = load_dataset("usm3d/hoho25k", cache_dir='/home/skvrnjan/personal/hoho25k', trust_remote_code=True)
18
  ds = ds.shuffle()
19
 
20
  scores_hss = []
21
  scores_f1 = []
22
  scores_iou = []
23
 
24
- show_visu = False
 
 
 
 
25
 
26
  idx = 0
27
  for a in tqdm(ds['train'], desc="Processing dataset"):
28
  #plot_all_modalities(a)
29
- #pred_vertices, pred_edges = predict_wireframe(a)
30
  try:
31
- pred_vertices, pred_edges = predict_wireframe(a)
32
  except:
33
  pred_vertices, pred_edges = empty_solution()
34
 
 
5
  import tempfile,zipfile
6
  import io
7
  import open3d as o3d
8
+ import os
9
 
10
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
11
  from utils import read_colmap_rec, empty_solution
 
14
  from hoho2025.metric_helper import hss
15
  from predict import predict_wireframe
16
  from tqdm import tqdm
17
+ from fast_pointnet import load_pointnet_model
18
+ import torch
19
 
20
+ ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
21
  ds = ds.shuffle()
22
 
23
  scores_hss = []
24
  scores_f1 = []
25
  scores_iou = []
26
 
27
+ show_visu = True
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
32
 
33
  idx = 0
34
  for a in tqdm(ds['train'], desc="Processing dataset"):
35
  #plot_all_modalities(a)
36
+ #pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
37
  try:
38
+ pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
39
  except:
40
  pred_vertices, pred_edges = empty_solution()
41