Since `transformers` v4.56.0` the dictionary `ALL_STATIC_CACHE_IMPLEMENTATIONS` replaced `NEED_SETUP_CACHE_CLASSES_MAPPING`

#9
Files changed (1) hide show
  1. modeling_decilm.py +14 -2
modeling_decilm.py CHANGED
@@ -19,15 +19,23 @@
19
  # limitations under the License.
20
 
21
  import math
 
22
  from typing import List, Optional, Tuple, Union
23
 
 
24
  import torch
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
 
 
 
 
 
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
@@ -810,7 +818,10 @@ class DeciLMPreTrainedModel(PreTrainedModel):
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
  generation_config.cache_implementation = "variable"
813
- NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
 
 
 
814
  return generation_config, model_kwargs
815
 
816
 
@@ -1148,6 +1159,7 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1148
  output_hidden_states: Optional[bool] = None,
1149
  return_dict: Optional[bool] = None,
1150
  cache_position: Optional[torch.LongTensor] = None,
 
1151
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1152
  r"""
1153
  Args:
 
19
  # limitations under the License.
20
 
21
  import math
22
+ import importlib.metdata
23
  from typing import List, Optional, Tuple, Union
24
 
25
+ from packaging.version import Version
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
  from transformers import GenerationConfig
32
+ from transformers.generation.utils import GenerationMixin, GenerateOutput
33
+
34
+
35
+ if Version(importlib.metadata.version("transformers")) <= Version("4.56.0.dev0")
36
+ from transformers.generation.configuration_utils import NEED_SETUP_CACHE_CLASSES_MAPPING
37
+ else:
38
+ from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
39
  from transformers.modeling_utils import PreTrainedModel
40
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
41
  from transformers.utils import (
 
818
  # DeciLM-specific code
819
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
820
  generation_config.cache_implementation = "variable"
821
+ if transformers_version <= Version("4.56.0.dev0")
822
+ NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
823
+ else:
824
+ ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
825
  return generation_config, model_kwargs
826
 
827
 
 
1159
  output_hidden_states: Optional[bool] = None,
1160
  return_dict: Optional[bool] = None,
1161
  cache_position: Optional[torch.LongTensor] = None,
1162
+ **kwargs,
1163
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1164
  r"""
1165
  Args: