|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import pickle |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
import json |
|
|
|
|
|
class FastPointNet(nn.Module): |
|
|
""" |
|
|
Fast PointNet implementation for 3D vertex prediction from point cloud patches. |
|
|
Takes 11D point clouds and predicts 3D vertex coordinates. |
|
|
Enhanced with deeper architecture, efficient attention, and accuracy improvements. |
|
|
""" |
|
|
def __init__(self, input_dim=11, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1): |
|
|
super(FastPointNet, self).__init__() |
|
|
self.max_points = max_points |
|
|
self.predict_score = predict_score |
|
|
self.predict_class = predict_class |
|
|
self.num_classes = num_classes |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv1d(input_dim, 64, 1) |
|
|
self.conv2 = nn.Conv1d(64, 128, 1) |
|
|
self.conv3 = nn.Conv1d(128, 256, 1) |
|
|
self.conv4 = nn.Conv1d(256, 512, 1) |
|
|
self.conv5 = nn.Conv1d(512, 1024, 1) |
|
|
self.conv6 = nn.Conv1d(1024, 1024, 1) |
|
|
self.conv7 = nn.Conv1d(1024, 2048, 1) |
|
|
|
|
|
|
|
|
self.channel_attention = nn.Sequential( |
|
|
nn.AdaptiveAvgPool1d(1), |
|
|
nn.Conv1d(2048, 128, 1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv1d(128, 2048, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.shared_fc1 = nn.Linear(2048, 1024) |
|
|
self.shared_fc2 = nn.Linear(1024, 512) |
|
|
self.shared_fc3 = nn.Linear(512, 512) |
|
|
|
|
|
|
|
|
self.pos_fc1 = nn.Linear(512, 512) |
|
|
self.pos_fc2 = nn.Linear(512, 256) |
|
|
self.pos_fc3 = nn.Linear(256, 128) |
|
|
self.pos_fc4 = nn.Linear(128, 64) |
|
|
self.pos_fc5 = nn.Linear(64, output_dim) |
|
|
|
|
|
|
|
|
if self.predict_score: |
|
|
self.score_fc1 = nn.Linear(512, 512) |
|
|
self.score_fc2 = nn.Linear(512, 256) |
|
|
self.score_fc3 = nn.Linear(256, 128) |
|
|
self.score_fc4 = nn.Linear(128, 64) |
|
|
self.score_fc5 = nn.Linear(64, 1) |
|
|
|
|
|
|
|
|
if self.predict_class: |
|
|
self.class_fc1 = nn.Linear(512, 512) |
|
|
self.class_fc2 = nn.Linear(512, 256) |
|
|
self.class_fc3 = nn.Linear(256, 128) |
|
|
self.class_fc4 = nn.Linear(128, 64) |
|
|
self.class_fc5 = nn.Linear(64, num_classes) |
|
|
|
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(64, momentum=0.1) |
|
|
self.bn2 = nn.BatchNorm1d(128, momentum=0.1) |
|
|
self.bn3 = nn.BatchNorm1d(256, momentum=0.1) |
|
|
self.bn4 = nn.BatchNorm1d(512, momentum=0.1) |
|
|
self.bn5 = nn.BatchNorm1d(1024, momentum=0.1) |
|
|
self.bn6 = nn.BatchNorm1d(1024, momentum=0.1) |
|
|
self.bn7 = nn.BatchNorm1d(2048, momentum=0.1) |
|
|
|
|
|
|
|
|
self.gn1 = nn.GroupNorm(32, 1024) |
|
|
self.gn2 = nn.GroupNorm(16, 512) |
|
|
|
|
|
|
|
|
self.dropout_light = nn.Dropout(0.1) |
|
|
self.dropout_medium = nn.Dropout(0.2) |
|
|
self.dropout_heavy = nn.Dropout(0.3) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass with residual connections and attention |
|
|
Args: |
|
|
x: (batch_size, input_dim, max_points) tensor |
|
|
Returns: |
|
|
Tuple containing predictions based on configuration |
|
|
""" |
|
|
batch_size = x.size(0) |
|
|
|
|
|
|
|
|
x1 = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01, inplace=True) |
|
|
x2 = F.leaky_relu(self.bn2(self.conv2(x1)), negative_slope=0.01, inplace=True) |
|
|
x3 = F.leaky_relu(self.bn3(self.conv3(x2)), negative_slope=0.01, inplace=True) |
|
|
x4 = F.leaky_relu(self.bn4(self.conv4(x3)), negative_slope=0.01, inplace=True) |
|
|
x5 = F.leaky_relu(self.bn5(self.conv5(x4)), negative_slope=0.01, inplace=True) |
|
|
|
|
|
|
|
|
x6 = F.leaky_relu(self.bn6(self.conv6(x5)) + x5, negative_slope=0.01, inplace=True) |
|
|
x7 = F.leaky_relu(self.bn7(self.conv7(x6)), negative_slope=0.01, inplace=True) |
|
|
|
|
|
|
|
|
attention_weights = self.channel_attention(x7) |
|
|
x7_attended = x7 * attention_weights |
|
|
|
|
|
|
|
|
max_pool = torch.max(x7_attended, 2)[0] |
|
|
avg_pool = torch.mean(x7_attended, 2) |
|
|
|
|
|
|
|
|
global_features = 0.7 * max_pool + 0.3 * avg_pool |
|
|
|
|
|
|
|
|
shared1 = F.leaky_relu(self.gn1(self.shared_fc1(global_features).unsqueeze(-1)).squeeze(-1), |
|
|
negative_slope=0.01, inplace=True) |
|
|
shared1 = self.dropout_light(shared1) |
|
|
|
|
|
shared2 = F.leaky_relu(self.gn2(self.shared_fc2(shared1).unsqueeze(-1)).squeeze(-1), |
|
|
negative_slope=0.01, inplace=True) |
|
|
shared2 = self.dropout_medium(shared2) |
|
|
|
|
|
|
|
|
shared3 = F.leaky_relu(self.shared_fc3(shared2), negative_slope=0.01, inplace=True) |
|
|
shared_features = self.dropout_light(shared3) + shared2 |
|
|
|
|
|
|
|
|
pos1 = F.leaky_relu(self.pos_fc1(shared_features), negative_slope=0.01, inplace=True) |
|
|
pos1 = self.dropout_light(pos1) |
|
|
|
|
|
pos2 = F.leaky_relu(self.pos_fc2(pos1), negative_slope=0.01, inplace=True) |
|
|
pos2 = self.dropout_medium(pos2) |
|
|
|
|
|
pos3 = F.leaky_relu(self.pos_fc3(pos2), negative_slope=0.01, inplace=True) |
|
|
pos3 = self.dropout_light(pos3) |
|
|
|
|
|
pos4 = F.leaky_relu(self.pos_fc4(pos3), negative_slope=0.01, inplace=True) |
|
|
position = self.pos_fc5(pos4) |
|
|
|
|
|
outputs = [position] |
|
|
|
|
|
if self.predict_score: |
|
|
|
|
|
score1 = F.leaky_relu(self.score_fc1(shared_features), negative_slope=0.01, inplace=True) |
|
|
score1 = self.dropout_light(score1) |
|
|
score2 = F.leaky_relu(self.score_fc2(score1), negative_slope=0.01, inplace=True) |
|
|
score2 = self.dropout_medium(score2) |
|
|
score3 = F.leaky_relu(self.score_fc3(score2), negative_slope=0.01, inplace=True) |
|
|
score3 = self.dropout_light(score3) |
|
|
score4 = F.leaky_relu(self.score_fc4(score3), negative_slope=0.01, inplace=True) |
|
|
score = F.softplus(self.score_fc5(score4)) |
|
|
outputs.append(score) |
|
|
|
|
|
if self.predict_class: |
|
|
|
|
|
class1 = F.leaky_relu(self.class_fc1(shared_features), negative_slope=0.01, inplace=True) |
|
|
class1 = self.dropout_light(class1) |
|
|
class2 = F.leaky_relu(self.class_fc2(class1), negative_slope=0.01, inplace=True) |
|
|
class2 = self.dropout_medium(class2) |
|
|
class3 = F.leaky_relu(self.class_fc3(class2), negative_slope=0.01, inplace=True) |
|
|
class3 = self.dropout_light(class3) |
|
|
class4 = F.leaky_relu(self.class_fc4(class3), negative_slope=0.01, inplace=True) |
|
|
classification = self.class_fc5(class4) |
|
|
outputs.append(classification) |
|
|
|
|
|
|
|
|
if len(outputs) == 1: |
|
|
return outputs[0] |
|
|
elif len(outputs) == 2: |
|
|
if self.predict_score: |
|
|
return outputs[0], outputs[1] |
|
|
else: |
|
|
return outputs[0], outputs[1] |
|
|
else: |
|
|
return outputs[0], outputs[1], outputs[2] |
|
|
|
|
|
class PatchDataset(Dataset): |
|
|
""" |
|
|
Dataset class for loading saved patches for PointNet training. |
|
|
Updated for 11D patches. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataset_dir: str, max_points: int = 1024, augment: bool = True): |
|
|
self.dataset_dir = dataset_dir |
|
|
self.max_points = max_points |
|
|
self.augment = augment |
|
|
|
|
|
|
|
|
self.patch_files = [] |
|
|
for file in os.listdir(dataset_dir): |
|
|
if file.endswith('.pkl'): |
|
|
self.patch_files.append(os.path.join(dataset_dir, file)) |
|
|
|
|
|
print(f"Found {len(self.patch_files)} patch files in {dataset_dir}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.patch_files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Load and process a patch for training. |
|
|
Returns: |
|
|
patch_data: (11, max_points) tensor of point cloud data |
|
|
target: (3,) tensor of target 3D coordinates |
|
|
valid_mask: (max_points,) boolean tensor indicating valid points |
|
|
distance_to_gt: scalar tensor of distance from initial prediction to GT |
|
|
classification: scalar tensor for binary classification (1 if GT vertex present, 0 if not) |
|
|
""" |
|
|
patch_file = self.patch_files[idx] |
|
|
|
|
|
with open(patch_file, 'rb') as f: |
|
|
patch_info = pickle.load(f) |
|
|
|
|
|
patch_11d = patch_info['patch_11d'] |
|
|
target = patch_info.get('assigned_wf_vertex', None) |
|
|
initial_pred = patch_info.get('cluster_center', None) |
|
|
|
|
|
|
|
|
has_gt_vertex = 1.0 if target is not None else 0.0 |
|
|
|
|
|
|
|
|
if target is None: |
|
|
|
|
|
target = np.zeros(3) |
|
|
else: |
|
|
target = np.array(target) |
|
|
|
|
|
|
|
|
num_points = patch_11d.shape[0] |
|
|
|
|
|
if num_points >= self.max_points: |
|
|
|
|
|
indices = np.random.choice(num_points, self.max_points, replace=False) |
|
|
patch_sampled = patch_11d[indices] |
|
|
valid_mask = np.ones(self.max_points, dtype=bool) |
|
|
else: |
|
|
|
|
|
patch_sampled = np.zeros((self.max_points, 11)) |
|
|
patch_sampled[:num_points] = patch_11d |
|
|
valid_mask = np.zeros(self.max_points, dtype=bool) |
|
|
valid_mask[:num_points] = True |
|
|
|
|
|
|
|
|
if self.augment and has_gt_vertex > 0: |
|
|
patch_sampled, target = self._augment_patch(patch_sampled, valid_mask, target) |
|
|
|
|
|
|
|
|
patch_tensor = torch.from_numpy(patch_sampled.T).float() |
|
|
target_tensor = torch.from_numpy(target).float() |
|
|
valid_mask_tensor = torch.from_numpy(valid_mask) |
|
|
|
|
|
|
|
|
if initial_pred is not None: |
|
|
initial_pred_tensor = torch.from_numpy(initial_pred).float() |
|
|
else: |
|
|
initial_pred_tensor = torch.zeros(3).float() |
|
|
|
|
|
|
|
|
classification_tensor = torch.tensor(has_gt_vertex).float() |
|
|
|
|
|
return patch_tensor, target_tensor, valid_mask_tensor, initial_pred_tensor, classification_tensor |
|
|
|
|
|
def _augment_patch(self, patch_sampled, valid_mask, target): |
|
|
""" |
|
|
Apply data augmentation to patch and target. |
|
|
Only augment valid points and update target accordingly. |
|
|
""" |
|
|
valid_points = patch_sampled[valid_mask] |
|
|
|
|
|
if len(valid_points) > 0: |
|
|
|
|
|
angle = np.random.uniform(-np.pi/12, np.pi/12) |
|
|
cos_a, sin_a = np.cos(angle), np.sin(angle) |
|
|
rotation_matrix = np.array([[cos_a, -sin_a, 0], |
|
|
[sin_a, cos_a, 0], |
|
|
[0, 0, 1]]) |
|
|
|
|
|
|
|
|
valid_points[:, :3] = valid_points[:, :3] @ rotation_matrix.T |
|
|
target = target @ rotation_matrix.T |
|
|
|
|
|
|
|
|
translation = np.random.uniform(-0.05, 0.05, 3) |
|
|
valid_points[:, :3] += translation |
|
|
target += translation |
|
|
|
|
|
|
|
|
scale = np.random.uniform(0.95, 1.05) |
|
|
valid_points[:, :3] *= scale |
|
|
target *= scale |
|
|
|
|
|
|
|
|
if valid_points.shape[1] > 3: |
|
|
noise = np.random.normal(0, 0.01, valid_points[:, 3:].shape) |
|
|
valid_points[:, 3:] += noise |
|
|
|
|
|
|
|
|
patch_sampled[valid_mask] = valid_points |
|
|
|
|
|
return patch_sampled, target |
|
|
|
|
|
def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str): |
|
|
""" |
|
|
Save patches from prediction pipeline to create a training dataset. |
|
|
|
|
|
Args: |
|
|
patches: List of patch dictionaries from generate_patches() |
|
|
dataset_dir: Directory to save the dataset |
|
|
entry_id: Unique identifier for this entry/image |
|
|
""" |
|
|
os.makedirs(dataset_dir, exist_ok=True) |
|
|
|
|
|
for i, patch in enumerate(patches): |
|
|
|
|
|
filename = f"{entry_id}_patch_{i}.pkl" |
|
|
filepath = os.path.join(dataset_dir, filename) |
|
|
|
|
|
|
|
|
if os.path.exists(filepath): |
|
|
continue |
|
|
|
|
|
|
|
|
with open(filepath, 'wb') as f: |
|
|
pickle.dump(patch, f) |
|
|
|
|
|
print(f"Saved {len(patches)} patches for entry {entry_id}") |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
valid_batch = [] |
|
|
for patch_data, target, valid_mask, initial_pred, classification in batch: |
|
|
|
|
|
if valid_mask.sum() > 0: |
|
|
valid_batch.append((patch_data, target, valid_mask, initial_pred, classification)) |
|
|
|
|
|
if len(valid_batch) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
patch_data = torch.stack([item[0] for item in valid_batch]) |
|
|
targets = torch.stack([item[1] for item in valid_batch]) |
|
|
valid_masks = torch.stack([item[2] for item in valid_batch]) |
|
|
initial_preds = torch.stack([item[3] for item in valid_batch]) |
|
|
classifications = torch.stack([item[4] for item in valid_batch]) |
|
|
|
|
|
return patch_data, targets, valid_masks, initial_preds, classifications |
|
|
|
|
|
|
|
|
def init_weights(m): |
|
|
if isinstance(m, nn.Conv1d): |
|
|
nn.init.kaiming_uniform_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu') |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.kaiming_uniform_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu') |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): |
|
|
nn.init.ones_(m.weight) |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, batch_size: int = 32, lr: float = 0.001, |
|
|
score_weight: float = 0.1, class_weight: float = 0.5): |
|
|
""" |
|
|
Train the FastPointNet model on saved patches. |
|
|
Updated for 11D input. |
|
|
""" |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Training on device: {device}") |
|
|
|
|
|
|
|
|
dataset = PatchDataset(dataset_dir, max_points=1024, augment=True) |
|
|
print(f"Dataset loaded with {len(dataset)} samples") |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20, |
|
|
collate_fn=collate_fn, drop_last=True) |
|
|
|
|
|
|
|
|
model = FastPointNet(input_dim=11, output_dim=3, max_points=1024, predict_score=True, predict_class=True, num_classes=1) |
|
|
|
|
|
model.apply(init_weights) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
position_criterion = nn.SmoothL1Loss() |
|
|
score_criterion = nn.SmoothL1Loss() |
|
|
classification_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.0)) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, betas=(0.9, 0.999)) |
|
|
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2) |
|
|
|
|
|
|
|
|
model.train() |
|
|
for epoch in range(epochs): |
|
|
total_loss = 0.0 |
|
|
total_pos_loss = 0.0 |
|
|
total_score_loss = 0.0 |
|
|
total_class_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch_idx, batch_data in enumerate(dataloader): |
|
|
if batch_data is None: |
|
|
continue |
|
|
|
|
|
patch_data, targets, valid_masks, initial_preds, classifications = batch_data |
|
|
patch_data = patch_data.to(device) |
|
|
targets = targets.to(device) |
|
|
classifications = classifications.to(device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
predictions, predicted_scores, predicted_classes = model(patch_data) |
|
|
|
|
|
|
|
|
actual_distances = torch.norm(predictions - targets, dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
has_gt_mask = classifications > 0.5 |
|
|
|
|
|
if has_gt_mask.sum() > 0: |
|
|
|
|
|
pos_loss = position_criterion(predictions[has_gt_mask], targets[has_gt_mask]) |
|
|
score_loss = score_criterion(predicted_scores[has_gt_mask], actual_distances[has_gt_mask]) |
|
|
else: |
|
|
pos_loss = torch.tensor(0.0, device=device) |
|
|
score_loss = torch.tensor(0.0, device=device) |
|
|
|
|
|
|
|
|
class_loss = classification_criterion(predicted_classes.squeeze(), classifications) |
|
|
|
|
|
|
|
|
total_batch_loss = pos_loss + score_weight * score_loss + class_weight * class_loss |
|
|
|
|
|
|
|
|
total_batch_loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
total_loss += total_batch_loss.item() |
|
|
total_pos_loss += pos_loss.item() |
|
|
total_score_loss += score_loss.item() |
|
|
total_class_loss += class_loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
if batch_idx % 50 == 0: |
|
|
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, " |
|
|
f"Total Loss: {total_batch_loss.item():.6f}, " |
|
|
f"Pos Loss: {pos_loss.item():.6f}, " |
|
|
f"Score Loss: {score_loss.item():.6f}, " |
|
|
f"Class Loss: {class_loss.item():.6f}") |
|
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
|
|
avg_pos_loss = total_pos_loss / num_batches if num_batches > 0 else 0 |
|
|
avg_score_loss = total_score_loss / num_batches if num_batches > 0 else 0 |
|
|
avg_class_loss = total_class_loss / num_batches if num_batches > 0 else 0 |
|
|
|
|
|
print(f"Epoch {epoch+1}/{epochs} completed, " |
|
|
f"Avg Total Loss: {avg_loss:.6f}, " |
|
|
f"Avg Pos Loss: {avg_pos_loss:.6f}, " |
|
|
f"Avg Score Loss: {avg_score_loss:.6f}, " |
|
|
f"Avg Class Loss: {avg_class_loss:.6f}") |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
if (epoch + 1) % 10 == 0: |
|
|
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth') |
|
|
torch.save({ |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'epoch': epoch + 1, |
|
|
'loss': avg_loss, |
|
|
}, checkpoint_path) |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'epoch': epochs, |
|
|
}, model_save_path) |
|
|
|
|
|
print(f"Model saved to {model_save_path}") |
|
|
return model |
|
|
|
|
|
def load_pointnet_model(model_path: str, device: torch.device = None, predict_score: bool = True) -> FastPointNet: |
|
|
""" |
|
|
Load a trained FastPointNet model. |
|
|
Updated for 11D input. |
|
|
""" |
|
|
if device is None: |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
model = FastPointNet(input_dim=11, output_dim=3, max_points=1024, predict_score=predict_score) |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
def predict_vertex_from_patch(model: FastPointNet, patch: np.ndarray, device: torch.device = None) -> Tuple[np.ndarray, float, float]: |
|
|
""" |
|
|
Predict 3D vertex coordinates, confidence score, and classification from a patch using trained PointNet. |
|
|
Updated for 11D patches. |
|
|
""" |
|
|
if device is None: |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
patch_11d = patch['patch_11d'] |
|
|
|
|
|
|
|
|
max_points = 1024 |
|
|
num_points = patch_11d.shape[0] |
|
|
|
|
|
if num_points >= max_points: |
|
|
|
|
|
indices = np.random.choice(num_points, max_points, replace=False) |
|
|
patch_sampled = patch_11d[indices] |
|
|
else: |
|
|
|
|
|
patch_sampled = np.zeros((max_points, 11)) |
|
|
patch_sampled[:num_points] = patch_11d |
|
|
|
|
|
|
|
|
patch_tensor = torch.from_numpy(patch_sampled.T).float().unsqueeze(0) |
|
|
patch_tensor = patch_tensor.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(patch_tensor) |
|
|
|
|
|
if model.predict_score and model.predict_class: |
|
|
position, score, classification = outputs |
|
|
position = position.cpu().numpy().squeeze() |
|
|
score = score.cpu().numpy().squeeze() |
|
|
classification = torch.sigmoid(classification).cpu().numpy().squeeze() |
|
|
elif model.predict_score: |
|
|
position, score = outputs |
|
|
position = position.cpu().numpy().squeeze() |
|
|
score = score.cpu().numpy().squeeze() |
|
|
classification = None |
|
|
elif model.predict_class: |
|
|
position, classification = outputs |
|
|
position = position.cpu().numpy().squeeze() |
|
|
score = None |
|
|
classification = torch.sigmoid(classification).cpu().numpy().squeeze() |
|
|
else: |
|
|
position = outputs |
|
|
position = position.cpu().numpy().squeeze() |
|
|
score = None |
|
|
classification = None |
|
|
|
|
|
|
|
|
offset = patch['cluster_center'] |
|
|
position += offset |
|
|
|
|
|
return position, score, classification |
|
|
|
|
|
|