|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import pickle |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
import json |
|
|
|
|
|
class ClassificationPointNet(nn.Module): |
|
|
""" |
|
|
PointNet implementation for binary classification from 6D point cloud patches. |
|
|
Takes 6D point clouds (x,y,z,r,g,b) and predicts binary classification (edge/not edge). |
|
|
""" |
|
|
def __init__(self, input_dim=6, max_points=1024): |
|
|
super(ClassificationPointNet, self).__init__() |
|
|
self.max_points = max_points |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv1d(input_dim, 64, 1) |
|
|
self.conv2 = nn.Conv1d(64, 128, 1) |
|
|
self.conv3 = nn.Conv1d(128, 256, 1) |
|
|
self.conv4 = nn.Conv1d(256, 512, 1) |
|
|
self.conv5 = nn.Conv1d(512, 1024, 1) |
|
|
self.conv6 = nn.Conv1d(1024, 2048, 1) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(2048, 1024) |
|
|
self.fc2 = nn.Linear(1024, 512) |
|
|
self.fc3 = nn.Linear(512, 256) |
|
|
self.fc4 = nn.Linear(256, 128) |
|
|
self.fc5 = nn.Linear(128, 64) |
|
|
self.fc6 = nn.Linear(64, 1) |
|
|
|
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(64) |
|
|
self.bn2 = nn.BatchNorm1d(128) |
|
|
self.bn3 = nn.BatchNorm1d(256) |
|
|
self.bn4 = nn.BatchNorm1d(512) |
|
|
self.bn5 = nn.BatchNorm1d(1024) |
|
|
self.bn6 = nn.BatchNorm1d(2048) |
|
|
|
|
|
|
|
|
self.dropout1 = nn.Dropout(0.3) |
|
|
self.dropout2 = nn.Dropout(0.4) |
|
|
self.dropout3 = nn.Dropout(0.5) |
|
|
self.dropout4 = nn.Dropout(0.4) |
|
|
self.dropout5 = nn.Dropout(0.3) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass |
|
|
Args: |
|
|
x: (batch_size, input_dim, max_points) tensor |
|
|
Returns: |
|
|
classification: (batch_size, 1) tensor of logits (sigmoid for probability) |
|
|
""" |
|
|
batch_size = x.size(0) |
|
|
|
|
|
|
|
|
x1 = F.relu(self.bn1(self.conv1(x))) |
|
|
x2 = F.relu(self.bn2(self.conv2(x1))) |
|
|
x3 = F.relu(self.bn3(self.conv3(x2))) |
|
|
x4 = F.relu(self.bn4(self.conv4(x3))) |
|
|
x5 = F.relu(self.bn5(self.conv5(x4))) |
|
|
x6 = F.relu(self.bn6(self.conv6(x5))) |
|
|
|
|
|
|
|
|
global_features = torch.max(x6, 2)[0] |
|
|
|
|
|
|
|
|
x = F.relu(self.fc1(global_features)) |
|
|
x = self.dropout1(x) |
|
|
x = F.relu(self.fc2(x)) |
|
|
x = self.dropout2(x) |
|
|
x = F.relu(self.fc3(x)) |
|
|
x = self.dropout3(x) |
|
|
x = F.relu(self.fc4(x)) |
|
|
x = self.dropout4(x) |
|
|
x = F.relu(self.fc5(x)) |
|
|
x = self.dropout5(x) |
|
|
classification = self.fc6(x) |
|
|
|
|
|
return classification |
|
|
|
|
|
class PatchClassificationDataset(Dataset): |
|
|
""" |
|
|
Dataset class for loading saved patches for PointNet classification training. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True): |
|
|
self.dataset_dir = dataset_dir |
|
|
self.max_points = max_points |
|
|
self.augment = augment |
|
|
|
|
|
|
|
|
self.patch_files = [] |
|
|
for file in os.listdir(dataset_dir): |
|
|
if file.endswith('.pkl'): |
|
|
self.patch_files.append(os.path.join(dataset_dir, file)) |
|
|
|
|
|
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.patch_files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Load and process a patch for training. |
|
|
Returns: |
|
|
patch_data: (6, max_points) tensor of point cloud data |
|
|
label: scalar tensor for binary classification (0 or 1) |
|
|
valid_mask: (max_points,) boolean tensor indicating valid points |
|
|
""" |
|
|
patch_file = self.patch_files[idx] |
|
|
|
|
|
with open(patch_file, 'rb') as f: |
|
|
patch_info = pickle.load(f) |
|
|
|
|
|
patch_6d = patch_info['patch_6d'] |
|
|
label = patch_info.get('label', 0) |
|
|
|
|
|
|
|
|
num_points = patch_6d.shape[0] |
|
|
|
|
|
if num_points >= self.max_points: |
|
|
|
|
|
indices = np.random.choice(num_points, self.max_points, replace=False) |
|
|
patch_sampled = patch_6d[indices] |
|
|
valid_mask = np.ones(self.max_points, dtype=bool) |
|
|
else: |
|
|
|
|
|
patch_sampled = np.zeros((self.max_points, 6)) |
|
|
patch_sampled[:num_points] = patch_6d |
|
|
valid_mask = np.zeros(self.max_points, dtype=bool) |
|
|
valid_mask[:num_points] = True |
|
|
|
|
|
|
|
|
if self.augment: |
|
|
patch_sampled = self._augment_patch(patch_sampled, valid_mask) |
|
|
|
|
|
|
|
|
patch_tensor = torch.from_numpy(patch_sampled.T).float() |
|
|
label_tensor = torch.tensor(label, dtype=torch.float32) |
|
|
valid_mask_tensor = torch.from_numpy(valid_mask) |
|
|
|
|
|
return patch_tensor, label_tensor, valid_mask_tensor |
|
|
|
|
|
def _augment_patch(self, patch, valid_mask): |
|
|
""" |
|
|
Apply data augmentation to the patch. |
|
|
""" |
|
|
valid_points = patch[valid_mask] |
|
|
|
|
|
if len(valid_points) == 0: |
|
|
return patch |
|
|
|
|
|
|
|
|
angle = np.random.uniform(0, 2 * np.pi) |
|
|
cos_angle = np.cos(angle) |
|
|
sin_angle = np.sin(angle) |
|
|
rotation_matrix = np.array([ |
|
|
[cos_angle, -sin_angle, 0], |
|
|
[sin_angle, cos_angle, 0], |
|
|
[0, 0, 1] |
|
|
]) |
|
|
|
|
|
|
|
|
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T |
|
|
|
|
|
|
|
|
noise = np.random.normal(0, 0.01, valid_points[:, :3].shape) |
|
|
valid_points[:, :3] += noise |
|
|
|
|
|
|
|
|
scale = np.random.uniform(0.9, 1.1) |
|
|
valid_points[:, :3] *= scale |
|
|
|
|
|
patch[valid_mask] = valid_points |
|
|
return patch |
|
|
|
|
|
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str): |
|
|
""" |
|
|
Save patches from prediction pipeline to create a training dataset. |
|
|
|
|
|
Args: |
|
|
patches: List of patch dictionaries from generate_patches() |
|
|
dataset_dir: Directory to save the dataset |
|
|
entry_id: Unique identifier for this entry/image |
|
|
""" |
|
|
os.makedirs(dataset_dir, exist_ok=True) |
|
|
|
|
|
for i, patch in enumerate(patches): |
|
|
|
|
|
filename = f"{entry_id}_patch_{i}.pkl" |
|
|
filepath = os.path.join(dataset_dir, filename) |
|
|
|
|
|
|
|
|
if os.path.exists(filepath): |
|
|
continue |
|
|
|
|
|
|
|
|
with open(filepath, 'wb') as f: |
|
|
pickle.dump(patch, f) |
|
|
|
|
|
print(f"Saved {len(patches)} patches for entry {entry_id}") |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
valid_batch = [] |
|
|
for patch_data, label, valid_mask in batch: |
|
|
|
|
|
if valid_mask.sum() > 0: |
|
|
valid_batch.append((patch_data, label, valid_mask)) |
|
|
|
|
|
if len(valid_batch) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
patch_data = torch.stack([item[0] for item in valid_batch]) |
|
|
labels = torch.stack([item[1] for item in valid_batch]) |
|
|
valid_masks = torch.stack([item[2] for item in valid_batch]) |
|
|
|
|
|
return patch_data, labels, valid_masks |
|
|
|
|
|
|
|
|
def init_weights(m): |
|
|
if isinstance(m, nn.Conv1d): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.BatchNorm1d): |
|
|
nn.init.ones_(m.weight) |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, |
|
|
lr: float = 0.001): |
|
|
""" |
|
|
Train the ClassificationPointNet model on saved patches. |
|
|
|
|
|
Args: |
|
|
dataset_dir: Directory containing saved patch files |
|
|
model_save_path: Path to save the trained model |
|
|
epochs: Number of training epochs |
|
|
batch_size: Training batch size |
|
|
lr: Learning rate |
|
|
""" |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Training on device: {device}") |
|
|
|
|
|
|
|
|
dataset = PatchClassificationDataset(dataset_dir, max_points=1024, augment=True) |
|
|
print(f"Dataset loaded with {len(dataset)} samples") |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, |
|
|
collate_fn=collate_fn, drop_last=True) |
|
|
|
|
|
|
|
|
model = ClassificationPointNet(input_dim=6, max_points=1024) |
|
|
model.apply(init_weights) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
criterion = nn.BCEWithLogitsLoss() |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) |
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) |
|
|
|
|
|
|
|
|
model.train() |
|
|
for epoch in range(epochs): |
|
|
total_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch_idx, batch_data in enumerate(dataloader): |
|
|
if batch_data is None: |
|
|
continue |
|
|
|
|
|
patch_data, labels, valid_masks = batch_data |
|
|
patch_data = patch_data.to(device) |
|
|
labels = labels.to(device).unsqueeze(1) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(patch_data) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
predicted = (torch.sigmoid(outputs) > 0.5).float() |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
num_batches += 1 |
|
|
|
|
|
if batch_idx % 50 == 0: |
|
|
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, " |
|
|
f"Loss: {loss.item():.6f}, " |
|
|
f"Accuracy: {100 * correct / total:.2f}%") |
|
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
|
|
accuracy = 100 * correct / total if total > 0 else 0 |
|
|
|
|
|
print(f"Epoch {epoch+1}/{epochs} completed, " |
|
|
f"Avg Loss: {avg_loss:.6f}, " |
|
|
f"Accuracy: {accuracy:.2f}%") |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth') |
|
|
torch.save({ |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'epoch': epoch + 1, |
|
|
'loss': avg_loss, |
|
|
'accuracy': accuracy, |
|
|
}, checkpoint_path) |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'epoch': epochs, |
|
|
}, model_save_path) |
|
|
|
|
|
print(f"Model saved to {model_save_path}") |
|
|
return model |
|
|
|
|
|
def load_pointnet_model(model_path: str, device: torch.device = None) -> ClassificationPointNet: |
|
|
""" |
|
|
Load a trained ClassificationPointNet model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the saved model |
|
|
device: Device to load the model on |
|
|
|
|
|
Returns: |
|
|
Loaded ClassificationPointNet model |
|
|
""" |
|
|
if device is None: |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
model = ClassificationPointNet(input_dim=6, max_points=1024) |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
def predict_class_from_patch(model: ClassificationPointNet, patch: Dict, device: torch.device = None) -> Tuple[int, float]: |
|
|
""" |
|
|
Predict binary classification from a patch using trained PointNet. |
|
|
|
|
|
Args: |
|
|
model: Trained ClassificationPointNet model |
|
|
patch: Dictionary containing patch data with 'patch_6d' key |
|
|
device: Device to run prediction on |
|
|
|
|
|
Returns: |
|
|
tuple of (predicted_class, confidence) |
|
|
predicted_class: int (0 for not edge, 1 for edge) |
|
|
confidence: float representing confidence score (0-1) |
|
|
""" |
|
|
if device is None: |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
patch_6d = patch['patch_6d'] |
|
|
|
|
|
|
|
|
max_points = 1024 |
|
|
num_points = patch_6d.shape[0] |
|
|
|
|
|
if num_points >= max_points: |
|
|
|
|
|
indices = np.random.choice(num_points, max_points, replace=False) |
|
|
patch_sampled = patch_6d[indices] |
|
|
else: |
|
|
|
|
|
patch_sampled = np.zeros((max_points, 6)) |
|
|
patch_sampled[:num_points] = patch_6d |
|
|
|
|
|
|
|
|
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) |
|
|
patch_tensor = patch_tensor.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(patch_tensor) |
|
|
probability = torch.sigmoid(outputs).item() |
|
|
predicted_class = int(probability > 0.5) |
|
|
|
|
|
return predicted_class, probability |
|
|
|
|
|
|