Since `transformers` v4.56.0` the dictionary `ALL_STATIC_CACHE_IMPLEMENTATIONS` replaced `NEED_SETUP_CACHE_CLASSES_MAPPING`
#9
by
blewis-hir
- opened
- modeling_decilm.py +14 -2
modeling_decilm.py
CHANGED
|
@@ -19,15 +19,23 @@
|
|
| 19 |
# limitations under the License.
|
| 20 |
|
| 21 |
import math
|
|
|
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
|
|
|
|
| 24 |
import torch
|
| 25 |
import torch.nn.functional as F
|
| 26 |
import torch.utils.checkpoint
|
| 27 |
from torch import nn
|
| 28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
from transformers import GenerationConfig
|
| 30 |
-
from transformers.generation.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 33 |
from transformers.utils import (
|
|
@@ -810,7 +818,10 @@ class DeciLMPreTrainedModel(PreTrainedModel):
|
|
| 810 |
# DeciLM-specific code
|
| 811 |
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
|
| 812 |
generation_config.cache_implementation = "variable"
|
| 813 |
-
|
|
|
|
|
|
|
|
|
|
| 814 |
return generation_config, model_kwargs
|
| 815 |
|
| 816 |
|
|
@@ -1148,6 +1159,7 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
|
|
| 1148 |
output_hidden_states: Optional[bool] = None,
|
| 1149 |
return_dict: Optional[bool] = None,
|
| 1150 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
| 1151 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1152 |
r"""
|
| 1153 |
Args:
|
|
|
|
| 19 |
# limitations under the License.
|
| 20 |
|
| 21 |
import math
|
| 22 |
+
import importlib.metdata
|
| 23 |
from typing import List, Optional, Tuple, Union
|
| 24 |
|
| 25 |
+
from packaging.version import Version
|
| 26 |
import torch
|
| 27 |
import torch.nn.functional as F
|
| 28 |
import torch.utils.checkpoint
|
| 29 |
from torch import nn
|
| 30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 31 |
from transformers import GenerationConfig
|
| 32 |
+
from transformers.generation.utils import GenerationMixin, GenerateOutput
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if Version(importlib.metadata.version("transformers")) <= Version("4.56.0.dev0")
|
| 36 |
+
from transformers.generation.configuration_utils import NEED_SETUP_CACHE_CLASSES_MAPPING
|
| 37 |
+
else:
|
| 38 |
+
from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
|
| 39 |
from transformers.modeling_utils import PreTrainedModel
|
| 40 |
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 41 |
from transformers.utils import (
|
|
|
|
| 818 |
# DeciLM-specific code
|
| 819 |
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
|
| 820 |
generation_config.cache_implementation = "variable"
|
| 821 |
+
if transformers_version <= Version("4.56.0.dev0")
|
| 822 |
+
NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
|
| 823 |
+
else:
|
| 824 |
+
ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
|
| 825 |
return generation_config, model_kwargs
|
| 826 |
|
| 827 |
|
|
|
|
| 1159 |
output_hidden_states: Optional[bool] = None,
|
| 1160 |
return_dict: Optional[bool] = None,
|
| 1161 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1162 |
+
**kwargs,
|
| 1163 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1164 |
r"""
|
| 1165 |
Args:
|