Instructions to use amaye15/autoencoder-robust-demo with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use amaye15/autoencoder-robust-demo with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="amaye15/autoencoder-robust-demo", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("amaye15/autoencoder-robust-demo", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| PyTorch Autoencoder model for Hugging Face Transformers. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, Union, Dict, Any, List | |
| from dataclasses import dataclass | |
| import random | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput | |
| from transformers.utils import ModelOutput | |
| try: | |
| from .configuration_autoencoder import AutoencoderConfig # when loaded via HF dynamic module | |
| except Exception: | |
| from configuration_autoencoder import AutoencoderConfig # local usage | |
| class NeuralScaler(nn.Module): | |
| """Learnable alternative to StandardScaler using neural networks.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| input_dim = config.input_dim | |
| hidden_dim = config.preprocessing_hidden_dim | |
| # Networks to learn data-dependent statistics | |
| self.mean_estimator = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, input_dim) | |
| ) | |
| self.std_estimator = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, input_dim), | |
| nn.Softplus() # Ensure positive standard deviation | |
| ) | |
| # Learnable affine transformation parameters | |
| self.weight = nn.Parameter(torch.ones(input_dim)) | |
| self.bias = nn.Parameter(torch.zeros(input_dim)) | |
| # Running statistics for inference (like BatchNorm) | |
| self.register_buffer('running_mean', torch.zeros(input_dim)) | |
| self.register_buffer('running_std', torch.ones(input_dim)) | |
| self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) | |
| # Momentum for running statistics | |
| self.momentum = 0.1 | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through neural scaler. | |
| Args: | |
| x: Input tensor (2D or 3D) | |
| inverse: Whether to apply inverse transformation | |
| Returns: | |
| Tuple of (transformed_tensor, regularization_loss) | |
| """ | |
| if inverse: | |
| return self._inverse_transform(x) | |
| # Handle both 2D and 3D tensors | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| # Reshape (batch, seq, features) -> (batch*seq, features) | |
| x = x.view(-1, x.size(-1)) | |
| if self.training: | |
| # Training mode: learn statistics from current batch | |
| batch_mean = x.mean(dim=0, keepdim=True) | |
| batch_std = x.std(dim=0, keepdim=True) | |
| # Learn data-dependent adjustments | |
| learned_mean_adj = self.mean_estimator(batch_mean) | |
| learned_std_adj = self.std_estimator(batch_std) | |
| # Combine batch statistics with learned adjustments | |
| effective_mean = batch_mean + learned_mean_adj | |
| effective_std = batch_std + learned_std_adj + 1e-8 | |
| # Update running statistics | |
| with torch.no_grad(): | |
| self.num_batches_tracked += 1 | |
| if self.num_batches_tracked == 1: | |
| self.running_mean.copy_(batch_mean.squeeze()) | |
| self.running_std.copy_(batch_std.squeeze()) | |
| else: | |
| self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) | |
| self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) | |
| else: | |
| # Inference mode: use running statistics | |
| effective_mean = self.running_mean.unsqueeze(0) | |
| effective_std = self.running_std.unsqueeze(0) + 1e-8 | |
| # Normalize | |
| normalized = (x - effective_mean) / effective_std | |
| # Apply learnable affine transformation | |
| transformed = normalized * self.weight + self.bias | |
| # Reshape back to original shape if needed | |
| if len(original_shape) == 3: | |
| transformed = transformed.view(original_shape) | |
| # Regularization loss to encourage meaningful learning | |
| reg_loss = 0.01 * (self.weight.var() + self.bias.var()) | |
| return transformed, reg_loss | |
| def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply inverse transformation to get back original scale.""" | |
| if not self.config.learn_inverse_preprocessing: | |
| return x, torch.tensor(0.0, device=x.device) | |
| # Handle both 2D and 3D tensors | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| # Reshape (batch, seq, features) -> (batch*seq, features) | |
| x = x.view(-1, x.size(-1)) | |
| # Reverse affine transformation | |
| x = (x - self.bias) / (self.weight + 1e-8) | |
| # Reverse normalization using running statistics | |
| effective_mean = self.running_mean.unsqueeze(0) | |
| effective_std = self.running_std.unsqueeze(0) + 1e-8 | |
| x = x * effective_std + effective_mean | |
| # Reshape back to original shape if needed | |
| if len(original_shape) == 3: | |
| x = x.view(original_shape) | |
| return x, torch.tensor(0.0, device=x.device) | |
| class LearnableMinMaxScaler(nn.Module): | |
| """Learnable MinMax scaler that adapts bounds during training. | |
| Scales features to [0, 1] using batch min/range with learnable adjustments and | |
| a learnable affine transform. Supports 2D (B, F) and 3D (B, T, F) inputs. | |
| """ | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| input_dim = config.input_dim | |
| hidden_dim = config.preprocessing_hidden_dim | |
| # Networks to learn adjustments to batch min and range | |
| self.min_estimator = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, input_dim), | |
| ) | |
| self.range_estimator = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, input_dim), | |
| nn.Softplus(), # Ensure positive adjustment to range | |
| ) | |
| # Learnable affine transformation parameters | |
| self.weight = nn.Parameter(torch.ones(input_dim)) | |
| self.bias = nn.Parameter(torch.zeros(input_dim)) | |
| # Running statistics for inference | |
| self.register_buffer("running_min", torch.zeros(input_dim)) | |
| self.register_buffer("running_range", torch.ones(input_dim)) | |
| self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) | |
| self.momentum = 0.1 | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if inverse: | |
| return self._inverse_transform(x) | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| x = x.view(-1, x.size(-1)) | |
| eps = 1e-8 | |
| if self.training: | |
| batch_min = x.min(dim=0, keepdim=True).values | |
| batch_max = x.max(dim=0, keepdim=True).values | |
| batch_range = (batch_max - batch_min).clamp_min(eps) | |
| # Learn adjustments | |
| learned_min_adj = self.min_estimator(batch_min) | |
| learned_range_adj = self.range_estimator(batch_range) | |
| effective_min = batch_min + learned_min_adj | |
| effective_range = batch_range + learned_range_adj + eps | |
| # Update running stats with raw batch min/range for stable inversion | |
| with torch.no_grad(): | |
| self.num_batches_tracked += 1 | |
| if self.num_batches_tracked == 1: | |
| self.running_min.copy_(batch_min.squeeze()) | |
| self.running_range.copy_(batch_range.squeeze()) | |
| else: | |
| self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum) | |
| self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum) | |
| else: | |
| effective_min = self.running_min.unsqueeze(0) | |
| effective_range = self.running_range.unsqueeze(0) | |
| # Scale to [0, 1] | |
| scaled = (x - effective_min) / effective_range | |
| # Learnable affine transform | |
| transformed = scaled * self.weight + self.bias | |
| if len(original_shape) == 3: | |
| transformed = transformed.view(original_shape) | |
| # Regularization: encourage non-degenerate range and modest affine params | |
| reg_loss = 0.01 * (self.weight.var() + self.bias.var()) | |
| if self.training: | |
| reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean() | |
| return transformed, reg_loss | |
| def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if not self.config.learn_inverse_preprocessing: | |
| return x, torch.tensor(0.0, device=x.device) | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| x = x.view(-1, x.size(-1)) | |
| # Reverse affine | |
| x = (x - self.bias) / (self.weight + 1e-8) | |
| # Reverse MinMax using running stats | |
| x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0) | |
| if len(original_shape) == 3: | |
| x = x.view(original_shape) | |
| return x, torch.tensor(0.0, device=x.device) | |
| class LearnableRobustScaler(nn.Module): | |
| """Learnable Robust scaler using median and IQR with learnable adjustments. | |
| Normalizes as (x - median) / IQR with learnable adjustments and an affine head. | |
| Supports 2D (B, F) and 3D (B, T, F) inputs. | |
| """ | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| input_dim = config.input_dim | |
| hidden_dim = config.preprocessing_hidden_dim | |
| self.median_estimator = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, input_dim), | |
| ) | |
| self.iqr_estimator = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, input_dim), | |
| nn.Softplus(), # Ensure positive IQR adjustment | |
| ) | |
| self.weight = nn.Parameter(torch.ones(input_dim)) | |
| self.bias = nn.Parameter(torch.zeros(input_dim)) | |
| self.register_buffer("running_median", torch.zeros(input_dim)) | |
| self.register_buffer("running_iqr", torch.ones(input_dim)) | |
| self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) | |
| self.momentum = 0.1 | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if inverse: | |
| return self._inverse_transform(x) | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| x = x.view(-1, x.size(-1)) | |
| eps = 1e-8 | |
| if self.training: | |
| qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0) | |
| q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :] | |
| iqr = (q75 - q25).clamp_min(eps) | |
| learned_med_adj = self.median_estimator(med) | |
| learned_iqr_adj = self.iqr_estimator(iqr) | |
| effective_median = med + learned_med_adj | |
| effective_iqr = iqr + learned_iqr_adj + eps | |
| with torch.no_grad(): | |
| self.num_batches_tracked += 1 | |
| if self.num_batches_tracked == 1: | |
| self.running_median.copy_(med.squeeze()) | |
| self.running_iqr.copy_(iqr.squeeze()) | |
| else: | |
| self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum) | |
| self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum) | |
| else: | |
| effective_median = self.running_median.unsqueeze(0) | |
| effective_iqr = self.running_iqr.unsqueeze(0) | |
| normalized = (x - effective_median) / effective_iqr | |
| transformed = normalized * self.weight + self.bias | |
| if len(original_shape) == 3: | |
| transformed = transformed.view(original_shape) | |
| reg_loss = 0.01 * (self.weight.var() + self.bias.var()) | |
| if self.training: | |
| reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean() | |
| return transformed, reg_loss | |
| def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if not self.config.learn_inverse_preprocessing: | |
| return x, torch.tensor(0.0, device=x.device) | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| x = x.view(-1, x.size(-1)) | |
| x = (x - self.bias) / (self.weight + 1e-8) | |
| x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0) | |
| if len(original_shape) == 3: | |
| x = x.view(original_shape) | |
| return x, torch.tensor(0.0, device=x.device) | |
| class LearnableYeoJohnsonPreprocessor(nn.Module): | |
| """Learnable Yeo-Johnson power transform with per-feature λ and affine head. | |
| Applies Yeo-Johnson transform elementwise with learnable lambda per feature, | |
| followed by standardization and a learnable affine transform. Supports 2D and 3D inputs. | |
| """ | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| input_dim = config.input_dim | |
| # Learnable lambda per feature (unconstrained). Initialize around 1.0 | |
| self.lmbda = nn.Parameter(torch.ones(input_dim)) | |
| # Learnable affine parameters after standardization | |
| self.weight = nn.Parameter(torch.ones(input_dim)) | |
| self.bias = nn.Parameter(torch.zeros(input_dim)) | |
| # Running stats for transformed data | |
| self.register_buffer("running_mean", torch.zeros(input_dim)) | |
| self.register_buffer("running_std", torch.ones(input_dim)) | |
| self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) | |
| self.momentum = 0.1 | |
| def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: | |
| eps = 1e-6 | |
| lmbda = lmbda.unsqueeze(0) # broadcast over batch | |
| pos = x >= 0 | |
| # For x >= 0 | |
| if_part = torch.where( | |
| torch.abs(lmbda) > eps, | |
| ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, | |
| torch.log((x + 1.0).clamp_min(eps)), | |
| ) | |
| # For x < 0 | |
| two_minus_lambda = 2.0 - lmbda | |
| else_part = torch.where( | |
| torch.abs(two_minus_lambda) > eps, | |
| -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, | |
| -torch.log((1.0 - x).clamp_min(eps)), | |
| ) | |
| return torch.where(pos, if_part, else_part) | |
| def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: | |
| eps = 1e-6 | |
| lmbda = lmbda.unsqueeze(0) | |
| pos = y >= 0 | |
| # Inverse for y >= 0 | |
| x_pos = torch.where( | |
| torch.abs(lmbda) > eps, | |
| (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, | |
| torch.exp(y) - 1.0, | |
| ) | |
| # Inverse for y < 0 | |
| two_minus_lambda = 2.0 - lmbda | |
| x_neg = torch.where( | |
| torch.abs(two_minus_lambda) > eps, | |
| 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), | |
| 1.0 - torch.exp(-y), | |
| ) | |
| return torch.where(pos, x_pos, x_neg) | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if inverse: | |
| return self._inverse_transform(x) | |
| orig_shape = x.shape | |
| if x.dim() == 3: | |
| x = x.view(-1, x.size(-1)) | |
| # Apply Yeo-Johnson | |
| y = self._yeo_johnson(x, self.lmbda) | |
| # Batch stats and running stats on transformed data | |
| if self.training: | |
| batch_mean = y.mean(dim=0, keepdim=True) | |
| batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6) | |
| with torch.no_grad(): | |
| self.num_batches_tracked += 1 | |
| if self.num_batches_tracked == 1: | |
| self.running_mean.copy_(batch_mean.squeeze()) | |
| self.running_std.copy_(batch_std.squeeze()) | |
| else: | |
| self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) | |
| self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) | |
| mean = batch_mean | |
| std = batch_std | |
| else: | |
| mean = self.running_mean.unsqueeze(0) | |
| std = self.running_std.unsqueeze(0) | |
| y_norm = (y - mean) / std | |
| out = y_norm * self.weight + self.bias | |
| if len(orig_shape) == 3: | |
| out = out.view(orig_shape) | |
| # Regularize lambda to avoid extreme values; encourage identity around 1 | |
| reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var()) | |
| return out, reg | |
| def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if not self.config.learn_inverse_preprocessing: | |
| return x, torch.tensor(0.0, device=x.device) | |
| orig_shape = x.shape | |
| if x.dim() == 3: | |
| x = x.view(-1, x.size(-1)) | |
| # Reverse affine and normalization with running stats | |
| y = (x - self.bias) / (self.weight + 1e-8) | |
| y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0) | |
| # Inverse Yeo-Johnson | |
| out = self._yeo_johnson_inverse(y, self.lmbda) | |
| if len(orig_shape) == 3: | |
| out = out.view(orig_shape) | |
| return out, torch.tensor(0.0, device=x.device) | |
| class CouplingLayer(nn.Module): | |
| """Coupling layer for normalizing flows.""" | |
| def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| # Create mask for coupling | |
| if mask_type == "alternating": | |
| self.register_buffer('mask', torch.arange(input_dim) % 2) | |
| elif mask_type == "half": | |
| mask = torch.zeros(input_dim) | |
| mask[:input_dim // 2] = 1 | |
| self.register_buffer('mask', mask) | |
| else: | |
| raise ValueError(f"Unknown mask type: {mask_type}") | |
| # Scale and translation networks | |
| masked_dim = int(self.mask.sum().item()) | |
| unmasked_dim = input_dim - masked_dim | |
| self.scale_net = nn.Sequential( | |
| nn.Linear(masked_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, unmasked_dim), | |
| nn.Tanh() # Bounded output for stability | |
| ) | |
| self.translate_net = nn.Sequential( | |
| nn.Linear(masked_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, unmasked_dim) | |
| ) | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through coupling layer. | |
| Args: | |
| x: Input tensor | |
| inverse: Whether to apply inverse transformation | |
| Returns: | |
| Tuple of (transformed_tensor, log_determinant) | |
| """ | |
| mask = self.mask.bool() | |
| x_masked = x[:, mask] | |
| x_unmasked = x[:, ~mask] | |
| # Compute scale and translation | |
| s = self.scale_net(x_masked) | |
| t = self.translate_net(x_masked) | |
| if not inverse: | |
| # Forward transformation | |
| y_unmasked = x_unmasked * torch.exp(s) + t | |
| log_det = s.sum(dim=1) | |
| else: | |
| # Inverse transformation | |
| y_unmasked = (x_unmasked - t) * torch.exp(-s) | |
| log_det = -s.sum(dim=1) | |
| # Reconstruct output | |
| y = torch.zeros_like(x) | |
| y[:, mask] = x_masked | |
| y[:, ~mask] = y_unmasked | |
| return y, log_det | |
| class NormalizingFlowPreprocessor(nn.Module): | |
| """Normalizing flow for learnable data preprocessing.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| input_dim = config.input_dim | |
| hidden_dim = config.preprocessing_hidden_dim | |
| num_layers = config.flow_coupling_layers | |
| # Create coupling layers with alternating masks | |
| self.layers = nn.ModuleList() | |
| for i in range(num_layers): | |
| mask_type = "alternating" if i % 2 == 0 else "half" | |
| self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type)) | |
| # Optional: Add batch normalization between layers | |
| if config.use_batch_norm: | |
| self.batch_norms = nn.ModuleList([ | |
| nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1) | |
| ]) | |
| else: | |
| self.batch_norms = None | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through normalizing flow. | |
| Args: | |
| x: Input tensor (2D or 3D) | |
| inverse: Whether to apply inverse transformation | |
| Returns: | |
| Tuple of (transformed_tensor, total_log_determinant) | |
| """ | |
| # Handle both 2D and 3D tensors | |
| original_shape = x.shape | |
| if x.dim() == 3: | |
| # Reshape (batch, seq, features) -> (batch*seq, features) | |
| x = x.view(-1, x.size(-1)) | |
| log_det_total = torch.zeros(x.size(0), device=x.device) | |
| if not inverse: | |
| # Forward pass | |
| for i, layer in enumerate(self.layers): | |
| x, log_det = layer(x, inverse=False) | |
| log_det_total += log_det | |
| # Apply batch normalization (except for last layer) | |
| if self.batch_norms and i < len(self.layers) - 1: | |
| x = self.batch_norms[i](x) | |
| else: | |
| # Inverse pass | |
| for i, layer in enumerate(reversed(self.layers)): | |
| # Reverse batch normalization (except for first layer in reverse) | |
| if self.batch_norms and i > 0: | |
| # Note: This is approximate inverse of batch norm | |
| bn_idx = len(self.layers) - 1 - i | |
| x = self.batch_norms[bn_idx](x) | |
| x, log_det = layer(x, inverse=True) | |
| log_det_total += log_det | |
| # Reshape back to original shape if needed | |
| if len(original_shape) == 3: | |
| x = x.view(original_shape) | |
| # Convert log determinant to regularization loss | |
| # Encourage the flow to preserve information (log_det close to 0) | |
| reg_loss = 0.01 * log_det_total.abs().mean() | |
| return x, reg_loss | |
| class LearnablePreprocessor(nn.Module): | |
| """Unified interface for learnable preprocessing methods.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| if not config.has_preprocessing: | |
| self.preprocessor = nn.Identity() | |
| elif config.is_neural_scaler: | |
| self.preprocessor = NeuralScaler(config) | |
| elif config.is_normalizing_flow: | |
| self.preprocessor = NormalizingFlowPreprocessor(config) | |
| elif getattr(config, "is_minmax_scaler", False): | |
| self.preprocessor = LearnableMinMaxScaler(config) | |
| elif getattr(config, "is_robust_scaler", False): | |
| self.preprocessor = LearnableRobustScaler(config) | |
| elif getattr(config, "is_yeo_johnson", False): | |
| self.preprocessor = LearnableYeoJohnsonPreprocessor(config) | |
| else: | |
| raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}") | |
| def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply preprocessing transformation. | |
| Args: | |
| x: Input tensor | |
| inverse: Whether to apply inverse transformation | |
| Returns: | |
| Tuple of (transformed_tensor, regularization_loss) | |
| """ | |
| if isinstance(self.preprocessor, nn.Identity): | |
| return x, torch.tensor(0.0, device=x.device) | |
| return self.preprocessor(x, inverse=inverse) | |
| class AutoencoderOutput(ModelOutput): | |
| """ | |
| Output type of AutoencoderModel. | |
| Args: | |
| last_hidden_state (torch.FloatTensor): The latent representation of the input. | |
| reconstructed (torch.FloatTensor, optional): The reconstructed input. | |
| hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. | |
| attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. | |
| preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. | |
| """ | |
| last_hidden_state: torch.FloatTensor = None | |
| reconstructed: Optional[torch.FloatTensor] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| preprocessing_loss: Optional[torch.FloatTensor] = None | |
| class AutoencoderForReconstructionOutput(ModelOutput): | |
| """ | |
| Output type of AutoencoderForReconstruction. | |
| Args: | |
| loss (torch.FloatTensor, optional): The reconstruction loss. | |
| reconstructed (torch.FloatTensor): The reconstructed input. | |
| last_hidden_state (torch.FloatTensor): The latent representation. | |
| hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. | |
| preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. | |
| """ | |
| loss: Optional[torch.FloatTensor] = None | |
| reconstructed: torch.FloatTensor = None | |
| last_hidden_state: torch.FloatTensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| preprocessing_loss: Optional[torch.FloatTensor] = None | |
| class AutoencoderEncoder(nn.Module): | |
| """Encoder part of the autoencoder.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| # Build encoder layers | |
| layers = [] | |
| input_dim = config.input_dim | |
| for hidden_dim in config.hidden_dims: | |
| layers.append(nn.Linear(input_dim, hidden_dim)) | |
| if config.use_batch_norm: | |
| layers.append(nn.BatchNorm1d(hidden_dim)) | |
| layers.append(self._get_activation(config.activation)) | |
| if config.dropout_rate > 0: | |
| layers.append(nn.Dropout(config.dropout_rate)) | |
| input_dim = hidden_dim | |
| self.encoder = nn.Sequential(*layers) | |
| # For variational autoencoders, we need separate layers for mean and log variance | |
| if config.is_variational: | |
| self.fc_mu = nn.Linear(input_dim, config.latent_dim) | |
| self.fc_logvar = nn.Linear(input_dim, config.latent_dim) | |
| else: | |
| # Standard encoder output | |
| self.fc_out = nn.Linear(input_dim, config.latent_dim) | |
| def _get_activation(self, activation: str) -> nn.Module: | |
| """Get activation function by name.""" | |
| activations = { | |
| "relu": nn.ReLU(), | |
| "tanh": nn.Tanh(), | |
| "sigmoid": nn.Sigmoid(), | |
| "leaky_relu": nn.LeakyReLU(), | |
| "gelu": nn.GELU(), | |
| "swish": nn.SiLU(), | |
| "silu": nn.SiLU(), | |
| "elu": nn.ELU(), | |
| "prelu": nn.PReLU(), | |
| "relu6": nn.ReLU6(), | |
| "hardtanh": nn.Hardtanh(), | |
| "hardsigmoid": nn.Hardsigmoid(), | |
| "hardswish": nn.Hardswish(), | |
| "mish": nn.Mish(), | |
| "softplus": nn.Softplus(), | |
| "softsign": nn.Softsign(), | |
| "tanhshrink": nn.Tanhshrink(), | |
| "threshold": nn.Threshold(threshold=0.1, value=0), | |
| } | |
| return activations[activation] | |
| def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: | |
| """Forward pass through encoder.""" | |
| # Add noise for denoising autoencoders | |
| if self.config.is_denoising and self.training: | |
| noise = torch.randn_like(x) * self.config.noise_factor | |
| x = x + noise | |
| encoded = self.encoder(x) | |
| if self.config.is_variational: | |
| # Variational autoencoder: return mean, log variance, and sampled latent | |
| mu = self.fc_mu(encoded) | |
| logvar = self.fc_logvar(encoded) | |
| # Reparameterization trick | |
| if self.training: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| z = mu + eps * std | |
| else: | |
| z = mu # Use mean during inference | |
| return z, mu, logvar | |
| else: | |
| # Standard autoencoder | |
| latent = self.fc_out(encoded) | |
| # Add sparsity constraint for sparse autoencoders | |
| if self.config.is_sparse and self.training: | |
| # Apply L1 regularization to encourage sparsity | |
| latent = F.relu(latent) # Ensure non-negative activations | |
| return latent | |
| class AutoencoderDecoder(nn.Module): | |
| """Decoder part of the autoencoder.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| # Build decoder layers (reverse of encoder) | |
| layers = [] | |
| input_dim = config.latent_dim | |
| decoder_dims = config.decoder_dims + [config.input_dim] | |
| for i, hidden_dim in enumerate(decoder_dims): | |
| layers.append(nn.Linear(input_dim, hidden_dim)) | |
| # Don't add batch norm, activation, or dropout to the final layer | |
| if i < len(decoder_dims) - 1: | |
| if config.use_batch_norm: | |
| layers.append(nn.BatchNorm1d(hidden_dim)) | |
| layers.append(self._get_activation(config.activation)) | |
| if config.dropout_rate > 0: | |
| layers.append(nn.Dropout(config.dropout_rate)) | |
| else: | |
| # Final layer - add appropriate activation based on reconstruction loss | |
| if config.reconstruction_loss == "bce": | |
| layers.append(nn.Sigmoid()) | |
| input_dim = hidden_dim | |
| self.decoder = nn.Sequential(*layers) | |
| def _get_activation(self, activation: str) -> nn.Module: | |
| """Get activation function by name.""" | |
| activations = { | |
| "relu": nn.ReLU(), | |
| "tanh": nn.Tanh(), | |
| "sigmoid": nn.Sigmoid(), | |
| "leaky_relu": nn.LeakyReLU(), | |
| "gelu": nn.GELU(), | |
| "swish": nn.SiLU(), | |
| "silu": nn.SiLU(), | |
| "elu": nn.ELU(), | |
| "prelu": nn.PReLU(), | |
| "relu6": nn.ReLU6(), | |
| "hardtanh": nn.Hardtanh(), | |
| "hardsigmoid": nn.Hardsigmoid(), | |
| "hardswish": nn.Hardswish(), | |
| "mish": nn.Mish(), | |
| "softplus": nn.Softplus(), | |
| "softsign": nn.Softsign(), | |
| "tanhshrink": nn.Tanhshrink(), | |
| "threshold": nn.Threshold(threshold=0.1, value=0), | |
| } | |
| return activations[activation] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass through decoder.""" | |
| return self.decoder(x) | |
| class RecurrentEncoder(nn.Module): | |
| """Recurrent encoder for sequence data.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| # Get RNN class | |
| if config.rnn_type == "lstm": | |
| rnn_class = nn.LSTM | |
| elif config.rnn_type == "gru": | |
| rnn_class = nn.GRU | |
| elif config.rnn_type == "rnn": | |
| rnn_class = nn.RNN | |
| else: | |
| raise ValueError(f"Unknown RNN type: {config.rnn_type}") | |
| # Create RNN layers | |
| self.rnn = rnn_class( | |
| input_size=config.input_dim, | |
| hidden_size=config.latent_dim, | |
| num_layers=config.num_layers, | |
| batch_first=True, | |
| dropout=config.dropout_rate if config.num_layers > 1 else 0, | |
| bidirectional=config.bidirectional | |
| ) | |
| # Projection layer for bidirectional RNN | |
| if config.bidirectional: | |
| self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) | |
| else: | |
| self.projection = None | |
| # Batch normalization | |
| if config.use_batch_norm: | |
| self.batch_norm = nn.BatchNorm1d(config.latent_dim) | |
| else: | |
| self.batch_norm = None | |
| # Dropout | |
| if config.dropout_rate > 0: | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| else: | |
| self.dropout = None | |
| def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: | |
| """ | |
| Forward pass through recurrent encoder. | |
| Args: | |
| x: Input tensor of shape (batch_size, seq_len, input_dim) | |
| lengths: Sequence lengths for packed sequences (optional) | |
| Returns: | |
| Encoded representation or tuple for VAE | |
| """ | |
| batch_size, seq_len, _ = x.shape | |
| # Add noise for denoising autoencoders | |
| if self.config.is_denoising and self.training: | |
| noise = torch.randn_like(x) * self.config.noise_factor | |
| x = x + noise | |
| # Pack sequences if lengths provided | |
| if lengths is not None: | |
| x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) | |
| # RNN forward pass | |
| if self.config.rnn_type == "lstm": | |
| output, (hidden, cell) = self.rnn(x) | |
| else: | |
| output, hidden = self.rnn(x) | |
| cell = None | |
| # Unpack if necessary | |
| if lengths is not None: | |
| output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) | |
| # Use last hidden state as encoding | |
| if self.config.bidirectional: | |
| # Concatenate forward and backward hidden states | |
| hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) | |
| hidden = hidden[-1] # Take last layer | |
| hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) # Concatenate directions | |
| # Project to latent dimension | |
| if self.projection: | |
| hidden = self.projection(hidden) | |
| else: | |
| hidden = hidden[-1] # Take last layer | |
| # Apply batch normalization | |
| if self.batch_norm: | |
| hidden = self.batch_norm(hidden) | |
| # Apply dropout | |
| if self.dropout and self.training: | |
| hidden = self.dropout(hidden) | |
| # Handle variational encoding | |
| if self.config.is_variational: | |
| # Split hidden into mean and log variance | |
| mu = hidden[:, :self.config.latent_dim // 2] | |
| logvar = hidden[:, self.config.latent_dim // 2:] | |
| # Reparameterization trick | |
| if self.training: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| z = mu + eps * std | |
| else: | |
| z = mu | |
| return z, mu, logvar | |
| else: | |
| return hidden | |
| class RecurrentDecoder(nn.Module): | |
| """Recurrent decoder for sequence data.""" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| # Get RNN class | |
| if config.rnn_type == "lstm": | |
| rnn_class = nn.LSTM | |
| elif config.rnn_type == "gru": | |
| rnn_class = nn.GRU | |
| elif config.rnn_type == "rnn": | |
| rnn_class = nn.RNN | |
| else: | |
| raise ValueError(f"Unknown RNN type: {config.rnn_type}") | |
| # Create RNN layers | |
| self.rnn = rnn_class( | |
| input_size=config.latent_dim, | |
| hidden_size=config.latent_dim, | |
| num_layers=config.num_layers, | |
| batch_first=True, | |
| dropout=config.dropout_rate if config.num_layers > 1 else 0, | |
| bidirectional=False # Decoder is always unidirectional | |
| ) | |
| # Output projection | |
| self.output_projection = nn.Linear(config.latent_dim, config.input_dim) | |
| # Batch normalization | |
| if config.use_batch_norm: | |
| self.batch_norm = nn.BatchNorm1d(config.latent_dim) | |
| else: | |
| self.batch_norm = None | |
| # Dropout | |
| if config.dropout_rate > 0: | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| else: | |
| self.dropout = None | |
| def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Forward pass through recurrent decoder. | |
| Args: | |
| z: Latent representation of shape (batch_size, latent_dim) | |
| target_length: Length of sequence to generate | |
| target_sequence: Target sequence for teacher forcing (optional) | |
| Returns: | |
| Decoded sequence of shape (batch_size, seq_len, input_dim) | |
| """ | |
| batch_size = z.size(0) | |
| device = z.device | |
| # Initialize hidden state with latent representation | |
| if self.config.rnn_type == "lstm": | |
| h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) | |
| c_0 = torch.zeros_like(h_0) | |
| hidden = (h_0, c_0) | |
| else: | |
| hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) | |
| outputs = [] | |
| # Initialize input (can be learned or zero) | |
| current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) | |
| for t in range(target_length): | |
| # Teacher forcing decision | |
| use_teacher_forcing = (target_sequence is not None and | |
| self.training and | |
| random.random() < self.config.teacher_forcing_ratio) | |
| if use_teacher_forcing and t > 0: | |
| # Use previous target as input | |
| current_input = target_sequence[:, t-1:t, :] | |
| # Project to latent dimension if needed | |
| if current_input.size(-1) != self.config.latent_dim: | |
| current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) | |
| # RNN forward step | |
| if self.config.rnn_type == "lstm": | |
| output, hidden = self.rnn(current_input, hidden) | |
| else: | |
| output, hidden = self.rnn(current_input, hidden) | |
| # Apply batch normalization and dropout | |
| output_flat = output.squeeze(1) # Remove sequence dimension | |
| if self.batch_norm: | |
| output_flat = self.batch_norm(output_flat) | |
| if self.dropout and self.training: | |
| output_flat = self.dropout(output_flat) | |
| # Project to output dimension | |
| step_output = self.output_projection(output_flat) | |
| outputs.append(step_output.unsqueeze(1)) | |
| # Use output as next input (for non-teacher forcing) | |
| if not use_teacher_forcing: | |
| # Project output back to latent dimension for next step | |
| current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) | |
| # Concatenate all outputs | |
| return torch.cat(outputs, dim=1) | |
| class AutoencoderModel(PreTrainedModel): | |
| """ | |
| The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. | |
| This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the | |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
| etc.) | |
| This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the | |
| PyTorch documentation for all matter related to general usage and behavior. | |
| """ | |
| config_class = AutoencoderConfig | |
| base_model_prefix = "autoencoder" | |
| supports_gradient_checkpointing = False | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__(config) | |
| self.config = config | |
| # Initialize learnable preprocessing | |
| if config.has_preprocessing: | |
| self.preprocessor = LearnablePreprocessor(config) | |
| else: | |
| self.preprocessor = None | |
| # Initialize encoder and decoder based on type | |
| if config.is_recurrent: | |
| self.encoder = RecurrentEncoder(config) | |
| self.decoder = RecurrentDecoder(config) | |
| else: | |
| self.encoder = AutoencoderEncoder(config) | |
| self.decoder = AutoencoderDecoder(config) | |
| # Tie weights if specified | |
| if config.tie_weights: | |
| self._tie_weights() | |
| # Initialize weights | |
| self.post_init() | |
| def _tie_weights(self): | |
| """Tie encoder and decoder weights (transpose relationship).""" | |
| # This is a simplified weight tying - in practice, you might want more sophisticated tying | |
| pass | |
| def get_input_embeddings(self): | |
| """Get input embeddings (not applicable for basic autoencoder).""" | |
| return None | |
| def set_input_embeddings(self, value): | |
| """Set input embeddings (not applicable for basic autoencoder).""" | |
| pass | |
| def forward( | |
| self, | |
| input_values: torch.Tensor, | |
| sequence_lengths: Optional[torch.Tensor] = None, | |
| target_length: Optional[int] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: | |
| """ | |
| Forward pass through the autoencoder. | |
| Args: | |
| input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: | |
| - Standard: (batch_size, input_dim) | |
| - Recurrent: (batch_size, seq_len, input_dim) | |
| sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. | |
| target_length (int, optional): Target sequence length for recurrent decoder. | |
| output_hidden_states (bool, optional): Whether to return hidden states. | |
| return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. | |
| Returns: | |
| AutoencoderOutput or tuple: The model outputs. | |
| """ | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # Apply learnable preprocessing | |
| preprocessing_loss = torch.tensor(0.0, device=input_values.device) | |
| if self.preprocessor is not None: | |
| input_values, preprocessing_loss = self.preprocessor(input_values, inverse=False) | |
| # Handle different autoencoder types | |
| if self.config.is_recurrent: | |
| # Recurrent autoencoder | |
| if sequence_lengths is not None: | |
| encoder_output = self.encoder(input_values, sequence_lengths) | |
| else: | |
| encoder_output = self.encoder(input_values) | |
| if self.config.is_variational: | |
| latent, mu, logvar = encoder_output | |
| self._mu = mu | |
| self._logvar = logvar | |
| else: | |
| latent = encoder_output | |
| self._mu = None | |
| self._logvar = None | |
| # Determine target length for decoder | |
| if target_length is None: | |
| if self.config.sequence_length is not None: | |
| target_length = self.config.sequence_length | |
| else: | |
| target_length = input_values.size(1) # Use input sequence length | |
| # Decode latent back to sequence space | |
| reconstructed = self.decoder(latent, target_length, input_values if self.training else None) | |
| else: | |
| # Standard autoencoder | |
| encoder_output = self.encoder(input_values) | |
| if self.config.is_variational: | |
| latent, mu, logvar = encoder_output | |
| self._mu = mu | |
| self._logvar = logvar | |
| else: | |
| latent = encoder_output | |
| self._mu = None | |
| self._logvar = None | |
| # Decode latent back to input space | |
| reconstructed = self.decoder(latent) | |
| # Apply inverse preprocessing to reconstruction | |
| if self.preprocessor is not None and self.config.learn_inverse_preprocessing: | |
| reconstructed, inverse_loss = self.preprocessor(reconstructed, inverse=True) | |
| preprocessing_loss += inverse_loss | |
| hidden_states = None | |
| if output_hidden_states: | |
| if self.config.is_variational: | |
| hidden_states = (latent, mu, logvar) | |
| else: | |
| hidden_states = (latent,) | |
| if not return_dict: | |
| return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) | |
| return AutoencoderOutput( | |
| last_hidden_state=latent, | |
| reconstructed=reconstructed, | |
| hidden_states=hidden_states, | |
| preprocessing_loss=preprocessing_loss, | |
| ) | |
| class AutoencoderForReconstruction(PreTrainedModel): | |
| """ | |
| Autoencoder Model with a reconstruction head on top for reconstruction tasks. | |
| This model inherits from PreTrainedModel and adds a reconstruction loss calculation. | |
| """ | |
| config_class = AutoencoderConfig | |
| base_model_prefix = "autoencoder" | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__(config) | |
| self.config = config | |
| # Initialize the base autoencoder model | |
| self.autoencoder = AutoencoderModel(config) | |
| # Initialize weights | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| """Get input embeddings.""" | |
| return self.autoencoder.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| """Set input embeddings.""" | |
| self.autoencoder.set_input_embeddings(value) | |
| def _compute_reconstruction_loss( | |
| self, | |
| reconstructed: torch.Tensor, | |
| target: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Compute reconstruction loss based on the configured loss type.""" | |
| if self.config.reconstruction_loss == "mse": | |
| return F.mse_loss(reconstructed, target, reduction="mean") | |
| elif self.config.reconstruction_loss == "bce": | |
| return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") | |
| elif self.config.reconstruction_loss == "l1": | |
| return F.l1_loss(reconstructed, target, reduction="mean") | |
| elif self.config.reconstruction_loss == "huber": | |
| return F.huber_loss(reconstructed, target, reduction="mean") | |
| elif self.config.reconstruction_loss == "smooth_l1": | |
| return F.smooth_l1_loss(reconstructed, target, reduction="mean") | |
| elif self.config.reconstruction_loss == "kl_div": | |
| return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") | |
| elif self.config.reconstruction_loss == "cosine": | |
| return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() | |
| elif self.config.reconstruction_loss == "focal": | |
| return self._focal_loss(reconstructed, target) | |
| elif self.config.reconstruction_loss == "dice": | |
| return self._dice_loss(reconstructed, target) | |
| elif self.config.reconstruction_loss == "tversky": | |
| return self._tversky_loss(reconstructed, target) | |
| elif self.config.reconstruction_loss == "ssim": | |
| return self._ssim_loss(reconstructed, target) | |
| elif self.config.reconstruction_loss == "perceptual": | |
| return self._perceptual_loss(reconstructed, target) | |
| else: | |
| raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") | |
| def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: | |
| """Compute focal loss for handling class imbalance.""" | |
| ce_loss = F.mse_loss(pred, target, reduction="none") | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = alpha * (1 - pt) ** gamma * ce_loss | |
| return focal_loss.mean() | |
| def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: | |
| """Compute Dice loss for segmentation-like tasks.""" | |
| pred_flat = pred.view(-1) | |
| target_flat = target.view(-1) | |
| intersection = (pred_flat * target_flat).sum() | |
| dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) | |
| return 1 - dice | |
| def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: | |
| """Compute Tversky loss, a generalization of Dice loss.""" | |
| pred_flat = pred.view(-1) | |
| target_flat = target.view(-1) | |
| true_pos = (pred_flat * target_flat).sum() | |
| false_neg = (target_flat * (1 - pred_flat)).sum() | |
| false_pos = ((1 - target_flat) * pred_flat).sum() | |
| tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) | |
| return 1 - tversky | |
| def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| """Compute SSIM-based loss (simplified version).""" | |
| # Simplified SSIM for 1D data | |
| mu1 = pred.mean(dim=-1, keepdim=True) | |
| mu2 = target.mean(dim=-1, keepdim=True) | |
| sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) | |
| sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) | |
| sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) | |
| c1, c2 = 0.01, 0.03 | |
| ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) | |
| return 1 - ssim.mean() | |
| def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| """Compute perceptual loss (simplified version using feature differences).""" | |
| # For simplicity, use L2 loss on normalized features | |
| pred_norm = F.normalize(pred, p=2, dim=-1) | |
| target_norm = F.normalize(target, p=2, dim=-1) | |
| return F.mse_loss(pred_norm, target_norm) | |
| def forward( | |
| self, | |
| input_values: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| sequence_lengths: Optional[torch.Tensor] = None, | |
| target_length: Optional[int] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: | |
| """ | |
| Forward pass with reconstruction loss calculation. | |
| Args: | |
| input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: | |
| - Standard: (batch_size, input_dim) | |
| - Recurrent: (batch_size, seq_len, input_dim) | |
| labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. | |
| sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. | |
| target_length (int, optional): Target sequence length for recurrent decoder. | |
| output_hidden_states (bool, optional): Whether to return hidden states. | |
| return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. | |
| Returns: | |
| AutoencoderForReconstructionOutput or tuple: The model outputs including loss. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # If no labels provided, use input as target (standard autoencoder) | |
| if labels is None: | |
| labels = input_values | |
| # Forward pass through autoencoder | |
| outputs = self.autoencoder( | |
| input_values=input_values, | |
| sequence_lengths=sequence_lengths, | |
| target_length=target_length, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| ) | |
| reconstructed = outputs.reconstructed | |
| latent = outputs.last_hidden_state | |
| hidden_states = outputs.hidden_states | |
| # Compute reconstruction loss | |
| recon_loss = self._compute_reconstruction_loss(reconstructed, labels) | |
| # Add regularization losses based on autoencoder type | |
| total_loss = recon_loss | |
| # Add preprocessing loss if available | |
| if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: | |
| total_loss += outputs.preprocessing_loss | |
| if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: | |
| # KL divergence loss for variational autoencoders | |
| kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) | |
| kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) # Normalize by batch size and latent dim | |
| total_loss = recon_loss + self.config.beta * kl_loss | |
| elif self.config.is_sparse: | |
| # Sparsity loss for sparse autoencoders | |
| latent = outputs.last_hidden_state | |
| sparsity_loss = torch.mean(torch.abs(latent)) # L1 sparsity | |
| total_loss = recon_loss + 0.1 * sparsity_loss # Sparsity weight | |
| elif self.config.is_contractive: | |
| # Contractive loss - penalize large gradients of hidden representation w.r.t. input | |
| latent = outputs.last_hidden_state | |
| latent.retain_grad() | |
| if latent.grad is not None: | |
| contractive_loss = torch.sum(latent.grad ** 2) | |
| total_loss = recon_loss + 0.1 * contractive_loss | |
| loss = total_loss | |
| if not return_dict: | |
| output = (reconstructed, latent) | |
| if hidden_states is not None: | |
| output = output + (hidden_states,) | |
| return ((loss,) + output) if loss is not None else output | |
| return AutoencoderForReconstructionOutput( | |
| loss=loss, | |
| reconstructed=reconstructed, | |
| last_hidden_state=latent, | |
| hidden_states=hidden_states, | |
| preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, | |
| ) | |