taegyeonglee commited on
Commit
1810b0b
·
verified ·
1 Parent(s): 6940c8b

Add HF-standard offline package (auto_map + modeling_kbert_mtl.py)

Browse files
Files changed (1) hide show
  1. modeling_kbert_mtl.py +14 -13
modeling_kbert_mtl.py CHANGED
@@ -1,28 +1,29 @@
1
  # modeling_kbert_mtl.py
2
  import torch
3
  import torch.nn as nn
4
- from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
5
 
6
- def _config_from_base_dict(base_cfg_dict: dict):
7
  if base_cfg_dict is None:
8
  raise ValueError("config.base_model_config is required for offline load.")
9
- model_type = "bert"
10
- try:
11
- kwargs = {k: v for k, v in base_cfg_dict.items() if k != "model_type"}
12
- cfg = AutoConfig.for_model(model_type, **kwargs)
13
- except Exception:
14
- cfg = BertConfig(**{k: v for k, v in base_cfg_dict.items() if k != "model_type"})
15
- return cfg
 
16
 
17
  class KbertMTL(PreTrainedModel):
18
  config_class = BertConfig
19
 
20
  def __init__(self, config):
21
  super().__init__(config)
22
- base_cfg_dict = getattr(config, "base_model_config", None)
23
- base_cfg = _config_from_base_dict(base_cfg_dict)
24
 
25
- self.bert = AutoModel.from_config(base_cfg)
 
 
26
 
27
  hidden = self.bert.config.hidden_size
28
  self.head_senti = nn.Linear(hidden, 5)
@@ -38,7 +39,7 @@ class KbertMTL(PreTrainedModel):
38
  if self.has_token_type and token_type_ids is not None:
39
  kw["token_type_ids"] = token_type_ids
40
  out = self.bert(**kw)
41
- h = out.last_hidden_state[:, 0]
42
  return {
43
  "logits_senti": self.head_senti(h),
44
  "logits_act": self.head_act(h),
 
1
  # modeling_kbert_mtl.py
2
  import torch
3
  import torch.nn as nn
4
+ from transformers import PreTrainedModel, BertConfig, BertModel
5
 
6
+ def _bert_config_from_base_dict(base_cfg_dict: dict) -> BertConfig:
7
  if base_cfg_dict is None:
8
  raise ValueError("config.base_model_config is required for offline load.")
9
+
10
+ base_cfg_dict = dict(base_cfg_dict) # shallow copy
11
+ base_cfg_dict["model_type"] = "bert"
12
+
13
+ allowed = set(BertConfig().to_dict().keys())
14
+ kwargs = {k: v for k, v in base_cfg_dict.items() if k in allowed}
15
+
16
+ return BertConfig(**kwargs)
17
 
18
  class KbertMTL(PreTrainedModel):
19
  config_class = BertConfig
20
 
21
  def __init__(self, config):
22
  super().__init__(config)
 
 
23
 
24
+ base_cfg_dict = getattr(config, "base_model_config", None)
25
+ bert_cfg = _bert_config_from_base_dict(base_cfg_dict)
26
+ self.bert = BertModel(bert_cfg)
27
 
28
  hidden = self.bert.config.hidden_size
29
  self.head_senti = nn.Linear(hidden, 5)
 
39
  if self.has_token_type and token_type_ids is not None:
40
  kw["token_type_ids"] = token_type_ids
41
  out = self.bert(**kw)
42
+ h = out.last_hidden_state[:, 0] # [CLS]
43
  return {
44
  "logits_senti": self.head_senti(h),
45
  "logits_act": self.head_act(h),