| """ |
| Modular, block-based components for building autoencoders in PyTorch. |
| |
| Core goals: |
| - Composable building blocks with consistent interfaces |
| - Support 2D (B, F) and 3D (B, T, F) tensors where applicable |
| - Simple configs to construct blocks and sequences |
| - Safe-by-default validation and helpful errors |
| |
| This module is intentionally self-contained to allow gradual integration with |
| existing models. It does not mutate current behavior. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| try: |
| from .configuration_autoencoder import ( |
| BlockConfig, |
| LinearBlockConfig, |
| AttentionBlockConfig, |
| RecurrentBlockConfig, |
| ConvolutionalBlockConfig, |
| VariationalBlockConfig, |
| ) |
| except Exception: |
| from configuration_autoencoder import ( |
| BlockConfig, |
| LinearBlockConfig, |
| AttentionBlockConfig, |
| RecurrentBlockConfig, |
| ConvolutionalBlockConfig, |
| VariationalBlockConfig, |
| ) |
|
|
|
|
| |
| try: |
| from .utils import _get_activation, _get_norm, _flatten_3d_to_2d, _maybe_restore_3d |
| except Exception: |
| from utils import _get_activation, _get_norm, _flatten_3d_to_2d, _maybe_restore_3d |
|
|
|
|
| |
|
|
| class BaseBlock(nn.Module): |
| """Abstract base for all blocks. |
| |
| All blocks should accept 2D (B, F) or 3D (B, T, F) tensors and return the |
| same rank, with last-dim equal to `output_dim`. |
| """ |
|
|
| def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| @property |
| def output_dim(self) -> int: |
| raise NotImplementedError |
|
|
|
|
| |
|
|
| class ResidualBlock(BaseBlock): |
| """Base class for blocks supporting residual connections. |
| |
| Implements a safe residual add when input and output dims match; otherwise |
| falls back to a learned projection. Residuals can be scaled. |
| """ |
|
|
| def __init__(self, residual: bool = False, residual_scale: float = 1.0, proj_dim_in: Optional[int] = None, proj_dim_out: Optional[int] = None): |
| super().__init__() |
| self.use_residual = residual |
| self.residual_scale = residual_scale |
| self._proj: Optional[nn.Module] = None |
| if residual and proj_dim_in is not None and proj_dim_out is not None and proj_dim_in != proj_dim_out: |
| self._proj = nn.Linear(proj_dim_in, proj_dim_out) |
|
|
| def _apply_residual(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| if not self.use_residual: |
| return y |
| x2d, hint = _flatten_3d_to_2d(x) |
| y2d, _ = _flatten_3d_to_2d(y) |
| if x2d.shape[-1] != y2d.shape[-1]: |
| if self._proj is None: |
| self._proj = nn.Linear(x2d.shape[-1], y2d.shape[-1]).to(y2d.device) |
| x2d = self._proj(x2d) |
| out = x2d + self.residual_scale * y2d |
| return _maybe_restore_3d(out, hint) |
|
|
|
|
| |
|
|
| class LinearBlock(ResidualBlock): |
| """Basic linear transformation with normalization and activation. |
| |
| - Handles both 2D (B, F) and 3D (B, T, F) tensors |
| - Optional normalization: batch|layer|group|instance|none |
| - Configurable activation |
| - Optional dropout |
| - Optional residual connection (with auto projection) |
| """ |
|
|
| def __init__(self, cfg: LinearBlockConfig): |
| super().__init__(residual=cfg.use_residual, residual_scale=cfg.residual_scale, proj_dim_in=cfg.input_dim, proj_dim_out=cfg.output_dim) |
| self.cfg = cfg |
|
|
| self.linear = nn.Linear(cfg.input_dim, cfg.output_dim) |
| |
| |
| if cfg.normalization == "layer": |
| self.norm = nn.LayerNorm(cfg.output_dim) |
| else: |
| self.norm = _get_norm(cfg.normalization, cfg.output_dim) |
| self.act = _get_activation(cfg.activation) |
| self.drop = nn.Dropout(cfg.dropout_rate) if cfg.dropout_rate and cfg.dropout_rate > 0 else nn.Identity() |
|
|
| @property |
| def output_dim(self) -> int: |
| return self.cfg.output_dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_in = x |
| x2d, hint = _flatten_3d_to_2d(x) |
| y = self.linear(x2d) |
| |
| if isinstance(self.norm, (nn.BatchNorm1d, nn.InstanceNorm1d, nn.GroupNorm)): |
| y = self.norm(y) |
| else: |
| |
| y = self.norm(y) |
| y = self.act(y) |
| y = self.drop(y) |
| y = _maybe_restore_3d(y, hint) |
| return self._apply_residual(x_in, y) |
|
|
|
|
| |
|
|
| class AttentionBlock(BaseBlock): |
| """Multi-head self-attention with optional FFN. |
| |
| Expects inputs as 3D (B, T, D) or 2D (B, D) which will be treated as (B, 1, D). |
| Supports optional attn mask and key padding mask via kwargs. |
| """ |
|
|
| def __init__(self, cfg: AttentionBlockConfig): |
| super().__init__() |
| self.cfg = cfg |
| d_model = cfg.input_dim |
| self.mha = nn.MultiheadAttention(d_model, num_heads=cfg.num_heads, dropout=cfg.dropout_rate, batch_first=True) |
| self.ln1 = nn.LayerNorm(d_model) |
| ffn_dim = cfg.ffn_dim or (4 * d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, ffn_dim), |
| _get_activation("gelu"), |
| nn.Dropout(cfg.dropout_rate), |
| nn.Linear(ffn_dim, d_model), |
| ) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(cfg.dropout_rate) |
|
|
| @property |
| def output_dim(self) -> int: |
| return self.cfg.input_dim |
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| if x.dim() == 2: |
| x = x.unsqueeze(1) |
| squeeze_back = True |
| else: |
| squeeze_back = False |
| |
| residual = x |
| attn_out, _ = self.mha(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) |
| x = self.ln1(residual + self.dropout(attn_out)) |
| |
| residual = x |
| x = self.ffn(x) |
| x = self.ln2(residual + self.dropout(x)) |
| if squeeze_back: |
| x = x.squeeze(1) |
| return x |
|
|
|
|
| |
|
|
| class RecurrentBlock(BaseBlock): |
| """RNN processing block supporting LSTM/GRU/RNN. |
| |
| Input: 3D (B, T, F) preferred. If 2D, treated as (B, 1, F). |
| Output dim equals cfg.output_dim if set; otherwise hidden_size * directions. |
| """ |
|
|
| def __init__(self, cfg: RecurrentBlockConfig): |
| super().__init__() |
| self.cfg = cfg |
| rnn_type = cfg.rnn_type.lower() |
| rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU, "rnn": nn.RNN}.get(rnn_type) |
| if rnn_cls is None: |
| raise ValueError(f"Unknown rnn_type: {cfg.rnn_type}") |
| self.rnn = rnn_cls( |
| input_size=cfg.input_dim, |
| hidden_size=cfg.hidden_size, |
| num_layers=cfg.num_layers, |
| batch_first=True, |
| dropout=cfg.dropout_rate if cfg.num_layers > 1 else 0.0, |
| bidirectional=cfg.bidirectional, |
| ) |
| out_dim = cfg.hidden_size * (2 if cfg.bidirectional else 1) |
| self._out_dim = cfg.output_dim or out_dim |
| self.proj = None if self._out_dim == out_dim else nn.Linear(out_dim, self._out_dim) |
|
|
| @property |
| def output_dim(self) -> int: |
| return self._out_dim |
|
|
| def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor: |
| squeeze_back = False |
| if x.dim() == 2: |
| x = x.unsqueeze(1) |
| squeeze_back = True |
| if lengths is not None: |
| x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
| if isinstance(self.rnn, nn.LSTM): |
| out, (h, c) = self.rnn(x) |
| else: |
| out, h = self.rnn(x) |
| if lengths is not None: |
| out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) |
| |
| y = out[:, -1, :] |
| if self.proj is not None: |
| y = self.proj(y) |
| if squeeze_back: |
| |
| return y |
| |
| return y.unsqueeze(1) |
|
|
|
|
| |
|
|
| class ConvolutionalBlock(BaseBlock): |
| """1D convolutional block for sequence-like data. |
| Accepts 3D (B, T, F) or 2D (B, F) which is treated as (B, 1, F). |
| """ |
|
|
| def __init__(self, cfg: ConvolutionalBlockConfig): |
| super().__init__() |
| self.cfg = cfg |
| |
| |
| padding = cfg.padding |
| if isinstance(padding, str) and padding == "same": |
| pad = cfg.kernel_size // 2 |
| else: |
| pad = int(padding) |
| self.conv = nn.Conv1d(cfg.input_dim, cfg.output_dim, kernel_size=cfg.kernel_size, padding=pad) |
| |
| if cfg.normalization == "layer": |
| self.norm = nn.GroupNorm(1, cfg.output_dim) |
| else: |
| self.norm = _get_norm(cfg.normalization, cfg.output_dim) |
| self.act = _get_activation(cfg.activation) |
| self.drop = nn.Dropout(cfg.dropout_rate) if cfg.dropout_rate and cfg.dropout_rate > 0 else nn.Identity() |
|
|
| @property |
| def output_dim(self) -> int: |
| return self.cfg.output_dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| squeeze_back = False |
| if x.dim() == 2: |
| x = x.unsqueeze(1) |
| squeeze_back = True |
| |
| x = x.transpose(1, 2) |
|
|
| y = self.conv(x) |
| if isinstance(self.norm, (nn.BatchNorm1d, nn.InstanceNorm1d, nn.GroupNorm)): |
| y = self.norm(y) |
| y = self.act(y) |
| y = self.drop(y) |
| y = y.transpose(1, 2) |
| if squeeze_back: |
| y = y.squeeze(1) |
| return y |
|
|
| |
|
|
| class VariationalBlock(BaseBlock): |
| """Encapsulates mu/logvar projection and reparameterization. |
| |
| Input can be 2D (B, F) or 3D (B, T, F); for 3D, operates per timestep and returns same rank. |
| Stores mu/logvar on the module for downstream loss usage. |
| """ |
|
|
| def __init__(self, cfg: VariationalBlockConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.fc_mu = nn.Linear(cfg.input_dim, cfg.latent_dim) |
| self.fc_logvar = nn.Linear(cfg.input_dim, cfg.latent_dim) |
| self._mu: Optional[torch.Tensor] = None |
| self._logvar: Optional[torch.Tensor] = None |
|
|
| @property |
| def output_dim(self) -> int: |
| return self.cfg.latent_dim |
|
|
| def forward(self, x: torch.Tensor, training: Optional[bool] = None) -> torch.Tensor: |
| if training is None: |
| training = self.training |
| x2d, hint = _flatten_3d_to_2d(x) |
| mu = self.fc_mu(x2d) |
| logvar = self.fc_logvar(x2d) |
| if training: |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| z = mu + eps * std |
| else: |
| z = mu |
| self._mu = mu |
| self._logvar = logvar |
| z = _maybe_restore_3d(z, hint) |
| return z |
|
|
|
|
|
|
|
|
| |
|
|
| class BlockSequence(nn.Module): |
| """Compose multiple blocks into a validated sequence. |
| |
| - Validates dimension flow between blocks |
| - Supports gradient checkpointing (per-block) via forward(checkpoint=True) |
| - Supports optional skip connections: pass `skips` as list of (src_idx, dst_idx) |
| """ |
|
|
| def __init__(self, blocks: Sequence[BaseBlock], validate_dims: bool = True, skips: Optional[List[Tuple[int, int]]] = None): |
| super().__init__() |
| self.blocks = nn.ModuleList(blocks) |
| self.skips = skips or [] |
| if validate_dims and len(blocks) > 1: |
| for i in range(1, len(blocks)): |
| prev = blocks[i - 1] |
| cur = blocks[i] |
| if getattr(prev, "output_dim", None) is None or getattr(cur, "output_dim", None) is None: |
| continue |
| if prev.output_dim != cur.output_dim and not isinstance(cur, LinearBlock): |
| |
| pass |
|
|
| def forward(self, x: torch.Tensor, checkpoint: bool = False, **kwargs) -> torch.Tensor: |
| activations: Dict[int, torch.Tensor] = {} |
| for i, block in enumerate(self.blocks): |
| if checkpoint and x.requires_grad: |
| x = torch.utils.checkpoint.checkpoint(lambda inp: block(inp, **kwargs), x) |
| else: |
| x = block(x, **kwargs) |
| activations[i] = x |
| |
| for src, dst in self.skips: |
| if dst == i and src in activations: |
| x = x + activations[src] |
| return x |
|
|
|
|
| |
|
|
| class BlockFactory: |
| """Factory to build blocks/sequences from configs. |
| |
| This is intentionally minimal; extend as needed. |
| """ |
|
|
| @staticmethod |
| def build_block(cfg: Union[BlockConfig, Dict[str, Any]]) -> BaseBlock: |
| |
| if isinstance(cfg, dict): |
| type_name = cfg.get("type") |
| |
| params = dict(cfg) |
| params.pop("type", None) |
| if type_name == "linear": |
| return LinearBlock(LinearBlockConfig(**params)) |
| if type_name == "attention": |
| return AttentionBlock(AttentionBlockConfig(**params)) |
| if type_name == "recurrent": |
| return RecurrentBlock(RecurrentBlockConfig(**params)) |
| if type_name == "conv1d": |
| return ConvolutionalBlock(ConvolutionalBlockConfig(**params)) |
| raise ValueError(f"Unsupported block type in dict cfg: {type_name} cfg={cfg}") |
| |
| if isinstance(cfg, LinearBlockConfig) or getattr(cfg, "type", None) == "linear": |
| if not isinstance(cfg, LinearBlockConfig): |
| cfg = LinearBlockConfig(**cfg.__dict__) |
| return LinearBlock(cfg) |
| if isinstance(cfg, AttentionBlockConfig) or getattr(cfg, "type", None) == "attention": |
| if not isinstance(cfg, AttentionBlockConfig): |
| cfg = AttentionBlockConfig(**cfg.__dict__) |
| return AttentionBlock(cfg) |
| if isinstance(cfg, RecurrentBlockConfig) or getattr(cfg, "type", None) == "recurrent": |
| if not isinstance(cfg, RecurrentBlockConfig): |
| cfg = RecurrentBlockConfig(**cfg.__dict__) |
| return RecurrentBlock(cfg) |
| if isinstance(cfg, ConvolutionalBlockConfig) or getattr(cfg, "type", None) == "conv1d": |
| if not isinstance(cfg, ConvolutionalBlockConfig): |
| cfg = ConvolutionalBlockConfig(**cfg.__dict__) |
| return ConvolutionalBlock(cfg) |
| if isinstance(cfg, VariationalBlockConfig) or getattr(cfg, "type", None) == "variational": |
| if not isinstance(cfg, VariationalBlockConfig): |
| cfg = VariationalBlockConfig(**cfg.__dict__) |
| return VariationalBlock(cfg) |
| raise ValueError(f"Unsupported block type: {cfg}") |
|
|
| @staticmethod |
| def build_sequence(configs: Sequence[Union[BlockConfig, Dict[str, Any]]]) -> BlockSequence: |
| blocks: List[BaseBlock] = [BlockFactory.build_block(c) for c in configs] |
| return BlockSequence(blocks) |
|
|
|
|
| __all__ = [ |
| "BlockConfig", |
| "LinearBlockConfig", |
| "AttentionBlockConfig", |
| "RecurrentBlockConfig", |
| "ConvolutionalBlockConfig", |
| "VariationalBlockConfig", |
| "BaseBlock", |
| "ResidualBlock", |
| "LinearBlock", |
| "AttentionBlock", |
| "RecurrentBlock", |
| "ConvolutionalBlock", |
| "VariationalBlock", |
| "BlockSequence", |
| "BlockFactory", |
| ] |
|
|
|
|