""" End-to-End Voxel-Based Vertex Detection Pipeline This file implements a complete pipeline for detecting wireframe vertices from 3D point clouds using a voxel-based deep learning approach. The pipeline includes: 1. Data preprocessing: Converting 14D point clouds into 3D voxel grids with averaged features 2. Ground truth generation: Creating binary vertex labels and refinement targets from wireframe vertices 3. Model architecture: VoxelUNet with encoder-decoder structure and 1x1x1 bottleneck for vertex detection 4. Training: Combined loss function with BCE, Dice loss, and MSE for offset regression 5. Inference: Predicting vertex locations from new point clouds with visualization Key components: - Voxelization with configurable grid size and metric voxel size - Per-voxel MLP before convolutional processing - Gaussian smoothing of ground truth labels - Refinement prediction for sub-voxel accuracy - PyVista-based visualization for results analysis Usage: - Set inference=False to train a new model - Set inference=True to run predictions on existing data """ import os import pickle import torch import torch.nn as nn import torch.optim as optim import numpy as np from typing import Dict, Any, Tuple, List from torch.utils.data import Dataset, DataLoader import glob import pyvista as pv import torch # [Previous code from the existing document remains unchanged up to CombinedLoss class] # ... (save_data, load_data, get_data_files, voxelize_points, create_ground_truth, VoxelUNet, VoxelDataset) ... def save_data(dict_to_save: Dict[str, Any], filename: str, data_folder: str = "data") -> None: """Save dictionary data to pickle file""" os.makedirs(data_folder, exist_ok=True) filepath = os.path.join(data_folder, f"{filename}.pkl") with open(filepath, 'wb') as f: pickle.dump(dict_to_save, f) #print(f"Data saved to {filepath}") def load_data(filepath: str) -> Dict[str, Any]: """Load dictionary data from pickle file""" with open(filepath, 'rb') as f: data = pickle.load(f) #print(f"Data loaded from {filepath}") return data def get_data_files(data_folder: str = "data", pattern: str = "*.pkl") -> List[str]: """Get list of data files from folder""" search_pattern = os.path.join(data_folder, pattern) files = glob.glob(search_pattern) #print(f"Found {len(files)} data files in {data_folder}") return files def voxelize_points(points: np.ndarray, grid_size_xy: int = 64, voxel_size_metric: float = 0.25 ) -> Tuple[torch.Tensor, np.ndarray, Dict[str, Any]]: """ Voxelize 14D point cloud into a 3D grid with a fixed number of voxels and fixed metric voxel size. The Z dimension of the grid will also have `grid_size_xy` voxels, forming a cubic grid. The point cloud is centered within this metric grid. Points outside are discarded. Features from points falling into the same voxel are averaged. Args: points: (N, 14) array where first 3 dims are xyz (original coordinates). grid_size_xy: Number of voxels along X and Y dimensions (and Z). voxel_size_metric: The physical size of each voxel (e.g., 0.5 units). Returns: voxel_grid: (NUM_FEATURES, dim_z, dim_y, dim_x) tensor with averaged features. voxel_indices_for_points: (N_points_in_grid, 3) integer voxel indices (z, y, x) for each input point that falls within the grid. scale_info: Dict with transformation parameters: 'grid_origin_metric': Real-world metric coordinate of the corner of voxel [0,0,0] (x,y,z). 'voxel_size_metric': The metric size of a voxel. 'grid_dims_voxels': Tuple (dim_x, dim_y, dim_z) representing number of voxels. 'pc_centroid_metric': Centroid of the input point cloud (x,y,z). """ NUM_FEATURES = 14 dim_x = grid_size_xy dim_y = grid_size_xy dim_z = grid_size_xy # Assuming cubic grid if dim_z == 0: dim_z = 1 # Ensure at least one voxel in Z grid_dims_voxels = np.array([dim_x, dim_y, dim_z], dtype=int) def _get_empty_return(reason: str = ""): voxel_grid_empty = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32) voxel_indices_empty = np.empty((0, 3), dtype=int) scale_info_empty = { 'grid_origin_metric': np.zeros(3, dtype=float), 'voxel_size_metric': voxel_size_metric, 'grid_dims_voxels': tuple(grid_dims_voxels.tolist()), 'pc_centroid_metric': np.zeros(3, dtype=float), } return voxel_grid_empty, voxel_indices_empty, scale_info_empty if points.shape[0] == 0: return _get_empty_return("Initial empty point cloud") xyz = points[:, :3] features_other = points[:, 3:] pc_centroid_metric = xyz.mean(axis=0) grid_metric_span = grid_dims_voxels * voxel_size_metric grid_origin_metric = pc_centroid_metric - (grid_metric_span / 2.0) # Voxel grid to store summed features voxel_grid_sum = torch.zeros(NUM_FEATURES, grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.float32) # Counter for points per voxel point_counts_in_voxel = torch.zeros(grid_dims_voxels[2], grid_dims_voxels[1], grid_dims_voxels[0], dtype=torch.int32) continuous_voxel_coords = (xyz - grid_origin_metric) / voxel_size_metric voxel_indices_for_points_zyx_order = [] for i in range(points.shape[0]): current_point_continuous_coord_xyz = continuous_voxel_coords[i] # Using np.round for voxel assignment (assigns to nearest voxel center) voxel_idx_int_xyz = np.round(current_point_continuous_coord_xyz).astype(int) idx_x, idx_y, idx_z = voxel_idx_int_xyz[0], voxel_idx_int_xyz[1], voxel_idx_int_xyz[2] if not (0 <= idx_x < grid_dims_voxels[0] and \ 0 <= idx_y < grid_dims_voxels[1] and \ 0 <= idx_z < grid_dims_voxels[2]): continue # Point is outside the grid voxel_indices_for_points_zyx_order.append([idx_z, idx_y, idx_x]) assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5 offset_xyz_in_grid_units = current_point_continuous_coord_xyz - assigned_voxel_center_grid_idx_space # Accumulate features in voxel_grid_sum voxel_grid_sum[0, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[0] # dx voxel_grid_sum[1, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[1] # dy voxel_grid_sum[2, idx_z, idx_y, idx_x] += offset_xyz_in_grid_units[2] # dz if NUM_FEATURES > 3: current_point_other_features = features_other[i] voxel_grid_sum[3:, idx_z, idx_y, idx_x] += torch.tensor(current_point_other_features, dtype=torch.float32) point_counts_in_voxel[idx_z, idx_y, idx_x] += 1 # Averaging step # Initialize the final voxel_grid which will store averaged features voxel_grid = torch.zeros_like(voxel_grid_sum) # Prepare counts for division, ensuring no division by zero. # Convert counts to float for division. counts_for_division = point_counts_in_voxel.float() # For voxels with 0 points, counts_for_division is 0.0. # To avoid 0/0 = NaN, set these counts to 1.0. Since voxel_grid_sum is 0 there, # the result of 0.0 / 1.0 will be 0.0, which is correct. counts_for_division[counts_for_division == 0] = 1.0 # Perform averaging: # voxel_grid_sum is (C, D, H, W) # counts_for_division.unsqueeze(0) is (1, D, H, W), broadcasting correctly. voxel_grid = voxel_grid_sum / counts_for_division.unsqueeze(0) final_voxel_indices_for_points_zyx = np.array(voxel_indices_for_points_zyx_order, dtype=int) if voxel_indices_for_points_zyx_order else np.empty((0,3), dtype=int) scale_info = { 'grid_origin_metric': grid_origin_metric, 'voxel_size_metric': voxel_size_metric, 'grid_dims_voxels': tuple(grid_dims_voxels.tolist()), 'pc_centroid_metric': pc_centroid_metric, } return voxel_grid, final_voxel_indices_for_points_zyx, scale_info def create_ground_truth(vertices: np.ndarray, scale_info: Dict[str, Any] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create ground truth voxel labels and refinement targets using metric voxelization info. The grid dimensions are taken from scale_info. Args: vertices: (M, 3) vertex coordinates in original metric space. scale_info: Dict from voxelize_points. Requires: 'grid_origin_metric', 'voxel_size_metric', 'grid_dims_voxels'. Returns: vertex_labels: (dim_z, dim_y, dim_x) binary labels (1.0 for voxel containing a vertex). refinement_targets: (3, dim_z, dim_y, dim_x) offset (dx,dy,dz) from voxel cell center in grid units. Range approx [-0.5, 0.5). """ grid_origin_metric = scale_info['grid_origin_metric'] # (ox, oy, oz) voxel_size_metric = scale_info['voxel_size_metric'] # grid_dims_voxels is (num_voxels_x, num_voxels_y, num_voxels_z) grid_dims_voxels = np.array(scale_info['grid_dims_voxels']) dim_x, dim_y, dim_z = grid_dims_voxels[0], grid_dims_voxels[1], grid_dims_voxels[2] # Labels tensor: (dim_z, dim_y, dim_x) vertex_labels = torch.zeros(dim_z, dim_y, dim_x, dtype=torch.float32) # Refinement targets tensor: (3, dim_z, dim_y, dim_x) for (dx, dy, dz) offsets refinement_targets = torch.zeros(3, dim_z, dim_y, dim_x, dtype=torch.float32) if vertices.shape[0] == 0: return vertex_labels, refinement_targets # Convert vertex metric coordinates to continuous voxel coordinates # (potentially fractional and outside [0, dim-1]) continuous_voxel_coords_vertices = (vertices - grid_origin_metric) / voxel_size_metric for i in range(vertices.shape[0]): # v_continuous_coord_xyz is (vx, vy, vz) for the current vertex in continuous voxel space v_continuous_coord_xyz = continuous_voxel_coords_vertices[i] # Integer voxel index (ix, iy, iz) by flooring v_idx_int_xyz = np.floor(v_continuous_coord_xyz).astype(int) # Clip to be within grid boundaries [0, dim-1] idx_x = np.clip(v_idx_int_xyz[0], 0, dim_x - 1) idx_y = np.clip(v_idx_int_xyz[1], 0, dim_y - 1) idx_z = np.clip(v_idx_int_xyz[2], 0, dim_z - 1) # Set label for this voxel (using z, y, x order for tensor access) vertex_labels[idx_z, idx_y, idx_x] = 1.0 # Calculate refinement offset: # Center of the *assigned* (clipped) voxel in continuous grid index space assigned_voxel_center_grid_idx_space = np.array([idx_x, idx_y, idx_z], dtype=float) + 0.5 # Offset of the vertex from its *assigned* voxel center, in grid units. offset_xyz_grid_units = v_continuous_coord_xyz - assigned_voxel_center_grid_idx_space # Store dx, dy, dz in channels 0, 1, 2 respectively # refinement_targets is (3, Z, Y, X) refinement_targets[0, idx_z, idx_y, idx_x] = offset_xyz_grid_units[0] # dx refinement_targets[1, idx_z, idx_y, idx_x] = offset_xyz_grid_units[1] # dy refinement_targets[2, idx_z, idx_y, idx_x] = offset_xyz_grid_units[2] # dz return vertex_labels, refinement_targets class VoxelUNet(nn.Module): """Encoder-decoder network with a 1x1x1 bottleneck for voxel-based vertex detection. Includes a per-voxel MLP before the first convolutional block.""" def __init__(self, in_channels: int = 14, base_channels: int = 32, bottleneck_expansion: int = 2, mlp_hidden_factor: int = 2): super(VoxelUNet, self).__init__() bc = base_channels # Per-voxel MLP # The MLP transforms input features per voxel before the convolutional encoder. # Input to MLP: in_channels # Output of MLP: base_channels (bc) mlp_hidden_dim = in_channels * mlp_hidden_factor # Intermediate dimension for the MLP self.voxel_mlp = nn.Sequential( nn.Linear(in_channels, mlp_hidden_dim), nn.ReLU(inplace=True), nn.Linear(mlp_hidden_dim, bc) # Output of MLP has 'base_channels' features ) # Encoder # self.enc1 now takes 'base_channels' as input from the MLP. self.enc1 = self._conv_block(bc, bc) # bc self.enc2 = self._conv_block(bc, bc * 2) # bc*2 self.enc3 = self._conv_block(bc * 2, bc * 4) # bc*4 self.enc4 = self._conv_block(bc * 4, bc * 8) # bc*8 self.enc5 = self._conv_block(bc * 8, bc * 16) # bc*16 self.pool = nn.MaxPool3d(2) # Bottleneck self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) bottleneck_in_channels = bc * 16 # Width of the bottleneck vector (number of channels after 1x1x1 pooling) bottleneck_width = bottleneck_in_channels * bottleneck_expansion self.bottleneck = nn.Sequential( nn.Conv3d(bottleneck_in_channels, bottleneck_width, kernel_size=1, padding=0, bias=True), nn.ReLU(inplace=True), # Second 1x1 conv to add more capacity/non-linearity in the bottleneck nn.Conv3d(bottleneck_width, bottleneck_width, kernel_size=1, padding=0, bias=True), nn.ReLU(inplace=True) ) # Decoder # Input channels for decoder blocks are adjusted as skip connections are removed. self.dec5 = self._conv_block(bottleneck_width, bc * 16) # Input from upsampled bottleneck self.up4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) self.dec4 = self._conv_block(bc * 16, bc * 8) # Input from dec5 output self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) self.dec3 = self._conv_block(bc * 8, bc * 4) # Input from dec4 output self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) self.dec2 = self._conv_block(bc * 4, bc * 2) # Input from dec3 output self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) self.dec1 = self._conv_block(bc * 2, bc) # Input from dec2 output # Output heads #self.vertex_head = nn.Conv3d(bc, 1, kernel_size=1) self.vertex_head = nn.Sequential( nn.Conv3d(bc, bc // 2, kernel_size=1), nn.ReLU(inplace=True), nn.Conv3d(bc // 2, bc // 4, kernel_size=1), nn.ReLU(inplace=True), nn.Conv3d(bc // 4, 1, kernel_size=1) ) self.refinement_head = nn.Conv3d(bc, 3, kernel_size=1) self.tanh = nn.Tanh() # For refinement head def _conv_block(self, in_channels: int, out_channels: int) -> nn.Sequential: # Standard convolutional block with two 3x3 convolutions # Using bias=False because BatchNorm3d is used after each convolution return nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # x shape: (B, C_in_raw, D, H, W) # Per-voxel MLP B, C_in_raw, D, H, W = x.shape # Permute to (B, D, H, W, C_in_raw) for nn.Linear x_permuted = x.permute(0, 2, 3, 4, 1).contiguous() # Flatten spatial dimensions: (B*D*H*W, C_in_raw) x_flattened = x_permuted.view(-1, C_in_raw) # Apply MLP: (B*D*H*W, C_mlp_out) where C_mlp_out is base_channels (bc) mlp_out_flattened = self.voxel_mlp(x_flattened) C_mlp_out = mlp_out_flattened.shape[-1] # Should be self.base_channels # Reshape back to (B, D, H, W, C_mlp_out) x_mlp_reshaped = mlp_out_flattened.view(B, D, H, W, C_mlp_out) # Permute back to (B, C_mlp_out, D, H, W) for 3D convolutions x_processed = x_mlp_reshaped.permute(0, 4, 1, 2, 3).contiguous() # Encoder path e1 = self.enc1(x_processed) # Output spatial: S, Output channels: bc p1 = self.pool(e1) # Output spatial: S/2 e2 = self.enc2(p1) # Output spatial: S/2, Output channels: bc*2 p2 = self.pool(e2) # Output spatial: S/4 e3 = self.enc3(p2) # Output spatial: S/4, Output channels: bc*4 p3 = self.pool(e3) # Output spatial: S/8 e4 = self.enc4(p3) # Output spatial: S/8, Output channels: bc*8 p4 = self.pool(e4) # Output spatial: S/16 e5 = self.enc5(p4) # Output spatial: S/16, Output channels: bc*16 p5 = self.pool(e5) # Output spatial: S/32, Channels: bc*16 (input to bottleneck path) # Bottleneck b_pooled = self.adaptive_pool(p5) # Output spatial: 1x1x1, Output channels: bc*16 b = self.bottleneck(b_pooled) # Output spatial: 1x1x1, Output channels: bottleneck_width # Decoder path # Upsample bottleneck output to match spatial dimensions of e5 (S/16) u5_from_b = nn.functional.interpolate(b, size=e5.shape[2:], mode='trilinear', align_corners=True) d5 = self.dec5(u5_from_b) # Output spatial: S/16, Output channels: bc*16 u4 = self.up4(d5) # Output spatial: S/8 d4 = self.dec4(u4) # Output spatial: S/8, Output channels: bc*8 u3 = self.up3(d4) # Output spatial: S/4 d3 = self.dec3(u3) # Output spatial: S/4, Output channels: bc*4 u2 = self.up2(d3) # Output spatial: S/2 d2 = self.dec2(u2) # Output spatial: S/2, Output channels: bc*2 u1 = self.up1(d2) # Output spatial: S d1 = self.dec1(u1) # Output spatial: S, Output channels: bc # Output heads vertex_logits = self.vertex_head(d1) refinement = self.tanh(self.refinement_head(d1)) * 0.5 # Output range [-0.5, 0.5] return vertex_logits, refinement class VoxelDataset(Dataset): def __init__(self, data_files: List[str], voxel_size: float = 0.1, grid_size: int = 64): self.data_files = data_files self.voxel_size = voxel_size self.grid_size = grid_size def __len__(self): return len(self.data_files) def __getitem__(self, idx): data = load_data(self.data_files[idx]) voxel_grid, _, scale_info = voxelize_points( data['pcloud_14d'], self.grid_size, self.voxel_size ) wf_vertices_np = np.array(data['wf_vertices']) vertex_labels, refinement_targets = create_ground_truth( wf_vertices_np, scale_info ) return voxel_grid, vertex_labels, refinement_targets, scale_info import torch.nn as nn import torch.nn.functional as F from typing import Tuple # Added for Tuple type hint class CombinedLoss(nn.Module): """ Combined loss for vertex classification and offset regression. Uses: - BCEWithLogitsLoss (with configurable negative/positive sample weighting) - Dice loss - MSE loss on refinement offsets (only over positive voxels) - Gaussian blur on the GT labels """ def __init__(self, vertex_weight: float = 1.0, refinement_weight: float = 0.0, dice_weight: float = 0.5, bce_neg_pos_ratio: float = 1.0, # Ratio of negative to positive sample weight in BCE blur_kernel_size: int = 5, blur_sigma: float = 1.0, eps: float = 1e-6): super().__init__() self.vertex_weight = vertex_weight self.refinement_weight = refinement_weight self.dice_weight = dice_weight self.bce_neg_pos_ratio = bce_neg_pos_ratio # Store the ratio self.eps = eps # BCE with logits (reduction='none' to apply custom weighting) self.bce_loss_fn = nn.BCEWithLogitsLoss(reduction='none') # MSE for offset regression self.mse_loss = nn.MSELoss() # build 3D gaussian kernel k = blur_kernel_size coords = torch.arange(k, dtype=torch.float32) - (k - 1) / 2 xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing='ij') kernel = torch.exp(-(xx**2 + yy**2 + zz**2) / (2 * blur_sigma**2)) # shape (1,1,k,k,k) kernel = kernel.view(1, 1, k, k, k) self.register_buffer('gaussian_kernel', kernel) self.pad = k // 2 def forward(self, vertex_logits_pred: torch.Tensor, # (B,1,D,H,W) refinement_pred: torch.Tensor, # (B,3,D,H,W) vertex_gt: torch.Tensor, # (B,D,H,W), 0/1 refinement_gt: torch.Tensor # (B,3,D,H,W) ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # logits & gt logits = vertex_logits_pred.squeeze(1) # (B,D,H,W) gt = vertex_gt.float() # (B,D,H,W) # apply gaussian blur on gt gt_unsq = gt.unsqueeze(1) # (B,1,D,H,W) gt_blur = F.conv3d(gt_unsq, self.gaussian_kernel, padding=self.pad) # (B,1,D,H,W) gt_blur = gt_blur.clamp(0, 1) # ensure values are in [0, 1] gt_smooth = gt_blur.squeeze(1) # (B,D,H,W) # 1) Weighted BCE loss - positive when gt_smooth > 1e-3 (using smoothed GT for mask definition) pos_mask = gt_smooth > 1e-3 # Mask for "positive" regions based on smoothed GT neg_mask = ~pos_mask # Mask for "negative" regions bce_all = self.bce_loss_fn(logits, gt_smooth) # Calculate BCE loss for all elements # Calculate weighted BCE pos_weight_factor = 1.0 # Weight for positive samples' contribution neg_weight_factor = self.bce_neg_pos_ratio # Weight for negative samples' contribution bce = torch.tensor(0.0, device=logits.device) num_pos = pos_mask.sum() num_neg = neg_mask.sum() if num_pos > 0 and num_neg > 0: mean_pos_loss = bce_all[pos_mask].mean() mean_neg_loss = bce_all[neg_mask].mean() bce = pos_weight_factor * mean_pos_loss + neg_weight_factor * mean_neg_loss elif num_pos > 0: # Only positive samples contribute mean_pos_loss = bce_all[pos_mask].mean() bce = pos_weight_factor * mean_pos_loss elif num_neg > 0: # Only negative samples contribute mean_neg_loss = bce_all[neg_mask].mean() bce = neg_weight_factor * mean_neg_loss # If no samples (num_pos=0 and num_neg=0), bce remains 0.0 # 2) Dice loss prob = torch.sigmoid(logits) # Use binarized smoothed GT for Dice target, consistent with original gt_smooth_round_for_dice = gt_smooth intersection = (prob * gt_smooth_round_for_dice).sum(dim=[1,2,3]) union = prob.sum(dim=[1,2,3]) + gt_smooth_round_for_dice.sum(dim=[1,2,3]) dice_score = (2 * intersection + self.eps) / (union + self.eps) dice_loss = 1 - dice_score.mean() vertex_loss = bce + self.dice_weight * dice_loss # 3) Refinement MSE (only where original gt==1, i.e., true vertex locations) # Use the original hard GT for selecting voxels for refinement loss mask_pos_refinement = (gt > 0.5).unsqueeze(1) refinement_loss = torch.tensor(0., device=logits.device) if mask_pos_refinement.sum() > 0: # Ensure pred and gt have the same shape for masked selection expanded_mask = mask_pos_refinement.expand_as(refinement_pred) pred_offsets = refinement_pred[expanded_mask].view(-1, 3) gt_offsets = refinement_gt[expanded_mask].view(-1, 3) if pred_offsets.numel() > 0: # Ensure there are elements to compute loss on refinement_loss = self.mse_loss(pred_offsets, gt_offsets) # 4) Total loss total_loss = (self.vertex_weight * vertex_loss + self.refinement_weight * refinement_loss) return total_loss, vertex_loss, refinement_loss def train_epoch(model, dataloader, optimizer, criterion, device, current_epoch: int): model.train() total_loss_epoch = 0.0 vertex_loss_epoch = 0.0 refinement_loss_epoch = 0.0 for batch_idx, (voxel_grid_batch, vertex_labels_batch, refinement_targets_batch, _) in enumerate(dataloader): voxel_grid_batch = voxel_grid_batch.to(device) vertex_labels_batch = vertex_labels_batch.to(device) refinement_targets_batch = refinement_targets_batch.to(device) if False: print(f'Epoch {current_epoch+1}, Batch {batch_idx+1}/{len(dataloader)}') sample_voxel_features = voxel_grid_batch[0].cpu().numpy() sample_gt_labels = vertex_labels_batch[0].cpu().numpy() sample_gt_refinement = refinement_targets_batch[0].cpu().numpy() summed_xyz_in_voxels = sample_voxel_features[:3] occupied_voxel_mask = np.any(summed_xyz_in_voxels != 0, axis=0) plotter = pv.Plotter(window_size=[800,600]) plotter.background_color = 'white' if np.any(occupied_voxel_mask): occupied_voxel_indices = np.array(np.where(occupied_voxel_mask)).T input_points_display = pv.PolyData(occupied_voxel_indices + 0.5) plotter.add_mesh(input_points_display, color='cornflowerblue', point_size=5, render_points_as_spheres=True, label='Occupied Voxels (Centers)') gt_vertex_voxel_mask = sample_gt_labels > 0.5 if np.any(gt_vertex_voxel_mask): gt_vertex_indices_int = np.array(np.where(gt_vertex_voxel_mask)).T gt_offsets = sample_gt_refinement[:, gt_vertex_voxel_mask].T gt_vertex_positions_grid_space = gt_vertex_indices_int.astype(float) + 0.5 + gt_offsets target_vertices_display = pv.PolyData(gt_vertex_positions_grid_space) plotter.add_mesh(target_vertices_display, color='crimson', point_size=10, render_points_as_spheres=True, label='Target Vertices (GT)') plotter.show(title=f"Debug Viz E{current_epoch+1} B{batch_idx+1}", auto_close=False) else: print(f"Epoch {current_epoch+1} Batch {batch_idx+1}: No data to visualize for the first sample.") optimizer.zero_grad() vertex_logits_pred, refinement_pred = model(voxel_grid_batch) loss, vertex_loss, refinement_loss = criterion( vertex_logits_pred, refinement_pred, vertex_labels_batch, refinement_targets_batch ) print(f"Batch {batch_idx+1}/{len(dataloader)}: Loss={loss.item():.4f}, Vertex Loss={vertex_loss.item():.4f}, Refinement Loss={refinement_loss.item():.4f}") if loss > 0.000001: loss.backward() optimizer.step() total_loss_epoch += loss.item() vertex_loss_epoch += vertex_loss.item() refinement_loss_epoch += refinement_loss.item() if (batch_idx + 1) % 200 == 0: checkpoint_path = f"model_epoch_{current_epoch+1}_batch_{batch_idx+1}_grid_128v9.pth" # Consider updating filename if grid size changes torch.save(model.state_dict(), checkpoint_path) print(f"Saved batch checkpoint: {checkpoint_path}") avg_total_loss = total_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0 avg_vertex_loss = vertex_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0 avg_refinement_loss = refinement_loss_epoch / len(dataloader) if len(dataloader) > 0 else 0 return avg_total_loss, avg_vertex_loss, avg_refinement_loss def train_model(data_folder: str = "data", num_epochs: int = 100, batch_size: int = 4, neg_pos_ratio_val: float = 1.0): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") data_files = get_data_files(data_folder) if not data_files: print(f"No data files found in {data_folder}. Exiting.") return GRID_SIZE_CFG = 128 VOXEL_SIZE_CFG = 0.5 dataset = VoxelDataset(data_files, voxel_size=VOXEL_SIZE_CFG, grid_size=GRID_SIZE_CFG) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8) model = VoxelUNet(in_channels=14, base_channels=32, bottleneck_expansion=4, mlp_hidden_factor= 10).to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = CombinedLoss( vertex_weight=10.0, refinement_weight=0.0, dice_weight=-0.0 ).to(device) print(f"Starting training: {num_epochs} epochs, Batch size: {batch_size}, Grid size: {GRID_SIZE_CFG}, Voxel size: {VOXEL_SIZE_CFG}, Initial LR: {optimizer.param_groups[0]['lr']}") for epoch in range(num_epochs): print(f"\n--- Epoch {epoch+1}/{num_epochs} ---") avg_loss, avg_vertex_loss, avg_refinement_loss = train_epoch( model, dataloader, optimizer, criterion, device, epoch ) print(f"Epoch {epoch+1} Summary: Avg Loss: {avg_loss:.4f}, " f"Avg Vertex Loss: {avg_vertex_loss:.4f}, " f"Avg Refinement Loss: {avg_refinement_loss:.4f}, " f"Current LR: {optimizer.param_groups[0]['lr']:.6f}") checkpoint_path = f"model_epoch_{epoch+1}_grid{GRID_SIZE_CFG}_smooth_bal{neg_pos_ratio_val}_v9.pth" torch.save(model.state_dict(), checkpoint_path) print(f"Saved checkpoint: {checkpoint_path}") final_model_path = f"final_model_grid{GRID_SIZE_CFG}_epochs{num_epochs}_smooth_bal{neg_pos_ratio_val}_v9.pth" torch.save(model.state_dict(), final_model_path) print(f"Training completed! Final model saved as {final_model_path}") def load_model_for_inference(model_path: str, device: torch.device, in_channels: int = 14, base_channels: int = 32) -> VoxelUNet: """Load a VoxelUNet model for inference.""" model = VoxelUNet(in_channels=14, base_channels=32, bottleneck_expansion=4, mlp_hidden_factor= 10) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() print(f"Model loaded from {model_path} and set to evaluation mode on {device}.") return model def predict_vertices(model: VoxelUNet, point_cloud_14d: np.ndarray, grid_size: int, device: torch.device, voxel_size_metric: float = 0.5, # Added for consistency, default matches voxelize_points vertex_threshold: float = 0.5) -> np.ndarray: """ Predict vertices from a 14D point cloud. Args: model: The trained VoxelUNet model. point_cloud_14d: (N, 14) NumPy array of the input point cloud. grid_size: The size of the voxel grid along X and Y dimensions (must match training). device: PyTorch device ('cuda' or 'cpu'). voxel_size_metric: The metric size of each voxel (must match training). vertex_threshold: Threshold for classifying a voxel as containing a vertex. Returns: predicted_vertices_original_space: (M, 3) NumPy array of predicted vertex coordinates in the original point cloud space (X, Y, Z order). Returns an empty array if no vertices are predicted or if the input point cloud results in an empty voxel grid. """ voxel_grid_tensor, _, scale_info = voxelize_points( point_cloud_14d, grid_size_xy=grid_size, voxel_size_metric=voxel_size_metric ) # Check if voxelization produced a valid grid (e.g., if input point cloud was empty) # voxelize_points returns a zero tensor for grid if input points are empty. # If voxel_grid_tensor is all zeros and no points were input, scale_info might be default. if voxel_grid_tensor.sum() == 0 and point_cloud_14d.shape[0] == 0: # This case implies empty input point cloud, voxelize_points handles this. # Predictions will naturally be empty if the grid is empty. pass # Continue, model will predict on zero grid. input_tensor = voxel_grid_tensor.unsqueeze(0).to(device) with torch.no_grad(): vertex_logits_pred_tensor, refinement_pred_tensor = model(input_tensor) vertex_prob_pred_tensor = torch.sigmoid(vertex_logits_pred_tensor) vertex_prob_pred_np = vertex_prob_pred_tensor.squeeze(0).squeeze(0).cpu().numpy() refinement_pred_np = refinement_pred_tensor.squeeze(0).cpu().numpy() # Shape (3, D, H, W) -> (dx,dy,dz channels) print(f"Vertex Probabilities Stats: Min={np.min(vertex_prob_pred_np):.4f}, Max={np.max(vertex_prob_pred_np):.4f}, Mean={np.mean(vertex_prob_pred_np):.4f}, Median={np.median(vertex_prob_pred_np):.4f}") if refinement_pred_np.size > 0: print(f"Refinement Predictions Stats: Min={np.min(refinement_pred_np):.4f}, Max={np.max(refinement_pred_np):.4f}, Mean={np.mean(refinement_pred_np):.4f}, Median={np.median(refinement_pred_np):.4f}") for i in range(refinement_pred_np.shape[0]): # Iterate over dx, dy, dz components print(f" Refinement Dim {i} (dx,dy,dz order) Stats: Min={np.min(refinement_pred_np[i]):.4f}, Max={np.max(refinement_pred_np[i]):.4f}, Mean={np.mean(refinement_pred_np[i]):.4f}, Median={np.median(refinement_pred_np[i]):.4f}") else: print("Refinement Predictions Stats: Array is empty.") predicted_mask = vertex_prob_pred_np > vertex_threshold # predicted_voxel_indices are (N_preds, 3) with columns (idx_z, idx_y, idx_x) predicted_voxel_indices_zyx = np.argwhere(predicted_mask) if not predicted_voxel_indices_zyx.size: return np.empty((0, 3), dtype=np.float32) # Extract refinement offsets for the predicted voxels # offsets_channels_first will be (3, N_preds) where channels are (dx, dy, dz) offsets_channels_first = refinement_pred_np[:, predicted_voxel_indices_zyx[:, 0], # z_indices predicted_voxel_indices_zyx[:, 1], # y_indices predicted_voxel_indices_zyx[:, 2]] # x_indices # Transpose to (N_preds, 3) where columns are (dx, dy, dz) offsets_xyz_order = offsets_channels_first.T # Calculate refined coordinates in continuous voxel grid space (X, Y, Z order) # Voxel center is at index + 0.5 # Refinement is added to this center. # predicted_voxel_indices_zyx[:, 2] is x_idx # predicted_voxel_indices_zyx[:, 1] is y_idx # predicted_voxel_indices_zyx[:, 0] is z_idx # offsets_xyz_order[:, 0] is dx # offsets_xyz_order[:, 1] is dy # offsets_xyz_order[:, 2] is dz refined_x_grid = predicted_voxel_indices_zyx[:, 2].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 0] refined_y_grid = predicted_voxel_indices_zyx[:, 1].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 1] refined_z_grid = predicted_voxel_indices_zyx[:, 0].astype(np.float32) + 0.5 #+ offsets_xyz_order[:, 2] # Stack to get (N_preds, 3) array in (X, Y, Z) order refined_grid_coords_xyz = np.stack((refined_x_grid, refined_y_grid, refined_z_grid), axis=-1) # Convert refined grid coordinates to original metric space grid_origin_metric = np.array(scale_info['grid_origin_metric']) # (ox, oy, oz) # Voxel_size_metric from scale_info should match the input voxel_size_metric parameter current_voxel_size_metric = scale_info['voxel_size_metric'] # predicted_vertices_original_space are (N_preds, 3) in (X,Y,Z) order predicted_vertices_original_space = refined_grid_coords_xyz * current_voxel_size_metric + grid_origin_metric return predicted_vertices_original_space.astype(np.float32) # Simple inference script def run_inference(model_path: str, data_file_path: str, output_file: str = None, grid_size: int = 128, voxel_size: float = 0.5, vertex_threshold: float = 0.5): """ Run inference on all data files in a directory, visualize with pyvista, and save results. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load model model = load_model_for_inference(model_path, device) # Get all data files from the directory data_files = get_data_files(data_file_path) if not data_files: print(f"No data files found in {data_file_path}") return print(f"Found {len(data_files)} data files to process") for i, file_path in enumerate(data_files): print(f"\n--- Processing file {i+1}/{len(data_files)}: {os.path.basename(file_path)} ---") # Load input data try: data = load_data(file_path) except Exception as e: print(f"Error loading {file_path}: {e}") continue if 'pcloud_14d' not in data: print(f"Error: File {file_path} does not contain 'pcloud_14d' key, skipping") continue # Extract original point cloud and ground-truth vertices pcloud = data['pcloud_14d'][:, :3] # (N,3) gt_vertices = np.array(data.get('wf_vertices', [])) # (M,3) or empty print(f"Input point cloud shape: {pcloud.shape}") if gt_vertices.size: print(f"GT vertices shape: {gt_vertices.shape}") # Run prediction print("Running inference...") try: predicted_vertices = predict_vertices( model=model, point_cloud_14d=data['pcloud_14d'], grid_size=grid_size, device=device, voxel_size_metric=voxel_size, vertex_threshold=vertex_threshold ) except Exception as e: print(f"Error during prediction for {file_path}: {e}") continue print(f"Predicted {len(predicted_vertices)} vertices") # --- Visualization --- plotter = pv.Plotter(window_size=[800,600]) plotter.background_color = 'white' # Original point cloud in light gray if pcloud.size: pc_cloud = pv.PolyData(pcloud) plotter.add_mesh(pc_cloud, color='lightgray', point_size=2, render_points_as_spheres=True, label='Input PC') # Ground-truth vertices in red if gt_vertices.size: gt_pd = pv.PolyData(gt_vertices) plotter.add_mesh(gt_pd, color='red', point_size=8, render_points_as_spheres=True, label='GT Vertices') # Predicted vertices in blue if predicted_vertices.size: pred_pd = pv.PolyData(predicted_vertices) plotter.add_mesh(pred_pd, color='blue', point_size=8, render_points_as_spheres=True, label='Predicted Vertices') plotter.add_legend() plotter.show(title=os.path.basename(file_path)) # Prepare output data output_data = { 'predicted_vertices': predicted_vertices, 'input_file': file_path, 'model_used': model_path, 'grid_size': grid_size, 'voxel_size': voxel_size, 'vertex_threshold': vertex_threshold, 'original_data': data } # Save results base_name = os.path.splitext(os.path.basename(file_path))[0] output_filename = f"{base_name}_predictions" try: save_data(output_data, output_filename) # Saves to 'data' subfolder by default print(f"Results saved to: data/{output_filename}.pkl") except Exception as e: print(f"Error saving results for {file_path}: {e}") print(f"\nCompleted processing {len(data_files)} files") if __name__ == "__main__": inference = False # Replace with your actual data folder path data_folder_train = 'YOUR_LOCAL_DATA_FOLDER_PATH' # Example: data_folder_train = '/path/to/your/training_data' num_epochs_train = 100 batch_size_train = 16 # This parameter now controls the ratio of negative to positive samples for BCE loss negative_to_positive_bce_ratio = 1 if inference: # Replace with your actual model path and data path for inference run_inference(model_path='YOUR_MODEL_PATH.pth', # Example: '/path/to/your/model.pth' data_file_path='YOUR_INFERENCE_DATA_FOLDER_PATH', # Example: '/path/to/your/inference_data' output_file=None, # Output will be saved in a 'data' subfolder relative to script grid_size=128, voxel_size=0.5, vertex_threshold=0.5 ) else: train_model(data_folder=data_folder_train, num_epochs=num_epochs_train, batch_size=batch_size_train, neg_pos_ratio_val=negative_to_positive_bce_ratio)