Upload modeling_nemotron_h.py
Browse files- modeling_nemotron_h.py +11 -11
    	
        modeling_nemotron_h.py
    CHANGED
    
    | @@ -24,25 +24,25 @@ import torch.utils.checkpoint | |
| 24 | 
             
            from torch import nn
         | 
| 25 | 
             
            from torch.nn import CrossEntropyLoss
         | 
| 26 |  | 
| 27 | 
            -
            from  | 
| 28 | 
            -
            from  | 
| 29 | 
            -
            from  | 
| 30 | 
            -
            from  | 
| 31 | 
             
                AttentionMaskConverter,
         | 
| 32 | 
             
            )
         | 
| 33 | 
            -
            from  | 
| 34 | 
            -
            from  | 
| 35 | 
             
                ModelOutput,
         | 
| 36 | 
             
                add_code_sample_docstrings,
         | 
| 37 | 
             
                add_start_docstrings,
         | 
| 38 | 
             
                add_start_docstrings_to_model_forward,
         | 
| 39 | 
             
                logging,
         | 
| 40 | 
             
            )
         | 
| 41 | 
            -
            from  | 
| 42 | 
             
                is_causal_conv1d_available,
         | 
| 43 | 
             
                is_flash_attn_2_available,
         | 
| 44 | 
             
                is_flash_attn_greater_or_equal_2_10,
         | 
| 45 | 
            -
                is_mamba_2_ssm_available, | 
| 46 | 
             
            )
         | 
| 47 | 
             
            from .configuration_nemotron_h import NemotronHConfig
         | 
| 48 |  | 
| @@ -70,7 +70,7 @@ else: | |
| 70 | 
             
                causal_conv1d_update, causal_conv1d_fn = None, None
         | 
| 71 |  | 
| 72 | 
             
            if is_flash_attn_2_available():
         | 
| 73 | 
            -
                from  | 
| 74 |  | 
| 75 | 
             
            is_fast_path_available = all(
         | 
| 76 | 
             
                (
         | 
| @@ -844,8 +844,8 @@ class NemotronHAttention(nn.Module): | |
| 844 | 
             
                    self.attention_dropout = config.attention_dropout
         | 
| 845 | 
             
                    self.hidden_size = config.hidden_size
         | 
| 846 | 
             
                    self.num_heads = config.num_attention_heads
         | 
| 847 | 
            -
                    if config. | 
| 848 | 
            -
                        self.head_dim = config. | 
| 849 | 
             
                    else:
         | 
| 850 | 
             
                        self.head_dim = config.hidden_size // config.num_attention_heads
         | 
| 851 | 
             
                    self.num_key_value_heads = config.num_key_value_heads
         | 
|  | |
| 24 | 
             
            from torch import nn
         | 
| 25 | 
             
            from torch.nn import CrossEntropyLoss
         | 
| 26 |  | 
| 27 | 
            +
            from transformers.activations import ACT2FN
         | 
| 28 | 
            +
            from transformers.cache_utils import DynamicCache  # we need __iter__ and __len__ of pkv
         | 
| 29 | 
            +
            from transformers.generation import GenerationMixin
         | 
| 30 | 
            +
            from transformers.modeling_attn_mask_utils import (
         | 
| 31 | 
             
                AttentionMaskConverter,
         | 
| 32 | 
             
            )
         | 
| 33 | 
            +
            from transformers.modeling_utils import PreTrainedModel
         | 
| 34 | 
            +
            from transformers.utils import (
         | 
| 35 | 
             
                ModelOutput,
         | 
| 36 | 
             
                add_code_sample_docstrings,
         | 
| 37 | 
             
                add_start_docstrings,
         | 
| 38 | 
             
                add_start_docstrings_to_model_forward,
         | 
| 39 | 
             
                logging,
         | 
| 40 | 
             
            )
         | 
| 41 | 
            +
            from transformers.utils.import_utils import (
         | 
| 42 | 
             
                is_causal_conv1d_available,
         | 
| 43 | 
             
                is_flash_attn_2_available,
         | 
| 44 | 
             
                is_flash_attn_greater_or_equal_2_10,
         | 
| 45 | 
            +
                is_mamba_2_ssm_available,
         | 
| 46 | 
             
            )
         | 
| 47 | 
             
            from .configuration_nemotron_h import NemotronHConfig
         | 
| 48 |  | 
|  | |
| 70 | 
             
                causal_conv1d_update, causal_conv1d_fn = None, None
         | 
| 71 |  | 
| 72 | 
             
            if is_flash_attn_2_available():
         | 
| 73 | 
            +
                from transformers.modeling_flash_attention_utils import _flash_attention_forward
         | 
| 74 |  | 
| 75 | 
             
            is_fast_path_available = all(
         | 
| 76 | 
             
                (
         | 
|  | |
| 844 | 
             
                    self.attention_dropout = config.attention_dropout
         | 
| 845 | 
             
                    self.hidden_size = config.hidden_size
         | 
| 846 | 
             
                    self.num_heads = config.num_attention_heads
         | 
| 847 | 
            +
                    if config.head_dim is not None:
         | 
| 848 | 
            +
                        self.head_dim = config.head_dim
         | 
| 849 | 
             
                    else:
         | 
| 850 | 
             
                        self.head_dim = config.hidden_size // config.num_attention_heads
         | 
| 851 | 
             
                    self.num_key_value_heads = config.num_key_value_heads
         | 
