Enhances FastPointNet architecture for improved accuracy.
Browse filesDeepens the FastPointNet architecture by adding more convolutional layers and fully connected layers.
Increases the number of parameters and introduces residual-like and skip connections for improved feature extraction and gradient flow, which results in better generalization and prediction accuracy.
Introduces global average pooling in addition to global max pooling and adjusts dropout rates to avoid overfitting.
Adds a filtering mechanism to the dataset loader.
- fast_pointnet.py +100 -68
- predict.py +58 -5
- train.py +11 -4
fast_pointnet.py
CHANGED
|
@@ -12,6 +12,7 @@ class FastPointNet(nn.Module):
|
|
| 12 |
"""
|
| 13 |
Fast PointNet implementation for 3D vertex prediction from point cloud patches.
|
| 14 |
Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True):
|
|
@@ -19,34 +20,46 @@ class FastPointNet(nn.Module):
|
|
| 19 |
self.max_points = max_points
|
| 20 |
self.predict_score = predict_score
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
self.conv1 = nn.Conv1d(input_dim,
|
| 24 |
-
self.conv2 = nn.Conv1d(
|
| 25 |
-
self.conv3 = nn.Conv1d(
|
|
|
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
self.
|
| 29 |
-
self.
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
self.
|
|
|
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
self.pos_fc1 = nn.Linear(512,
|
| 36 |
-
self.pos_fc2 = nn.Linear(
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
#
|
| 39 |
if self.predict_score:
|
| 40 |
-
self.score_fc1 = nn.Linear(512,
|
| 41 |
-
self.score_fc2 = nn.Linear(
|
| 42 |
-
self.score_fc3 = nn.Linear(
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
self.
|
| 48 |
-
self.
|
|
|
|
|
|
|
| 49 |
self.bn5 = nn.BatchNorm1d(1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def forward(self, x):
|
| 52 |
"""
|
|
@@ -61,32 +74,47 @@ class FastPointNet(nn.Module):
|
|
| 61 |
"""
|
| 62 |
batch_size = x.size(0)
|
| 63 |
|
| 64 |
-
#
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
if self.predict_score:
|
| 84 |
-
#
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
return position, score
|
| 92 |
else:
|
|
@@ -235,6 +263,25 @@ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
|
| 235 |
|
| 236 |
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
|
| 239 |
score_weight: float = 0.1):
|
| 240 |
"""
|
|
@@ -252,28 +299,9 @@ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, ba
|
|
| 252 |
print(f"Training on device: {device}")
|
| 253 |
|
| 254 |
# Create dataset and dataloader
|
| 255 |
-
dataset = PatchDataset(dataset_dir, max_points=1024, augment=
|
| 256 |
print(f"Dataset loaded with {len(dataset)} samples")
|
| 257 |
|
| 258 |
-
# Create dataloader with custom collate function to filter invalid samples
|
| 259 |
-
def collate_fn(batch):
|
| 260 |
-
valid_batch = []
|
| 261 |
-
for patch_data, target, valid_mask, distance in batch:
|
| 262 |
-
# Filter out invalid samples (no valid points or dummy targets)
|
| 263 |
-
if valid_mask.sum() > 0 and not torch.all(target == 0):
|
| 264 |
-
valid_batch.append((patch_data, target, valid_mask, distance))
|
| 265 |
-
|
| 266 |
-
if len(valid_batch) == 0:
|
| 267 |
-
return None
|
| 268 |
-
|
| 269 |
-
# Stack valid samples
|
| 270 |
-
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 271 |
-
targets = torch.stack([item[1] for item in valid_batch])
|
| 272 |
-
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 273 |
-
distances = torch.stack([item[3] for item in valid_batch])
|
| 274 |
-
|
| 275 |
-
return patch_data, targets, valid_masks, distances
|
| 276 |
-
|
| 277 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 278 |
collate_fn=collate_fn, drop_last=True)
|
| 279 |
|
|
@@ -401,7 +429,7 @@ def load_pointnet_model(model_path: str, device: torch.device = None, predict_sc
|
|
| 401 |
|
| 402 |
return model
|
| 403 |
|
| 404 |
-
def predict_vertex_from_patch(model: FastPointNet,
|
| 405 |
"""
|
| 406 |
Predict 3D vertex coordinates and confidence score from a patch using trained PointNet.
|
| 407 |
|
|
@@ -418,7 +446,7 @@ def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device:
|
|
| 418 |
if device is None:
|
| 419 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 420 |
|
| 421 |
-
|
| 422 |
|
| 423 |
# Prepare input
|
| 424 |
max_points = 1024
|
|
@@ -443,6 +471,10 @@ def predict_vertex_from_patch(model: FastPointNet, patch_7d: np.ndarray, device:
|
|
| 443 |
position, score = model(patch_tensor)
|
| 444 |
position = position.cpu().numpy().squeeze()
|
| 445 |
score = score.cpu().numpy().squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
return position, score
|
| 447 |
else:
|
| 448 |
position = model(patch_tensor)
|
|
|
|
| 12 |
"""
|
| 13 |
Fast PointNet implementation for 3D vertex prediction from point cloud patches.
|
| 14 |
Takes 7D point clouds (x,y,z,r,g,b,filtered_flag) and predicts 3D vertex coordinates.
|
| 15 |
+
Enhanced with deeper architecture and more parameters for better generalization.
|
| 16 |
"""
|
| 17 |
|
| 18 |
def __init__(self, input_dim=7, output_dim=3, max_points=1024, predict_score=True):
|
|
|
|
| 20 |
self.max_points = max_points
|
| 21 |
self.predict_score = predict_score
|
| 22 |
|
| 23 |
+
# Enhanced point-wise MLPs with deeper architecture
|
| 24 |
+
self.conv1 = nn.Conv1d(input_dim, 128, 1)
|
| 25 |
+
self.conv2 = nn.Conv1d(128, 256, 1)
|
| 26 |
+
self.conv3 = nn.Conv1d(256, 512, 1)
|
| 27 |
+
self.conv4 = nn.Conv1d(512, 1024, 1)
|
| 28 |
|
| 29 |
+
# Additional layers for better feature extraction
|
| 30 |
+
self.conv5 = nn.Conv1d(1024, 1024, 1)
|
| 31 |
+
self.conv6 = nn.Conv1d(1024, 2048, 1)
|
| 32 |
|
| 33 |
+
# Larger shared features
|
| 34 |
+
self.shared_fc1 = nn.Linear(2048, 1024)
|
| 35 |
+
self.shared_fc2 = nn.Linear(1024, 512)
|
| 36 |
|
| 37 |
+
# Enhanced position prediction head
|
| 38 |
+
self.pos_fc1 = nn.Linear(512, 512)
|
| 39 |
+
self.pos_fc2 = nn.Linear(512, 256)
|
| 40 |
+
self.pos_fc3 = nn.Linear(256, 128)
|
| 41 |
+
self.pos_fc4 = nn.Linear(128, output_dim)
|
| 42 |
|
| 43 |
+
# Enhanced score prediction head
|
| 44 |
if self.predict_score:
|
| 45 |
+
self.score_fc1 = nn.Linear(512, 512)
|
| 46 |
+
self.score_fc2 = nn.Linear(512, 256)
|
| 47 |
+
self.score_fc3 = nn.Linear(256, 128)
|
| 48 |
+
self.score_fc4 = nn.Linear(128, 64)
|
| 49 |
+
self.score_fc5 = nn.Linear(64, 1)
|
| 50 |
+
|
| 51 |
+
# Batch normalization layers
|
| 52 |
+
self.bn1 = nn.BatchNorm1d(128)
|
| 53 |
+
self.bn2 = nn.BatchNorm1d(256)
|
| 54 |
+
self.bn3 = nn.BatchNorm1d(512)
|
| 55 |
+
self.bn4 = nn.BatchNorm1d(1024)
|
| 56 |
self.bn5 = nn.BatchNorm1d(1024)
|
| 57 |
+
self.bn6 = nn.BatchNorm1d(2048)
|
| 58 |
+
|
| 59 |
+
# Dropout with different rates
|
| 60 |
+
self.dropout_light = nn.Dropout(0.2)
|
| 61 |
+
self.dropout_medium = nn.Dropout(0.3)
|
| 62 |
+
self.dropout_heavy = nn.Dropout(0.4)
|
| 63 |
|
| 64 |
def forward(self, x):
|
| 65 |
"""
|
|
|
|
| 74 |
"""
|
| 75 |
batch_size = x.size(0)
|
| 76 |
|
| 77 |
+
# Enhanced point-wise feature extraction with residual-like connections
|
| 78 |
+
x1 = F.relu(self.bn1(self.conv1(x)))
|
| 79 |
+
x2 = F.relu(self.bn2(self.conv2(x1)))
|
| 80 |
+
x3 = F.relu(self.bn3(self.conv3(x2)))
|
| 81 |
+
x4 = F.relu(self.bn4(self.conv4(x3)))
|
| 82 |
+
x5 = F.relu(self.bn5(self.conv5(x4)))
|
| 83 |
+
x6 = F.relu(self.bn6(self.conv6(x5)))
|
| 84 |
+
|
| 85 |
+
# Global max pooling with additional global average pooling
|
| 86 |
+
max_pool = torch.max(x6, 2)[0] # (batch_size, 2048)
|
| 87 |
+
avg_pool = torch.mean(x6, 2) # (batch_size, 2048)
|
| 88 |
+
|
| 89 |
+
# Combine max and average pooling for richer global features
|
| 90 |
+
global_features = max_pool + avg_pool # (batch_size, 2048)
|
| 91 |
+
|
| 92 |
+
# Enhanced shared features with residual connection
|
| 93 |
+
shared1 = F.relu(self.shared_fc1(global_features))
|
| 94 |
+
shared1 = self.dropout_light(shared1)
|
| 95 |
+
shared2 = F.relu(self.shared_fc2(shared1))
|
| 96 |
+
shared_features = self.dropout_medium(shared2)
|
| 97 |
+
|
| 98 |
+
# Enhanced position prediction with skip connections
|
| 99 |
+
pos1 = F.relu(self.pos_fc1(shared_features))
|
| 100 |
+
pos1 = self.dropout_light(pos1)
|
| 101 |
+
pos2 = F.relu(self.pos_fc2(pos1))
|
| 102 |
+
pos2 = self.dropout_medium(pos2)
|
| 103 |
+
pos3 = F.relu(self.pos_fc3(pos2))
|
| 104 |
+
pos3 = self.dropout_light(pos3)
|
| 105 |
+
position = self.pos_fc4(pos3)
|
| 106 |
|
| 107 |
if self.predict_score:
|
| 108 |
+
# Enhanced score prediction
|
| 109 |
+
score1 = F.relu(self.score_fc1(shared_features))
|
| 110 |
+
score1 = self.dropout_light(score1)
|
| 111 |
+
score2 = F.relu(self.score_fc2(score1))
|
| 112 |
+
score2 = self.dropout_medium(score2)
|
| 113 |
+
score3 = F.relu(self.score_fc3(score2))
|
| 114 |
+
score3 = self.dropout_light(score3)
|
| 115 |
+
score4 = F.relu(self.score_fc4(score3))
|
| 116 |
+
score4 = self.dropout_light(score4)
|
| 117 |
+
score = F.relu(self.score_fc5(score4)) # Ensure positive distance
|
| 118 |
|
| 119 |
return position, score
|
| 120 |
else:
|
|
|
|
| 263 |
|
| 264 |
print(f"Saved {len(patches)} patches for entry {entry_id}")
|
| 265 |
|
| 266 |
+
# Create dataloader with custom collate function to filter invalid samples
|
| 267 |
+
def collate_fn(batch):
|
| 268 |
+
valid_batch = []
|
| 269 |
+
for patch_data, target, valid_mask, distance in batch:
|
| 270 |
+
# Filter out invalid samples (no valid points or dummy targets)
|
| 271 |
+
if valid_mask.sum() > 0 and not torch.all(target == 0):
|
| 272 |
+
valid_batch.append((patch_data, target, valid_mask, distance))
|
| 273 |
+
|
| 274 |
+
if len(valid_batch) == 0:
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
# Stack valid samples
|
| 278 |
+
patch_data = torch.stack([item[0] for item in valid_batch])
|
| 279 |
+
targets = torch.stack([item[1] for item in valid_batch])
|
| 280 |
+
valid_masks = torch.stack([item[2] for item in valid_batch])
|
| 281 |
+
distances = torch.stack([item[3] for item in valid_batch])
|
| 282 |
+
|
| 283 |
+
return patch_data, targets, valid_masks, distances
|
| 284 |
+
|
| 285 |
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001,
|
| 286 |
score_weight: float = 0.1):
|
| 287 |
"""
|
|
|
|
| 299 |
print(f"Training on device: {device}")
|
| 300 |
|
| 301 |
# Create dataset and dataloader
|
| 302 |
+
dataset = PatchDataset(dataset_dir, max_points=1024, augment=False)
|
| 303 |
print(f"Dataset loaded with {len(dataset)} samples")
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
|
| 306 |
collate_fn=collate_fn, drop_last=True)
|
| 307 |
|
|
|
|
| 429 |
|
| 430 |
return model
|
| 431 |
|
| 432 |
+
def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float]:
|
| 433 |
"""
|
| 434 |
Predict 3D vertex coordinates and confidence score from a patch using trained PointNet.
|
| 435 |
|
|
|
|
| 446 |
if device is None:
|
| 447 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 448 |
|
| 449 |
+
patch_7d = patch['patch_7d'] # (N, 7)
|
| 450 |
|
| 451 |
# Prepare input
|
| 452 |
max_points = 1024
|
|
|
|
| 471 |
position, score = model(patch_tensor)
|
| 472 |
position = position.cpu().numpy().squeeze()
|
| 473 |
score = score.cpu().numpy().squeeze()
|
| 474 |
+
|
| 475 |
+
offset = patch['offset']
|
| 476 |
+
position -= offset
|
| 477 |
+
|
| 478 |
return position, score
|
| 479 |
else:
|
| 480 |
position = model(patch_tensor)
|
predict.py
CHANGED
|
@@ -11,7 +11,7 @@ import cv2
|
|
| 11 |
import open3d as o3d
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
|
| 13 |
import pyvista as pv
|
| 14 |
-
from fast_pointnet import save_patches_dataset
|
| 15 |
|
| 16 |
GENERATE_DATASET = True
|
| 17 |
#DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
|
@@ -388,7 +388,53 @@ def create_3d_wireframe_single_image(vertices: List[dict],
|
|
| 388 |
return vertices_3d
|
| 389 |
|
| 390 |
|
| 391 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
"""
|
| 393 |
Predict 3D wireframe from a dataset entry.
|
| 394 |
"""
|
|
@@ -421,6 +467,13 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
|
|
| 421 |
|
| 422 |
continue
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
|
| 425 |
# Get 2D vertices and edges first
|
| 426 |
#vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
|
|
@@ -908,7 +961,7 @@ def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices,
|
|
| 908 |
if pid in point_idxs:
|
| 909 |
patch_7d[i, 6] = 1.0
|
| 910 |
else:
|
| 911 |
-
patch_7d[i, 6] =
|
| 912 |
|
| 913 |
if filtered_vertices[group_idx] is not None:
|
| 914 |
initial_pred = filtered_vertices[group_idx] + offset
|
|
@@ -961,7 +1014,7 @@ def generate_patches(colmap_rec, filtered_points_idxs, frame, filtered_vertices,
|
|
| 961 |
plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
|
| 962 |
|
| 963 |
plotter.show(title=f"Patch {group_idx}")
|
| 964 |
-
|
| 965 |
return patches
|
| 966 |
|
| 967 |
def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
|
|
@@ -991,7 +1044,7 @@ def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_se
|
|
| 991 |
|
| 992 |
if len(uv) == 0:
|
| 993 |
print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
|
| 994 |
-
return [], [], []
|
| 995 |
|
| 996 |
house_mask = get_house_mask(ade_seg)
|
| 997 |
|
|
|
|
| 11 |
import open3d as o3d
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local
|
| 13 |
import pyvista as pv
|
| 14 |
+
from fast_pointnet import save_patches_dataset, predict_vertex_from_patch
|
| 15 |
|
| 16 |
GENERATE_DATASET = True
|
| 17 |
#DATASET_DIR = '/home/skvrnjan/personal/hohocustom/'
|
|
|
|
| 388 |
return vertices_3d
|
| 389 |
|
| 390 |
|
| 391 |
+
def visu_patch_and_pred(patch, pred):
|
| 392 |
+
# Create plotter
|
| 393 |
+
plotter = pv.Plotter()
|
| 394 |
+
|
| 395 |
+
# Create point cloud for this patch
|
| 396 |
+
offset = patch.get('offset', None) # Offset if available
|
| 397 |
+
patch_points_3d = np.array(patch['patch_7d'][:, :3])
|
| 398 |
+
patch_points_3d = patch_points_3d - offset
|
| 399 |
+
patch_cloud = pv.PolyData(patch_points_3d)
|
| 400 |
+
|
| 401 |
+
point_idxs = patch['filtered_point_ids'] # List of point indices that are filtered
|
| 402 |
+
patch_point_ids = patch['point_ids'] # Assuming the 7th column contains point IDs
|
| 403 |
+
assigned_gt_vertex = patch.get('assigned_gt_vertex', None) # GT vertex if available
|
| 404 |
+
initial_pred = patch.get('initial_pred', None) # Initial prediction if available
|
| 405 |
+
initial_pred = initial_pred - offset
|
| 406 |
+
|
| 407 |
+
assigned_gt_vertex = assigned_gt_vertex - offset
|
| 408 |
+
|
| 409 |
+
# Color points: red for filtered points, blue for other points
|
| 410 |
+
patch_point_colors = []
|
| 411 |
+
for i, pid in enumerate(patch_point_ids):
|
| 412 |
+
if pid in point_idxs:
|
| 413 |
+
patch_point_colors.append([255, 0, 0]) # Red for filtered points
|
| 414 |
+
else:
|
| 415 |
+
patch_point_colors.append([0, 0, 255]) # Blue for other points
|
| 416 |
+
|
| 417 |
+
patch_cloud["colors"] = np.array(patch_point_colors)
|
| 418 |
+
plotter.add_mesh(patch_cloud, scalars="colors", rgb=True, point_size=8, render_points_as_spheres=True)
|
| 419 |
+
|
| 420 |
+
# Create sphere to visualize GT vertex if available
|
| 421 |
+
if assigned_gt_vertex is not None:
|
| 422 |
+
gt_sphere = pv.Sphere(radius=0.1, center=assigned_gt_vertex)
|
| 423 |
+
plotter.add_mesh(gt_sphere, color="green", opacity=0.5)
|
| 424 |
+
|
| 425 |
+
if initial_pred is not None:
|
| 426 |
+
# Create sphere to visualize initial prediction
|
| 427 |
+
pred_sphere = pv.Sphere(radius=0.1, center=initial_pred)
|
| 428 |
+
plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
|
| 429 |
+
|
| 430 |
+
if pred is not None:
|
| 431 |
+
# Create sphere to visualize predicted vertex
|
| 432 |
+
pred_sphere = pv.Sphere(radius=0.1, center=pred)
|
| 433 |
+
plotter.add_mesh(pred_sphere, color="red", opacity=0.5)
|
| 434 |
+
|
| 435 |
+
plotter.show(title=f"Patch x")
|
| 436 |
+
|
| 437 |
+
def predict_wireframe(entry, pnet_model) -> Tuple[np.ndarray, List[int]]:
|
| 438 |
"""
|
| 439 |
Predict 3D wireframe from a dataset entry.
|
| 440 |
"""
|
|
|
|
| 467 |
|
| 468 |
continue
|
| 469 |
|
| 470 |
+
for patch in patches:
|
| 471 |
+
pred_vertex, pred_dist = predict_vertex_from_patch(pnet_model, patch, device='cuda')
|
| 472 |
+
visu_patch_and_pred(patch, pred_vertex)
|
| 473 |
+
|
| 474 |
+
x = 0
|
| 475 |
+
|
| 476 |
+
|
| 477 |
vertices, connections, vertices_3d = vertices_ours, connections_ours, vertices_3d_ours
|
| 478 |
# Get 2D vertices and edges first
|
| 479 |
#vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th=25.)
|
|
|
|
| 961 |
if pid in point_idxs:
|
| 962 |
patch_7d[i, 6] = 1.0
|
| 963 |
else:
|
| 964 |
+
patch_7d[i, 6] = -1.0
|
| 965 |
|
| 966 |
if filtered_vertices[group_idx] is not None:
|
| 967 |
initial_pred = filtered_vertices[group_idx] + offset
|
|
|
|
| 1014 |
plotter.add_mesh(pred_sphere, color="orange", opacity=0.5)
|
| 1015 |
|
| 1016 |
plotter.show(title=f"Patch {group_idx}")
|
| 1017 |
+
|
| 1018 |
return patches
|
| 1019 |
|
| 1020 |
def our_get_vertices_and_edges(gest_seg_np, colmap_rec, img_id_substring, ade_seg, depth, K=None, R=None, t=None, frame=None):
|
|
|
|
| 1044 |
|
| 1045 |
if len(uv) == 0:
|
| 1046 |
print(f"No points projected into image bounds for {img_id_substring} using K,R,t.")
|
| 1047 |
+
return [], [], [], []
|
| 1048 |
|
| 1049 |
house_mask = get_house_mask(ade_seg)
|
| 1050 |
|
train.py
CHANGED
|
@@ -5,6 +5,7 @@ import pycolmap
|
|
| 5 |
import tempfile,zipfile
|
| 6 |
import io
|
| 7 |
import open3d as o3d
|
|
|
|
| 8 |
|
| 9 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
|
| 10 |
from utils import read_colmap_rec, empty_solution
|
|
@@ -13,22 +14,28 @@ from utils import read_colmap_rec, empty_solution
|
|
| 13 |
from hoho2025.metric_helper import hss
|
| 14 |
from predict import predict_wireframe
|
| 15 |
from tqdm import tqdm
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
ds = load_dataset("usm3d/hoho25k", cache_dir=
|
| 18 |
ds = ds.shuffle()
|
| 19 |
|
| 20 |
scores_hss = []
|
| 21 |
scores_f1 = []
|
| 22 |
scores_iou = []
|
| 23 |
|
| 24 |
-
show_visu =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
idx = 0
|
| 27 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 28 |
#plot_all_modalities(a)
|
| 29 |
-
#pred_vertices, pred_edges = predict_wireframe(a)
|
| 30 |
try:
|
| 31 |
-
pred_vertices, pred_edges = predict_wireframe(a)
|
| 32 |
except:
|
| 33 |
pred_vertices, pred_edges = empty_solution()
|
| 34 |
|
|
|
|
| 5 |
import tempfile,zipfile
|
| 6 |
import io
|
| 7 |
import open3d as o3d
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
|
| 11 |
from utils import read_colmap_rec, empty_solution
|
|
|
|
| 14 |
from hoho2025.metric_helper import hss
|
| 15 |
from predict import predict_wireframe
|
| 16 |
from tqdm import tqdm
|
| 17 |
+
from fast_pointnet import load_pointnet_model
|
| 18 |
+
import torch
|
| 19 |
|
| 20 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 21 |
ds = ds.shuffle()
|
| 22 |
|
| 23 |
scores_hss = []
|
| 24 |
scores_f1 = []
|
| 25 |
scores_iou = []
|
| 26 |
|
| 27 |
+
show_visu = True
|
| 28 |
+
|
| 29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
+
|
| 31 |
+
pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 32 |
|
| 33 |
idx = 0
|
| 34 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 35 |
#plot_all_modalities(a)
|
| 36 |
+
#pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
|
| 37 |
try:
|
| 38 |
+
pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
|
| 39 |
except:
|
| 40 |
pred_vertices, pred_edges = empty_solution()
|
| 41 |
|