| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
|
|
| def _get_activation(name: Optional[str]) -> nn.Module: |
| if name is None: |
| return nn.Identity() |
| name = name.lower() |
| mapping = { |
| "relu": nn.ReLU(), |
| "gelu": nn.GELU(), |
| "silu": nn.SiLU(), |
| "swish": nn.SiLU(), |
| "tanh": nn.Tanh(), |
| "sigmoid": nn.Sigmoid(), |
| "leaky_relu": nn.LeakyReLU(0.2), |
| "elu": nn.ELU(), |
| "mish": nn.Mish(), |
| "softplus": nn.Softplus(), |
| "identity": nn.Identity(), |
| None: nn.Identity(), |
| } |
| if name not in mapping: |
| raise ValueError(f"Unknown activation: {name}") |
| return mapping[name] |
|
|
|
|
| def _get_norm(name: Optional[str], num_features: int) -> nn.Module: |
| if name is None or name == "none": |
| return nn.Identity() |
| name = name.lower() |
| if name == "batch": |
| return nn.BatchNorm1d(num_features) |
| if name == "layer": |
| return nn.LayerNorm(num_features) |
| if name == "instance": |
| return nn.InstanceNorm1d(num_features) |
| if name == "group": |
| |
| groups = max(1, min(8, num_features)) |
| |
| while num_features % groups != 0 and groups > 1: |
| groups -= 1 |
| if groups == 1: |
| return nn.LayerNorm(num_features) |
| return nn.GroupNorm(groups, num_features) |
| raise ValueError(f"Unknown normalization: {name}") |
|
|
|
|
| def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]: |
| if x.dim() == 3: |
| b, t, f = x.shape |
| return x.reshape(b * t, f), (b, t) |
| return x, None |
|
|
|
|
| def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor: |
| if shape_hint is None: |
| return x |
| b, t = shape_hint |
| f = x.shape[-1] |
| return x.reshape(b, t, f) |