Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| from torch import nn | |
| def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): | |
| # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor | |
| # This function is a PyTorch version of jax truncated normal init (default init method in flax) | |
| # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848 | |
| # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199 | |
| with torch.no_grad(): | |
| if std == 0: | |
| tensor.zero_() | |
| else: | |
| sqrt2 = math.sqrt(2) | |
| a = math.erf(lower / sqrt2) | |
| b = math.erf(upper / sqrt2) | |
| z = (b - a) / 2 | |
| c = (2 * math.pi) ** -0.5 | |
| pdf_u = c * math.exp(-0.5 * lower ** 2) | |
| pdf_l = c * math.exp(-0.5 * upper ** 2) | |
| comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) | |
| tensor.uniform_(a, b) | |
| tensor.erfinv_() | |
| tensor.mul_(sqrt2 * comp_std) | |
| tensor.clip_(lower * comp_std, upper * comp_std) | |
| return tensor | |