hoho / fast_pointnet_class.py
jskvrna's picture
Preparation of the files for the public release.
33113fd
# This file defines a PointNet-based model for binary classification of 6D point cloud patches.
# It includes the model architecture (ClassificationPointNet), a custom dataset class
# (PatchClassificationDataset) for loading and augmenting patches, functions for saving
# patches to create a dataset, a training loop (train_pointnet), a function to load
# a trained model (load_pointnet_model), and a function for predicting class labels
# from new patches (predict_class_from_patch).
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple, Optional
import json
class ClassificationPointNet(nn.Module):
"""
PointNet implementation for binary classification from 6D point cloud patches.
Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge).
"""
def __init__(self, input_dim=6, max_points=1024):
super(ClassificationPointNet, self).__init__()
self.max_points = max_points
# Point-wise MLPs for feature extraction (deeper network)
self.conv1 = nn.Conv1d(input_dim, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 256, 1)
self.conv4 = nn.Conv1d(256, 512, 1)
self.conv5 = nn.Conv1d(512, 1024, 1)
self.conv6 = nn.Conv1d(1024, 2048, 1) # Additional layer
# Classification head (deeper with more capacity)
self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, 128)
self.fc5 = nn.Linear(128, 64)
self.fc6 = nn.Linear(64, 1) # Single output for binary classification
# Batch normalization layers
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(256)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(1024)
self.bn6 = nn.BatchNorm1d(2048)
# Dropout layers
self.dropout1 = nn.Dropout(0.3)
self.dropout2 = nn.Dropout(0.4)
self.dropout3 = nn.Dropout(0.5)
self.dropout4 = nn.Dropout(0.4)
self.dropout5 = nn.Dropout(0.3)
def forward(self, x):
"""
Forward pass
Args:
x: (batch_size, input_dim, max_points) tensor
Returns:
classification: (batch_size, 1) tensor of logits (sigmoid for probability)
"""
batch_size = x.size(0)
# Point-wise feature extraction
x1 = F.relu(self.bn1(self.conv1(x)))
x2 = F.relu(self.bn2(self.conv2(x1)))
x3 = F.relu(self.bn3(self.conv3(x2)))
x4 = F.relu(self.bn4(self.conv4(x3)))
x5 = F.relu(self.bn5(self.conv5(x4)))
x6 = F.relu(self.bn6(self.conv6(x5)))
# Global max pooling
global_features = torch.max(x6, 2)[0] # (batch_size, 2048)
# Classification head
x = F.relu(self.fc1(global_features))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = F.relu(self.fc3(x))
x = self.dropout3(x)
x = F.relu(self.fc4(x))
x = self.dropout4(x)
x = F.relu(self.fc5(x))
x = self.dropout5(x)
classification = self.fc6(x) # (batch_size, 1)
return classification
class PatchClassificationDataset(Dataset):
"""
Dataset class for loading saved patches for PointNet classification training.
"""
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True):
self.dataset_dir = dataset_dir
self.max_points = max_points
self.augment = augment
# Load patch files
self.patch_files = []
for file in os.listdir(dataset_dir):
if file.endswith('.pkl'):
self.patch_files.append(os.path.join(dataset_dir, file))
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}")
def __len__(self):
return len(self.patch_files)
def __getitem__(self, idx):
"""
Load and process a patch for training.
Returns:
patch_data: (6, max_points) tensor of point cloud data
label: scalar tensor for binary classification (0 or 1)
valid_mask: (max_points,) boolean tensor indicating valid points
"""
patch_file = self.patch_files[idx]
with open(patch_file, 'rb') as f:
patch_info = pickle.load(f)
patch_6d = patch_info['patch_6d'] # (N, 6)
label = patch_info.get('label', 0) # Get binary classification label (0 or 1)
# Pad or sample points to max_points
num_points = patch_6d.shape[0]
if num_points >= self.max_points:
# Randomly sample max_points
indices = np.random.choice(num_points, self.max_points, replace=False)
patch_sampled = patch_6d[indices]
valid_mask = np.ones(self.max_points, dtype=bool)
else:
# Pad with zeros
patch_sampled = np.zeros((self.max_points, 6))
patch_sampled[:num_points] = patch_6d
valid_mask = np.zeros(self.max_points, dtype=bool)
valid_mask[:num_points] = True
# Data augmentation
if self.augment:
patch_sampled = self._augment_patch(patch_sampled, valid_mask)
# Convert to tensors and transpose for conv1d (channels first)
patch_tensor = torch.from_numpy(patch_sampled.T).float() # (6, max_points)
label_tensor = torch.tensor(label, dtype=torch.float32) # Float for BCE loss
valid_mask_tensor = torch.from_numpy(valid_mask)
return patch_tensor, label_tensor, valid_mask_tensor
def _augment_patch(self, patch, valid_mask):
"""
Apply data augmentation to the patch.
"""
valid_points = patch[valid_mask]
if len(valid_points) == 0:
return patch
# Random rotation around z-axis
angle = np.random.uniform(0, 2 * np.pi)
cos_angle = np.cos(angle)
sin_angle = np.sin(angle)
rotation_matrix = np.array([
[cos_angle, -sin_angle, 0],
[sin_angle, cos_angle, 0],
[0, 0, 1]
])
# Apply rotation to xyz coordinates
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T
# Random jittering
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape)
valid_points[:, :3] += noise
# Random scaling
scale = np.random.uniform(0.9, 1.1)
valid_points[:, :3] *= scale
patch[valid_mask] = valid_points
return patch
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
"""
Save patches from prediction pipeline to create a training dataset.
Args:
patches: List of patch dictionaries from generate_patches()
dataset_dir: Directory to save the dataset
entry_id: Unique identifier for this entry/image
"""
os.makedirs(dataset_dir, exist_ok=True)
for i, patch in enumerate(patches):
# Create unique filename
filename = f"{entry_id}_patch_{i}.pkl"
filepath = os.path.join(dataset_dir, filename)
# Skip if file already exists
if os.path.exists(filepath):
continue
# Save patch data
with open(filepath, 'wb') as f:
pickle.dump(patch, f)
print(f"Saved {len(patches)} patches for entry {entry_id}")
# Create dataloader with custom collate function to filter invalid samples
def collate_fn(batch):
valid_batch = []
for patch_data, label, valid_mask in batch:
# Filter out invalid samples (no valid points)
if valid_mask.sum() > 0:
valid_batch.append((patch_data, label, valid_mask))
if len(valid_batch) == 0:
return None
# Stack valid samples
patch_data = torch.stack([item[0] for item in valid_batch])
labels = torch.stack([item[1] for item in valid_batch])
valid_masks = torch.stack([item[2] for item in valid_batch])
return patch_data, labels, valid_masks
# Initialize weights using Xavier/Glorot initialization
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32,
lr: float = 0.001):
"""
Train the ClassificationPointNet model on saved patches.
Args:
dataset_dir: Directory containing saved patch files
model_save_path: Path to save the trained model
epochs: Number of training epochs
batch_size: Training batch size
lr: Learning rate
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")
# Create dataset and dataloader
dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True)
print(f"Dataset loaded with {len(dataset)} samples")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8,
collate_fn=collate_fn, drop_last=True)
# Initialize model
model = ClassificationPointNet(input_dim=6, max_points=1024)
model.apply(init_weights)
model.to(device)
# Loss function and optimizer (BCE for binary classification)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
# Training loop
model.train()
for epoch in range(epochs):
total_loss = 0.0
correct = 0
total = 0
num_batches = 0
for batch_idx, batch_data in enumerate(dataloader):
if batch_data is None: # Skip invalid batches
continue
patch_data, labels, valid_masks = batch_data
patch_data = patch_data.to(device) # (batch_size, 6, max_points)
labels = labels.to(device).unsqueeze(1) # (batch_size, 1)
# Forward pass
optimizer.zero_grad()
outputs = model(patch_data) # (batch_size, 1)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
predicted = (torch.sigmoid(outputs) > 0.5).float()
total += labels.size(0)
correct += (predicted == labels).sum().item()
num_batches += 1
if batch_idx % 50 == 0:
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
f"Loss: {loss.item():.6f}, "
f"Accuracy: {100 * correct / total:.2f}%")
avg_loss = total_loss / num_batches if num_batches > 0 else 0
accuracy = 100 * correct / total if total > 0 else 0
print(f"Epoch {epoch+1}/{epochs} completed, "
f"Avg Loss: {avg_loss:.6f}, "
f"Accuracy: {accuracy:.2f}%")
scheduler.step()
# Save model checkpoint every epoch
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch + 1,
'loss': avg_loss,
'accuracy': accuracy,
}, checkpoint_path)
# Save the trained model
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epochs,
}, model_save_path)
print(f"Model saved to {model_save_path}")
return model
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet:
"""
Load a trained ClassificationPointNet model.
Args:
model_path: Path to the saved model
device: Device to load the model on
Returns:
Loaded ClassificationPointNet model
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ClassificationPointNet(input_dim=6, max_points=1024)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]:
"""
Predict binary classification from a patch using trained PointNet.
Args:
model: Trained ClassificationPointNet model
patch: Dictionary containing patch data with 'patch_6d' key
device: Device to run prediction on
Returns:
tuple of (predicted_class, confidence)
predicted_class: int (0 for not edge, 1 for edge)
confidence: float representing confidence score (0-1)
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
patch_6d = patch['patch_6d'] # (N, 6)
# Prepare input
max_points = 1024
num_points = patch_6d.shape[0]
if num_points >= max_points:
# Sample points
indices = np.random.choice(num_points, max_points, replace=False)
patch_sampled = patch_6d[indices]
else:
# Pad with zeros
patch_sampled = np.zeros((max_points, 6))
patch_sampled[:num_points] = patch_6d
# Convert to tensor
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) # (1, 6, max_points)
patch_tensor = patch_tensor.to(device)
# Predict
with torch.no_grad():
outputs = model(patch_tensor) # (1, 1)
probability = torch.sigmoid(outputs).item()
predicted_class = int(probability > 0.5)
return predicted_class, probability