# This file defines a PointNet-based model for binary classification of 6D point cloud patches. # It includes the model architecture (ClassificationPointNet), a custom dataset class # (PatchClassificationDataset) for loading and augmenting patches, functions for saving # patches to create a dataset, a training loop (train_pointnet), a function to load # a trained model (load_pointnet_model), and a function for predicting class labels # from new patches (predict_class_from_patch). import os import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pickle from torch.utils.data import Dataset, DataLoader from typing import List, Dict, Tuple, Optional import json class ClassificationPointNet(nn.Module): """ PointNet implementation for binary classification from 6D point cloud patches. Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge). """ def __init__(self, input_dim=6, max_points=1024): super(ClassificationPointNet, self).__init__() self.max_points = max_points # Point-wise MLPs for feature extraction (deeper network) self.conv1 = nn.Conv1d(input_dim, 64, 1) self.conv2 = nn.Conv1d(64, 128, 1) self.conv3 = nn.Conv1d(128, 256, 1) self.conv4 = nn.Conv1d(256, 512, 1) self.conv5 = nn.Conv1d(512, 1024, 1) self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer # Classification head (deeper with more capacity) self.fc1 = nn.Linear(2048, 1024) self.fc2 = nn.Linear(1024, 512) self.fc3 = nn.Linear(512, 256) self.fc4 = nn.Linear(256, 128) self.fc5 = nn.Linear(128, 64) self.fc6 = nn.Linear(64, 1) # Single output for binary classification # Batch normalization layers self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(256) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(1024) self.bn6 = nn.BatchNorm1d(2048) # Dropout layers self.dropout1 = nn.Dropout(0.3) self.dropout2 = nn.Dropout(0.4) self.dropout3 = nn.Dropout(0.5) self.dropout4 = nn.Dropout(0.4) self.dropout5 = nn.Dropout(0.3) def forward(self, x): """ Forward pass Args: x: (batch_size, input_dim, max_points) tensor Returns: classification: (batch_size, 1) tensor of logits (sigmoid for probability) """ batch_size = x.size(0) # Point-wise feature extraction x1 = F.relu(self.bn1(self.conv1(x))) x2 = F.relu(self.bn2(self.conv2(x1))) x3 = F.relu(self.bn3(self.conv3(x2))) x4 = F.relu(self.bn4(self.conv4(x3))) x5 = F.relu(self.bn5(self.conv5(x4))) x6 = F.relu(self.bn6(self.conv6(x5))) # Global max pooling global_features = torch.max(x6, 2)[0] # (batch_size, 2048) # Classification head x = F.relu(self.fc1(global_features)) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) x = F.relu(self.fc3(x)) x = self.dropout3(x) x = F.relu(self.fc4(x)) x = self.dropout4(x) x = F.relu(self.fc5(x)) x = self.dropout5(x) classification = self.fc6(x) # (batch_size, 1) return classification class PatchClassificationDataset(Dataset): """ Dataset class for loading saved patches for PointNet classification training. """ def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True): self.dataset_dir = dataset_dir self.max_points = max_points self.augment = augment # Load patch files self.patch_files = [] for file in os.listdir(dataset_dir): if file.endswith('.pkl'): self.patch_files.append(os.path.join(dataset_dir, file)) print(f"Found {len(self.patch_files)} patch files in {dataset_dir}") def __len__(self): return len(self.patch_files) def __getitem__(self, idx): """ Load and process a patch for training. Returns: patch_data: (6, max_points) tensor of point cloud data label: scalar tensor for binary classification (0 or 1) valid_mask: (max_points,) boolean tensor indicating valid points """ patch_file = self.patch_files[idx] with open(patch_file, 'rb') as f: patch_info = pickle.load(f) patch_6d = patch_info['patch_6d'] # (N, 6) label = patch_info.get('label', 0) # Get binary classification label (0 or 1) # Pad or sample points to max_points num_points = patch_6d.shape[0] if num_points >= self.max_points: # Randomly sample max_points indices = np.random.choice(num_points, self.max_points, replace=False) patch_sampled = patch_6d[indices] valid_mask = np.ones(self.max_points, dtype=bool) else: # Pad with zeros patch_sampled = np.zeros((self.max_points, 6)) patch_sampled[:num_points] = patch_6d valid_mask = np.zeros(self.max_points, dtype=bool) valid_mask[:num_points] = True # Data augmentation if self.augment: patch_sampled = self._augment_patch(patch_sampled, valid_mask) # Convert to tensors and transpose for conv1d (channels first) patch_tensor = torch.from_numpy(patch_sampled.T).float() # (6, max_points) label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss valid_mask_tensor = torch.from_numpy(valid_mask) return patch_tensor, label_tensor, valid_mask_tensor def _augment_patch(self, patch, valid_mask): """ Apply data augmentation to the patch. """ valid_points = patch[valid_mask] if len(valid_points) == 0: return patch # Random rotation around z-axis angle = np.random.uniform(0, 2 * np.pi) cos_angle = np.cos(angle) sin_angle = np.sin(angle) rotation_matrix = np.array([ [cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1] ]) # Apply rotation to xyz coordinates valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T # Random jittering noise = np.random.normal(0, 0.01, valid_points[:, :3].shape) valid_points[:, :3] += noise # Random scaling scale = np.random.uniform(0.9, 1.1) valid_points[:, :3] *= scale patch[valid_mask] = valid_points return patch def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str): """ Save patches from prediction pipeline to create a training dataset. Args: patches: List of patch dictionaries from generate_patches() dataset_dir: Directory to save the dataset entry_id: Unique identifier for this entry/image """ os.makedirs(dataset_dir, exist_ok=True) for i, patch in enumerate(patches): # Create unique filename filename = f"{entry_id}_patch_{i}.pkl" filepath = os.path.join(dataset_dir, filename) # Skip if file already exists if os.path.exists(filepath): continue # Save patch data with open(filepath, 'wb') as f: pickle.dump(patch, f) print(f"Saved {len(patches)} patches for entry {entry_id}") # Create dataloader with custom collate function to filter invalid samples def collate_fn(batch): valid_batch = [] for patch_data, label, valid_mask in batch: # Filter out invalid samples (no valid points) if valid_mask.sum() > 0: valid_batch.append((patch_data, label, valid_mask)) if len(valid_batch) == 0: return None # Stack valid samples patch_data = torch.stack([item[0] for item in valid_batch]) labels = torch.stack([item[1] for item in valid_batch]) valid_masks = torch.stack([item[2] for item in valid_batch]) return patch_data, labels, valid_masks # Initialize weights using Xavier/Glorot initialization def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm1d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001): """ Train the ClassificationPointNet model on saved patches. Args: dataset_dir: Directory containing saved patch files model_save_path: Path to save the trained model epochs: Number of training epochs batch_size: Training batch size lr: Learning rate """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on device: {device}") # Create dataset and dataloader dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True) print(f"Dataset loaded with {len(dataset)} samples") dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn, drop_last=True) # Initialize model model = ClassificationPointNet(input_dim=6, max_points=1024) model.apply(init_weights) model.to(device) # Loss function and optimizer (BCE for binary classification) criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) # Training loop model.train() for epoch in range(epochs): total_loss = 0.0 correct = 0 total = 0 num_batches = 0 for batch_idx, batch_data in enumerate(dataloader): if batch_data is None: # Skip invalid batches continue patch_data, labels, valid_masks = batch_data patch_data = patch_data.to(device) # (batch_size, 6, max_points) labels = labels.to(device).unsqueeze(1) # (batch_size, 1) # Forward pass optimizer.zero_grad() outputs = model(patch_data) # (batch_size, 1) loss = criterion(outputs, labels) # Backward pass loss.backward() optimizer.step() # Statistics total_loss += loss.item() predicted = (torch.sigmoid(outputs) > 0.5).float() total += labels.size(0) correct += (predicted == labels).sum().item() num_batches += 1 if batch_idx % 50 == 0: print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, " f"Loss: {loss.item():.6f}, " f"Accuracy: {100 * correct / total:.2f}%") avg_loss = total_loss / num_batches if num_batches > 0 else 0 accuracy = 100 * correct / total if total > 0 else 0 print(f"Epoch {epoch+1}/{epochs} completed, " f"Avg Loss: {avg_loss:.6f}, " f"Accuracy: {accuracy:.2f}%") scheduler.step() # Save model checkpoint every epoch checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth') torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch + 1, 'loss': avg_loss, 'accuracy': accuracy, }, checkpoint_path) # Save the trained model torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epochs, }, model_save_path) print(f"Model saved to {model_save_path}") return model def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet: """ Load a trained ClassificationPointNet model. Args: model_path: Path to the saved model device: Device to load the model on Returns: Loaded ClassificationPointNet model """ if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ClassificationPointNet(input_dim=6, max_points=1024) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() return model def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]: """ Predict binary classification from a patch using trained PointNet. Args: model: Trained ClassificationPointNet model patch: Dictionary containing patch data with 'patch_6d' key device: Device to run prediction on Returns: tuple of (predicted_class, confidence) predicted_class: int (0 for not edge, 1 for edge) confidence: float representing confidence score (0-1) """ if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') patch_6d = patch['patch_6d'] # (N, 6) # Prepare input max_points = 1024 num_points = patch_6d.shape[0] if num_points >= max_points: # Sample points indices = np.random.choice(num_points, max_points, replace=False) patch_sampled = patch_6d[indices] else: # Pad with zeros patch_sampled = np.zeros((max_points, 6)) patch_sampled[:num_points] = patch_6d # Convert to tensor patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 6, max_points) patch_tensor = patch_tensor.to(device) # Predict with torch.no_grad(): outputs = model(patch_tensor) # (1, 1) probability = torch.sigmoid(outputs).item() predicted_class = int(probability > 0.5) return predicted_class, probability