| """ |
| PyTorch Autoencoder model for Hugging Face Transformers. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, Union, Dict, Any, List |
| from dataclasses import dataclass |
| import random |
| import re |
|
|
| |
| try: |
| from transformers.modeling_utils import PreTrainedModel |
| except Exception: |
| |
| from transformers import PreTrainedModel |
|
|
| from transformers.modeling_outputs import BaseModelOutput |
| from transformers.utils import ModelOutput |
|
|
| try: |
| from .configuration_autoencoder import AutoencoderConfig |
| except Exception: |
| from configuration_autoencoder import AutoencoderConfig |
|
|
| |
| try: |
| from .blocks import ( |
| BlockFactory, |
| BlockSequence, |
| LinearBlockConfig, |
| AttentionBlockConfig, |
| RecurrentBlockConfig, |
| ConvolutionalBlockConfig, |
| VariationalBlockConfig, |
| VariationalBlock, |
| ) |
| except Exception: |
| from blocks import ( |
| BlockFactory, |
| BlockSequence, |
| LinearBlockConfig, |
| AttentionBlockConfig, |
| RecurrentBlockConfig, |
| ConvolutionalBlockConfig, |
| VariationalBlockConfig, |
| VariationalBlock, |
| ) |
|
|
| |
| try: |
| from .utils import _get_activation |
| except Exception: |
| from utils import _get_activation |
|
|
| |
| try: |
| from .preprocessing import PreprocessingBlock |
| except Exception: |
| from preprocessing import PreprocessingBlock |
|
|
|
|
| @dataclass |
| class AutoencoderOutput(ModelOutput): |
| """ |
| Output type of AutoencoderModel. |
| |
| Args: |
| last_hidden_state (torch.FloatTensor): The latent representation of the input. |
| reconstructed (torch.FloatTensor, optional): The reconstructed input. |
| hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
| attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. |
| preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
| """ |
|
|
| last_hidden_state: torch.FloatTensor = None |
| reconstructed: Optional[torch.FloatTensor] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| preprocessing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclass |
| class AutoencoderForReconstructionOutput(ModelOutput): |
| """ |
| Output type of AutoencoderForReconstruction. |
| |
| Args: |
| loss (torch.FloatTensor, optional): The reconstruction loss. |
| reconstructed (torch.FloatTensor): The reconstructed input. |
| last_hidden_state (torch.FloatTensor): The latent representation. |
| hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
| preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| reconstructed: torch.FloatTensor = None |
| last_hidden_state: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| preprocessing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
| class AutoencoderEncoder(nn.Module): |
| """Encoder part of the autoencoder.""" |
|
|
| def __init__(self, config: AutoencoderConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| layers = [] |
| input_dim = config.input_dim |
|
|
| for hidden_dim in config.hidden_dims: |
| layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
| if config.use_batch_norm: |
| layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
| layers.append(self._get_activation(config.activation)) |
|
|
| if config.dropout_rate > 0: |
| layers.append(nn.Dropout(config.dropout_rate)) |
|
|
| input_dim = hidden_dim |
|
|
| self.encoder = nn.Sequential(*layers) |
|
|
| |
| if config.is_variational: |
| self.fc_mu = nn.Linear(input_dim, config.latent_dim) |
| self.fc_logvar = nn.Linear(input_dim, config.latent_dim) |
| else: |
| |
| self.fc_out = nn.Linear(input_dim, config.latent_dim) |
|
|
|
|
| def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| """Forward pass through encoder.""" |
| |
| if self.config.is_denoising and self.training: |
| noise = torch.randn_like(x) * self.config.noise_factor |
| x = x + noise |
|
|
| encoded = self.encoder(x) |
|
|
| if self.config.is_variational: |
| |
| mu = self.fc_mu(encoded) |
| logvar = self.fc_logvar(encoded) |
|
|
| |
| if self.training: |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| z = mu + eps * std |
| else: |
| z = mu |
|
|
| return z, mu, logvar |
| else: |
| |
| latent = self.fc_out(encoded) |
|
|
| |
| if self.config.is_sparse and self.training: |
| |
| latent = F.relu(latent) |
|
|
| return latent |
|
|
|
|
| class AutoencoderDecoder(nn.Module): |
| """Decoder part of the autoencoder.""" |
|
|
| def __init__(self, config: AutoencoderConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| layers = [] |
| input_dim = config.latent_dim |
| decoder_dims = config.decoder_dims + [config.input_dim] |
|
|
| for i, hidden_dim in enumerate(decoder_dims): |
| layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
| |
| if i < len(decoder_dims) - 1: |
| if config.use_batch_norm: |
| layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
| layers.append(_get_activation(config.activation)) |
|
|
| if config.dropout_rate > 0: |
| layers.append(nn.Dropout(config.dropout_rate)) |
| else: |
| |
| if config.reconstruction_loss == "bce": |
| layers.append(nn.Sigmoid()) |
|
|
| input_dim = hidden_dim |
|
|
| self.decoder = nn.Sequential(*layers) |
|
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass through decoder.""" |
| return self.decoder(x) |
|
|
|
|
| class RecurrentEncoder(nn.Module): |
| """Recurrent encoder for sequence data.""" |
|
|
| def __init__(self, config: AutoencoderConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| if config.rnn_type == "lstm": |
| rnn_class = nn.LSTM |
| elif config.rnn_type == "gru": |
| rnn_class = nn.GRU |
| elif config.rnn_type == "rnn": |
| rnn_class = nn.RNN |
| else: |
| raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
|
|
| |
| self.rnn = rnn_class( |
| input_size=config.input_dim, |
| hidden_size=config.latent_dim, |
| num_layers=config.num_layers, |
| batch_first=True, |
| dropout=config.dropout_rate if config.num_layers > 1 else 0, |
| bidirectional=config.bidirectional |
| ) |
|
|
| |
| if config.bidirectional: |
| self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) |
| else: |
| self.projection = None |
|
|
| |
| if config.use_batch_norm: |
| self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
| else: |
| self.batch_norm = None |
|
|
| |
| if config.dropout_rate > 0: |
| self.dropout = nn.Dropout(config.dropout_rate) |
| else: |
| self.dropout = None |
|
|
| def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| """ |
| Forward pass through recurrent encoder. |
| |
| Args: |
| x: Input tensor of shape (batch_size, seq_len, input_dim) |
| lengths: Sequence lengths for packed sequences (optional) |
| |
| Returns: |
| Encoded representation or tuple for VAE |
| """ |
| batch_size, seq_len, _ = x.shape |
|
|
| |
| if self.config.is_denoising and self.training: |
| noise = torch.randn_like(x) * self.config.noise_factor |
| x = x + noise |
|
|
| |
| if lengths is not None: |
| x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
|
|
| |
| if self.config.rnn_type == "lstm": |
| output, (hidden, cell) = self.rnn(x) |
| else: |
| output, hidden = self.rnn(x) |
| cell = None |
|
|
| |
| if lengths is not None: |
| output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) |
|
|
| |
| if self.config.bidirectional: |
| |
| hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) |
| hidden = hidden[-1] |
| hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) |
|
|
| |
| if self.projection: |
| hidden = self.projection(hidden) |
| else: |
| hidden = hidden[-1] |
|
|
| |
| if self.batch_norm: |
| hidden = self.batch_norm(hidden) |
|
|
| |
| if self.dropout and self.training: |
| hidden = self.dropout(hidden) |
|
|
| |
| if self.config.is_variational: |
| |
| mu = hidden[:, :self.config.latent_dim // 2] |
| logvar = hidden[:, self.config.latent_dim // 2:] |
|
|
| |
| if self.training: |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| z = mu + eps * std |
| else: |
| z = mu |
|
|
| return z, mu, logvar |
| else: |
| return hidden |
|
|
|
|
| class RecurrentDecoder(nn.Module): |
| """Recurrent decoder for sequence data.""" |
|
|
| def __init__(self, config: AutoencoderConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| if config.rnn_type == "lstm": |
| rnn_class = nn.LSTM |
| elif config.rnn_type == "gru": |
| rnn_class = nn.GRU |
| elif config.rnn_type == "rnn": |
| rnn_class = nn.RNN |
| else: |
| raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
|
|
| |
| self.rnn = rnn_class( |
| input_size=config.latent_dim, |
| hidden_size=config.latent_dim, |
| num_layers=config.num_layers, |
| batch_first=True, |
| dropout=config.dropout_rate if config.num_layers > 1 else 0, |
| bidirectional=False |
| ) |
|
|
| |
| self.output_projection = nn.Linear(config.latent_dim, config.input_dim) |
|
|
| |
| if config.use_batch_norm: |
| self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
| else: |
| self.batch_norm = None |
|
|
| |
| if config.dropout_rate > 0: |
| self.dropout = nn.Dropout(config.dropout_rate) |
| else: |
| self.dropout = None |
|
|
| def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Forward pass through recurrent decoder. |
| |
| Args: |
| z: Latent representation of shape (batch_size, latent_dim) |
| target_length: Length of sequence to generate |
| target_sequence: Target sequence for teacher forcing (optional) |
| |
| Returns: |
| Decoded sequence of shape (batch_size, seq_len, input_dim) |
| """ |
| batch_size = z.size(0) |
| device = z.device |
|
|
| |
| if self.config.rnn_type == "lstm": |
| h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
| c_0 = torch.zeros_like(h_0) |
| hidden = (h_0, c_0) |
| else: |
| hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
|
|
| outputs = [] |
|
|
| |
| current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
| for t in range(target_length): |
| |
| use_teacher_forcing = (target_sequence is not None and |
| self.training and |
| random.random() < self.config.teacher_forcing_ratio) |
|
|
| if use_teacher_forcing and t > 0: |
| |
| current_input = target_sequence[:, t-1:t, :] |
| |
| if current_input.size(-1) != self.config.latent_dim: |
| current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
| |
| if self.config.rnn_type == "lstm": |
| output, hidden = self.rnn(current_input, hidden) |
| else: |
| output, hidden = self.rnn(current_input, hidden) |
|
|
| |
| output_flat = output.squeeze(1) |
|
|
| if self.batch_norm: |
| output_flat = self.batch_norm(output_flat) |
|
|
| if self.dropout and self.training: |
| output_flat = self.dropout(output_flat) |
|
|
| |
| step_output = self.output_projection(output_flat) |
| outputs.append(step_output.unsqueeze(1)) |
|
|
| |
| if not use_teacher_forcing: |
| |
| current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
| |
| return torch.cat(outputs, dim=1) |
|
|
|
|
| class AutoencoderModel(PreTrainedModel): |
| """ |
| The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. |
| |
| This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the |
| PyTorch documentation for all matter related to general usage and behavior. |
| """ |
|
|
| config_class = AutoencoderConfig |
| base_model_prefix = "autoencoder" |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config: AutoencoderConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| if config.has_preprocessing: |
| self.pre_block = PreprocessingBlock(config, inverse=False) |
| else: |
| self.pre_block = None |
|
|
| |
| norm = "batch" if config.use_batch_norm else "none" |
|
|
| def default_linear_sequence(in_dim: int, dims: List[int], activation: str, normalization: str, dropout: float) -> List[LinearBlockConfig]: |
| cfgs: List[LinearBlockConfig] = [] |
| prev = in_dim |
| for h in dims: |
| cfgs.append( |
| LinearBlockConfig( |
| input_dim=prev, |
| output_dim=h, |
| activation=activation, |
| normalization=normalization, |
| dropout_rate=dropout, |
| use_residual=False, |
| ) |
| ) |
| prev = h |
| return cfgs |
|
|
| |
| if getattr(config, "encoder_blocks", None): |
| enc_cfgs = config.encoder_blocks |
| |
| last_out = None |
| for b in enc_cfgs: |
| if isinstance(b, dict): |
| last_out = b.get("output_dim", last_out) |
| else: |
| last_out = getattr(b, "output_dim", last_out) |
| enc_out_dim = last_out or (config.hidden_dims[-1] if config.hidden_dims else config.input_dim) |
| else: |
| enc_cfgs = default_linear_sequence(config.input_dim, config.hidden_dims, config.activation, norm, config.dropout_rate) |
| enc_out_dim = config.hidden_dims[-1] if config.hidden_dims else config.input_dim |
| base_encoder_seq: BlockSequence = BlockFactory.build_sequence(enc_cfgs) if len(enc_cfgs) > 0 else BlockSequence([]) |
| |
| self.encoder_seq = base_encoder_seq |
|
|
| |
| if config.is_variational: |
| self.fc_mu = nn.Linear(enc_out_dim, config.latent_dim) |
| self.fc_logvar = nn.Linear(enc_out_dim, config.latent_dim) |
| self.to_latent = None |
| else: |
| self.fc_mu = None |
| self.fc_logvar = None |
| self.to_latent = nn.Linear(enc_out_dim, config.latent_dim) |
|
|
| |
| if getattr(config, "decoder_blocks", None): |
| dec_cfgs = config.decoder_blocks |
| else: |
| dec_dims = config.decoder_dims + [config.input_dim] |
| dec_cfgs = default_linear_sequence(config.latent_dim, dec_dims, config.activation, norm, config.dropout_rate) |
| |
| if len(dec_cfgs) > 0: |
| last = dec_cfgs[-1] |
| last.activation = "identity" |
| last.normalization = "none" |
| last.dropout_rate = 0.0 |
| self.decoder_seq: BlockSequence = BlockFactory.build_sequence(dec_cfgs) if len(dec_cfgs) > 0 else BlockSequence([]) |
|
|
| |
| if config.tie_weights: |
| self._tie_weights() |
|
|
| |
| self.post_init() |
|
|
| def _tie_weights(self): |
| """Tie encoder and decoder weights (transpose relationship).""" |
| |
| pass |
|
|
| def get_input_embeddings(self): |
| """Get input embeddings (not applicable for basic autoencoder).""" |
| return None |
|
|
| def set_input_embeddings(self, value): |
| """Set input embeddings (not applicable for basic autoencoder).""" |
| pass |
|
|
| def forward( |
| self, |
| input_values: torch.Tensor, |
| sequence_lengths: Optional[torch.Tensor] = None, |
| target_length: Optional[int] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: |
| """ |
| Forward pass through the autoencoder. |
| |
| Args: |
| input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
| - Standard: (batch_size, input_dim) |
| - Recurrent: (batch_size, seq_len, input_dim) |
| sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
| target_length (int, optional): Target sequence length for recurrent decoder. |
| output_hidden_states (bool, optional): Whether to return hidden states. |
| return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
| |
| Returns: |
| AutoencoderOutput or tuple: The model outputs. |
| """ |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if self.pre_block is not None: |
| input_values = self.pre_block(input_values) |
| preprocessing_loss = torch.tensor(0.0, device=input_values.device) |
|
|
| |
| |
| enc_out = self.encoder_seq(input_values) |
|
|
| |
| if self.config.is_variational: |
| |
| self._variational = getattr(self, '_variational', None) |
| if self._variational is None: |
| self._variational = VariationalBlock(VariationalBlockConfig(input_dim=enc_out.shape[-1], latent_dim=self.config.latent_dim)).to(enc_out.device) |
| latent = self._variational(enc_out, training=self.training) |
| self._mu = self._variational._mu |
| self._logvar = self._variational._logvar |
| else: |
| latent = self.to_latent(enc_out) if self.to_latent is not None else enc_out |
| self._mu, self._logvar = None, None |
|
|
| |
| reconstructed = self.decoder_seq(latent) |
|
|
|
|
|
|
| hidden_states = None |
| if output_hidden_states: |
| if self.config.is_variational: |
| hidden_states = (latent, getattr(self, '_mu', None), getattr(self, '_logvar', None)) |
| else: |
| hidden_states = (latent,) |
|
|
| if not return_dict: |
| return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) |
|
|
| return AutoencoderOutput( |
| last_hidden_state=latent, |
| reconstructed=reconstructed, |
| hidden_states=hidden_states, |
| preprocessing_loss=preprocessing_loss, |
| ) |
|
|
|
|
| class AutoencoderForReconstruction(PreTrainedModel): |
| """ |
| Autoencoder Model with a reconstruction head on top for reconstruction tasks. |
| |
| This model inherits from PreTrainedModel and adds a reconstruction loss calculation. |
| """ |
|
|
| config_class = AutoencoderConfig |
| base_model_prefix = "autoencoder" |
|
|
| def __init__(self, config: AutoencoderConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.autoencoder = AutoencoderModel(config) |
|
|
| |
| self.post_init() |
|
|
|
|
|
|
| def get_input_embeddings(self): |
| """Get input embeddings.""" |
| return self.autoencoder.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| """Set input embeddings.""" |
| self.autoencoder.set_input_embeddings(value) |
|
|
| def _compute_reconstruction_loss( |
| self, |
| reconstructed: torch.Tensor, |
| target: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute reconstruction loss based on the configured loss type.""" |
| if self.config.reconstruction_loss == "mse": |
| return F.mse_loss(reconstructed, target, reduction="mean") |
| elif self.config.reconstruction_loss == "bce": |
| return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") |
| elif self.config.reconstruction_loss == "l1": |
| return F.l1_loss(reconstructed, target, reduction="mean") |
| elif self.config.reconstruction_loss == "huber": |
| return F.huber_loss(reconstructed, target, reduction="mean") |
| elif self.config.reconstruction_loss == "smooth_l1": |
| return F.smooth_l1_loss(reconstructed, target, reduction="mean") |
| elif self.config.reconstruction_loss == "kl_div": |
| return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") |
| elif self.config.reconstruction_loss == "cosine": |
| return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() |
| elif self.config.reconstruction_loss == "focal": |
| return self._focal_loss(reconstructed, target) |
| elif self.config.reconstruction_loss == "dice": |
| return self._dice_loss(reconstructed, target) |
| elif self.config.reconstruction_loss == "tversky": |
| return self._tversky_loss(reconstructed, target) |
| elif self.config.reconstruction_loss == "ssim": |
| return self._ssim_loss(reconstructed, target) |
| elif self.config.reconstruction_loss == "perceptual": |
| return self._perceptual_loss(reconstructed, target) |
| else: |
| raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") |
|
|
| def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: |
| """Compute focal loss for handling class imbalance.""" |
| ce_loss = F.mse_loss(pred, target, reduction="none") |
| pt = torch.exp(-ce_loss) |
| focal_loss = alpha * (1 - pt) ** gamma * ce_loss |
| return focal_loss.mean() |
|
|
| def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: |
| """Compute Dice loss for segmentation-like tasks.""" |
| pred_flat = pred.view(-1) |
| target_flat = target.view(-1) |
| intersection = (pred_flat * target_flat).sum() |
| dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) |
| return 1 - dice |
|
|
| def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: |
| """Compute Tversky loss, a generalization of Dice loss.""" |
| pred_flat = pred.view(-1) |
| target_flat = target.view(-1) |
| true_pos = (pred_flat * target_flat).sum() |
| false_neg = (target_flat * (1 - pred_flat)).sum() |
| false_pos = ((1 - target_flat) * pred_flat).sum() |
| tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) |
| return 1 - tversky |
|
|
| def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| """Compute SSIM-based loss (simplified version).""" |
| |
| mu1 = pred.mean(dim=-1, keepdim=True) |
| mu2 = target.mean(dim=-1, keepdim=True) |
| sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) |
| sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) |
| sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) |
|
|
| c1, c2 = 0.01, 0.03 |
| ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) |
| return 1 - ssim.mean() |
|
|
| def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| """Compute perceptual loss (simplified version using feature differences).""" |
| |
| pred_norm = F.normalize(pred, p=2, dim=-1) |
| target_norm = F.normalize(target, p=2, dim=-1) |
| return F.mse_loss(pred_norm, target_norm) |
|
|
| def forward( |
| self, |
| input_values: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| sequence_lengths: Optional[torch.Tensor] = None, |
| target_length: Optional[int] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: |
| """ |
| Forward pass with reconstruction loss calculation. |
| |
| Args: |
| input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
| - Standard: (batch_size, input_dim) |
| - Recurrent: (batch_size, seq_len, input_dim) |
| labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. |
| sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
| target_length (int, optional): Target sequence length for recurrent decoder. |
| output_hidden_states (bool, optional): Whether to return hidden states. |
| return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
| |
| Returns: |
| AutoencoderForReconstructionOutput or tuple: The model outputs including loss. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if labels is None: |
| labels = input_values |
|
|
| |
| outputs = self.autoencoder( |
| input_values=input_values, |
| sequence_lengths=sequence_lengths, |
| target_length=target_length, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| ) |
|
|
| reconstructed = outputs.reconstructed |
| latent = outputs.last_hidden_state |
| hidden_states = outputs.hidden_states |
|
|
| |
| recon_loss = self._compute_reconstruction_loss(reconstructed, labels) |
|
|
| |
| total_loss = recon_loss |
|
|
| |
| if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: |
| total_loss += outputs.preprocessing_loss |
|
|
| if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: |
| |
| kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) |
| kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) |
| total_loss = recon_loss + self.config.beta * kl_loss |
|
|
| elif self.config.is_sparse: |
| |
| latent = outputs.last_hidden_state |
| sparsity_loss = torch.mean(torch.abs(latent)) |
| total_loss = recon_loss + 0.1 * sparsity_loss |
|
|
| elif self.config.is_contractive: |
| |
| latent = outputs.last_hidden_state |
| latent.retain_grad() |
| if latent.grad is not None: |
| contractive_loss = torch.sum(latent.grad ** 2) |
| total_loss = recon_loss + 0.1 * contractive_loss |
|
|
| loss = total_loss |
|
|
| if not return_dict: |
| output = (reconstructed, latent) |
| if hidden_states is not None: |
| output = output + (hidden_states,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return AutoencoderForReconstructionOutput( |
| loss=loss, |
| reconstructed=reconstructed, |
| last_hidden_state=latent, |
| hidden_states=hidden_states, |
| preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, |
| ) |
|
|