Update param_init_fns.py
Browse files- param_init_fns.py +49 -51
    	
        param_init_fns.py
    CHANGED
    
    | @@ -2,22 +2,26 @@ import math | |
| 2 | 
             
            import warnings
         | 
| 3 | 
             
            from collections.abc import Sequence
         | 
| 4 | 
             
            from functools import partial
         | 
| 5 | 
            -
            from typing import Optional, Tuple, Union
         | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from torch import nn
         | 
|  | |
| 8 | 
             
            from .norm import NORM_CLASS_REGISTRY
         | 
|  | |
|  | |
|  | |
|  | |
| 9 |  | 
| 10 | 
            -
            def torch_default_param_init_fn_(module: nn.Module,  | 
| 11 | 
             
                del kwargs
         | 
| 12 | 
            -
                if  | 
| 13 | 
            -
                    warnings.warn(f"Initializing network using module's reset_parameters attribute")
         | 
| 14 | 
            -
                if hasattr(module, 'reset_parameters'):
         | 
| 15 | 
             
                    module.reset_parameters()
         | 
| 16 |  | 
| 17 | 
            -
            def fused_init_helper_(module: nn.Module, init_fn_):
         | 
| 18 | 
             
                _fused = getattr(module, '_fused', None)
         | 
| 19 | 
             
                if _fused is None:
         | 
| 20 | 
             
                    raise RuntimeError(f'Internal logic error')
         | 
|  | |
| 21 | 
             
                (dim, splits) = _fused
         | 
| 22 | 
             
                splits = (0, *splits, module.weight.size(dim))
         | 
| 23 | 
             
                for (s, e) in zip(splits[:-1], splits[1:]):
         | 
| @@ -25,10 +29,8 @@ def fused_init_helper_(module: nn.Module, init_fn_): | |
| 25 | 
             
                    slice_indices[dim] = slice(s, e)
         | 
| 26 | 
             
                    init_fn_(module.weight[slice_indices])
         | 
| 27 |  | 
| 28 | 
            -
            def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,  | 
| 29 | 
             
                del kwargs
         | 
| 30 | 
            -
                if verbose > 1:
         | 
| 31 | 
            -
                    warnings.warn(f'If model has bias parameters they are initialized to 0.')
         | 
| 32 | 
             
                init_div_is_residual = init_div_is_residual
         | 
| 33 | 
             
                if init_div_is_residual is False:
         | 
| 34 | 
             
                    div_is_residual = 1.0
         | 
| @@ -36,20 +38,18 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: | |
| 36 | 
             
                    div_is_residual = math.sqrt(2 * n_layers)
         | 
| 37 | 
             
                elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
         | 
| 38 | 
             
                    div_is_residual = init_div_is_residual
         | 
| 39 | 
            -
                elif  | 
| 40 | 
             
                    div_is_residual = float(init_div_is_residual)
         | 
| 41 | 
             
                else:
         | 
| 42 | 
             
                    div_is_residual = 1.0
         | 
| 43 | 
             
                    raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
         | 
| 44 | 
            -
                if  | 
| 45 | 
            -
                    if verbose > 1:
         | 
| 46 | 
            -
                        warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
         | 
| 47 | 
            -
                if isinstance(module, nn.Linear):
         | 
| 48 | 
             
                    if hasattr(module, '_fused'):
         | 
| 49 | 
             
                        fused_init_helper_(module, init_fn_)
         | 
| 50 | 
             
                    else:
         | 
| 51 | 
             
                        init_fn_(module.weight)
         | 
| 52 | 
             
                    if module.bias is not None:
         | 
|  | |
| 53 | 
             
                        torch.nn.init.zeros_(module.bias)
         | 
| 54 | 
             
                    if init_div_is_residual is not False and getattr(module, '_is_residual', False):
         | 
| 55 | 
             
                        with torch.no_grad():
         | 
| @@ -60,8 +60,6 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: | |
| 60 | 
             
                        if std == 0:
         | 
| 61 | 
             
                            warnings.warn(f'Embedding layer initialized to 0.')
         | 
| 62 | 
             
                        emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
         | 
| 63 | 
            -
                        if verbose > 1:
         | 
| 64 | 
            -
                            warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
         | 
| 65 | 
             
                    elif emb_init_uniform_lim is not None:
         | 
| 66 | 
             
                        lim = emb_init_uniform_lim
         | 
| 67 | 
             
                        if isinstance(lim, Sequence):
         | 
| @@ -75,17 +73,13 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: | |
| 75 | 
             
                            lim = [-lim, lim]
         | 
| 76 | 
             
                        (a, b) = lim
         | 
| 77 | 
             
                        emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
         | 
| 78 | 
            -
                        if verbose > 1:
         | 
| 79 | 
            -
                            warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
         | 
| 80 | 
             
                    else:
         | 
| 81 | 
             
                        emb_init_fn_ = init_fn_
         | 
| 82 | 
             
                    emb_init_fn_(module.weight)
         | 
| 83 | 
             
                elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
         | 
| 84 | 
            -
                    if  | 
| 85 | 
            -
                        warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
         | 
| 86 | 
            -
                    if hasattr(module, 'weight') and module.weight is not None:
         | 
| 87 | 
             
                        torch.nn.init.ones_(module.weight)
         | 
| 88 | 
            -
                    if hasattr(module, 'bias') and module.bias  | 
| 89 | 
             
                        torch.nn.init.zeros_(module.bias)
         | 
| 90 | 
             
                elif isinstance(module, nn.MultiheadAttention):
         | 
| 91 | 
             
                    if module._qkv_same_embed_dim:
         | 
| @@ -114,32 +108,45 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: | |
| 114 | 
             
                            module.out_proj.weight.div_(div_is_residual)
         | 
| 115 | 
             
                    if module.out_proj.bias is not None:
         | 
| 116 | 
             
                        torch.nn.init.zeros_(module.out_proj.bias)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 117 | 
             
                else:
         | 
| 118 | 
             
                    for _ in module.parameters(recurse=False):
         | 
| 119 | 
             
                        raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
         | 
| 120 |  | 
| 121 | 
            -
            def _normal_init_(std, mean=0.0):
         | 
| 122 | 
             
                return partial(torch.nn.init.normal_, mean=mean, std=std)
         | 
| 123 |  | 
| 124 | 
            -
            def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,  | 
| 125 | 
             
                del kwargs
         | 
| 126 | 
             
                init_fn_ = _normal_init_(std=std)
         | 
| 127 | 
            -
                 | 
| 128 | 
            -
                    warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
         | 
| 129 | 
            -
                generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
         | 
| 130 |  | 
| 131 | 
            -
            def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,  | 
| 132 | 
             
                del kwargs
         | 
| 133 | 
             
                if init_std is None:
         | 
| 134 | 
             
                    raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
         | 
| 135 | 
            -
                _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim | 
| 136 |  | 
| 137 | 
            -
            def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,  | 
| 138 | 
             
                del kwargs
         | 
| 139 | 
             
                std = math.sqrt(2 / (5 * d_model))
         | 
| 140 | 
            -
                _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim | 
| 141 |  | 
| 142 | 
            -
            def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None,  | 
| 143 | 
             
                """From section 2.3.1 of GPT-NeoX-20B:
         | 
| 144 |  | 
| 145 | 
             
                An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
         | 
| @@ -148,34 +155,25 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init | |
| 148 | 
             
                """
         | 
| 149 | 
             
                del kwargs
         | 
| 150 | 
             
                residual_div = n_layers / math.sqrt(10)
         | 
| 151 | 
            -
                 | 
| 152 | 
            -
                    warnings.warn(f'setting init_div_is_residual to {residual_div}')
         | 
| 153 | 
            -
                small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
         | 
| 154 |  | 
| 155 | 
            -
            def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu',  | 
| 156 | 
             
                del kwargs
         | 
| 157 | 
            -
                if verbose > 1:
         | 
| 158 | 
            -
                    warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
         | 
| 159 | 
             
                kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
         | 
| 160 | 
            -
                generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim | 
| 161 |  | 
| 162 | 
            -
            def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu',  | 
| 163 | 
             
                del kwargs
         | 
| 164 | 
            -
                if verbose > 1:
         | 
| 165 | 
            -
                    warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
         | 
| 166 | 
             
                kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
         | 
| 167 | 
            -
                generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim | 
| 168 |  | 
| 169 | 
            -
            def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0,  | 
| 170 | 
             
                del kwargs
         | 
| 171 | 
             
                xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
         | 
| 172 | 
            -
                 | 
| 173 | 
            -
                    warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
         | 
| 174 | 
            -
                generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
         | 
| 175 |  | 
| 176 | 
            -
            def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0,  | 
|  | |
| 177 | 
             
                xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
         | 
| 178 | 
            -
                 | 
| 179 | 
            -
                    warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
         | 
| 180 | 
            -
                generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
         | 
| 181 | 
             
            MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
         | 
|  | |
| 2 | 
             
            import warnings
         | 
| 3 | 
             
            from collections.abc import Sequence
         | 
| 4 | 
             
            from functools import partial
         | 
| 5 | 
            +
            from typing import Any, Callable, Optional, Tuple, Union
         | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from torch import nn
         | 
| 8 | 
            +
            from .fc import FC_CLASS_REGISTRY
         | 
| 9 | 
             
            from .norm import NORM_CLASS_REGISTRY
         | 
| 10 | 
            +
            try:
         | 
| 11 | 
            +
                import transformer_engine.pytorch as te
         | 
| 12 | 
            +
            except:
         | 
| 13 | 
            +
                te = None
         | 
| 14 |  | 
| 15 | 
            +
            def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
         | 
| 16 | 
             
                del kwargs
         | 
| 17 | 
            +
                if hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
         | 
|  | |
|  | |
| 18 | 
             
                    module.reset_parameters()
         | 
| 19 |  | 
| 20 | 
            +
            def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
         | 
| 21 | 
             
                _fused = getattr(module, '_fused', None)
         | 
| 22 | 
             
                if _fused is None:
         | 
| 23 | 
             
                    raise RuntimeError(f'Internal logic error')
         | 
| 24 | 
            +
                assert isinstance(module.weight, torch.Tensor)
         | 
| 25 | 
             
                (dim, splits) = _fused
         | 
| 26 | 
             
                splits = (0, *splits, module.weight.size(dim))
         | 
| 27 | 
             
                for (s, e) in zip(splits[:-1], splits[1:]):
         | 
|  | |
| 29 | 
             
                    slice_indices[dim] = slice(s, e)
         | 
| 30 | 
             
                    init_fn_(module.weight[slice_indices])
         | 
| 31 |  | 
| 32 | 
            +
            def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
         | 
| 33 | 
             
                del kwargs
         | 
|  | |
|  | |
| 34 | 
             
                init_div_is_residual = init_div_is_residual
         | 
| 35 | 
             
                if init_div_is_residual is False:
         | 
| 36 | 
             
                    div_is_residual = 1.0
         | 
|  | |
| 38 | 
             
                    div_is_residual = math.sqrt(2 * n_layers)
         | 
| 39 | 
             
                elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
         | 
| 40 | 
             
                    div_is_residual = init_div_is_residual
         | 
| 41 | 
            +
                elif init_div_is_residual.isnumeric():
         | 
| 42 | 
             
                    div_is_residual = float(init_div_is_residual)
         | 
| 43 | 
             
                else:
         | 
| 44 | 
             
                    div_is_residual = 1.0
         | 
| 45 | 
             
                    raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
         | 
| 46 | 
            +
                if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
         | 
|  | |
|  | |
|  | |
| 47 | 
             
                    if hasattr(module, '_fused'):
         | 
| 48 | 
             
                        fused_init_helper_(module, init_fn_)
         | 
| 49 | 
             
                    else:
         | 
| 50 | 
             
                        init_fn_(module.weight)
         | 
| 51 | 
             
                    if module.bias is not None:
         | 
| 52 | 
            +
                        assert isinstance(module.bias, torch.Tensor)
         | 
| 53 | 
             
                        torch.nn.init.zeros_(module.bias)
         | 
| 54 | 
             
                    if init_div_is_residual is not False and getattr(module, '_is_residual', False):
         | 
| 55 | 
             
                        with torch.no_grad():
         | 
|  | |
| 60 | 
             
                        if std == 0:
         | 
| 61 | 
             
                            warnings.warn(f'Embedding layer initialized to 0.')
         | 
| 62 | 
             
                        emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
         | 
|  | |
|  | |
| 63 | 
             
                    elif emb_init_uniform_lim is not None:
         | 
| 64 | 
             
                        lim = emb_init_uniform_lim
         | 
| 65 | 
             
                        if isinstance(lim, Sequence):
         | 
|  | |
| 73 | 
             
                            lim = [-lim, lim]
         | 
| 74 | 
             
                        (a, b) = lim
         | 
| 75 | 
             
                        emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
         | 
|  | |
|  | |
| 76 | 
             
                    else:
         | 
| 77 | 
             
                        emb_init_fn_ = init_fn_
         | 
| 78 | 
             
                    emb_init_fn_(module.weight)
         | 
| 79 | 
             
                elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
         | 
| 80 | 
            +
                    if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
         | 
|  | |
|  | |
| 81 | 
             
                        torch.nn.init.ones_(module.weight)
         | 
| 82 | 
            +
                    if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
         | 
| 83 | 
             
                        torch.nn.init.zeros_(module.bias)
         | 
| 84 | 
             
                elif isinstance(module, nn.MultiheadAttention):
         | 
| 85 | 
             
                    if module._qkv_same_embed_dim:
         | 
|  | |
| 108 | 
             
                            module.out_proj.weight.div_(div_is_residual)
         | 
| 109 | 
             
                    if module.out_proj.bias is not None:
         | 
| 110 | 
             
                        torch.nn.init.zeros_(module.out_proj.bias)
         | 
| 111 | 
            +
                elif te is not None and isinstance(module, te.LayerNormMLP):
         | 
| 112 | 
            +
                    if isinstance(module.layer_norm_weight, torch.Tensor):
         | 
| 113 | 
            +
                        torch.nn.init.ones_(module.layer_norm_weight)
         | 
| 114 | 
            +
                    if isinstance(module.layer_norm_bias, torch.Tensor):
         | 
| 115 | 
            +
                        torch.nn.init.zeros_(module.layer_norm_bias)
         | 
| 116 | 
            +
                    init_fn_(module.fc1_weight)
         | 
| 117 | 
            +
                    if module.fc1_bias is not None:
         | 
| 118 | 
            +
                        assert isinstance(module.fc1_bias, torch.Tensor)
         | 
| 119 | 
            +
                        torch.nn.init.zeros_(module.fc1_bias)
         | 
| 120 | 
            +
                    init_fn_(module.fc2_weight)
         | 
| 121 | 
            +
                    if module.fc2_bias is not None:
         | 
| 122 | 
            +
                        assert isinstance(module.fc2_bias, torch.Tensor)
         | 
| 123 | 
            +
                        torch.nn.init.zeros_(module.fc2_bias)
         | 
| 124 | 
            +
                    with torch.no_grad():
         | 
| 125 | 
            +
                        module.fc2_weight.div_(div_is_residual)
         | 
| 126 | 
             
                else:
         | 
| 127 | 
             
                    for _ in module.parameters(recurse=False):
         | 
| 128 | 
             
                        raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
         | 
| 129 |  | 
| 130 | 
            +
            def _normal_init_(std: float, mean: float=0.0) -> Callable:
         | 
| 131 | 
             
                return partial(torch.nn.init.normal_, mean=mean, std=std)
         | 
| 132 |  | 
| 133 | 
            +
            def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
         | 
| 134 | 
             
                del kwargs
         | 
| 135 | 
             
                init_fn_ = _normal_init_(std=std)
         | 
| 136 | 
            +
                generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
|  | |
|  | |
| 137 |  | 
| 138 | 
            +
            def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
         | 
| 139 | 
             
                del kwargs
         | 
| 140 | 
             
                if init_std is None:
         | 
| 141 | 
             
                    raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
         | 
| 142 | 
            +
                _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
| 143 |  | 
| 144 | 
            +
            def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
         | 
| 145 | 
             
                del kwargs
         | 
| 146 | 
             
                std = math.sqrt(2 / (5 * d_model))
         | 
| 147 | 
            +
                _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
| 148 |  | 
| 149 | 
            +
            def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
         | 
| 150 | 
             
                """From section 2.3.1 of GPT-NeoX-20B:
         | 
| 151 |  | 
| 152 | 
             
                An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
         | 
|  | |
| 155 | 
             
                """
         | 
| 156 | 
             
                del kwargs
         | 
| 157 | 
             
                residual_div = n_layers / math.sqrt(10)
         | 
| 158 | 
            +
                small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
|  | |
|  | |
| 159 |  | 
| 160 | 
            +
            def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
         | 
| 161 | 
             
                del kwargs
         | 
|  | |
|  | |
| 162 | 
             
                kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
         | 
| 163 | 
            +
                generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
| 164 |  | 
| 165 | 
            +
            def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
         | 
| 166 | 
             
                del kwargs
         | 
|  | |
|  | |
| 167 | 
             
                kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
         | 
| 168 | 
            +
                generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
| 169 |  | 
| 170 | 
            +
            def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
         | 
| 171 | 
             
                del kwargs
         | 
| 172 | 
             
                xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
         | 
| 173 | 
            +
                generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
|  | |
|  | |
| 174 |  | 
| 175 | 
            +
            def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
         | 
| 176 | 
            +
                del kwargs
         | 
| 177 | 
             
                xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
         | 
| 178 | 
            +
                generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
         | 
|  | |
|  | |
| 179 | 
             
            MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
         | 
