Commit
·
6f7b340
1
Parent(s):
ade58fc
update for pipeline
Browse files- config.json +2 -1
- modeling_hhem_v2.py +9 -2
config.json
CHANGED
|
@@ -8,5 +8,6 @@
|
|
| 8 |
},
|
| 9 |
"model_type": "HHEMv2Config",
|
| 10 |
"torch_dtype": "float32",
|
| 11 |
-
"transformers_version": "4.39.3"
|
|
|
|
| 12 |
}
|
|
|
|
| 8 |
},
|
| 9 |
"model_type": "HHEMv2Config",
|
| 10 |
"torch_dtype": "float32",
|
| 11 |
+
"transformers_version": "4.39.3",
|
| 12 |
+
"id2label": {"0": "hallucinated", "1": "consistent"}
|
| 13 |
}
|
modeling_hhem_v2.py
CHANGED
|
@@ -45,8 +45,15 @@ class HHEMv2ForSequenceClassification(PreTrainedModel):
|
|
| 45 |
# combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
|
| 46 |
# self.t5 = combined_model
|
| 47 |
|
| 48 |
-
def forward(self, **kwargs):
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def predict(self, text_pairs):
|
| 52 |
tokenizer = self.tokenzier
|
|
|
|
| 45 |
# combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
|
| 46 |
# self.t5 = combined_model
|
| 47 |
|
| 48 |
+
def forward(self, **kwargs): # To cope with `text-classiication` pipeline
|
| 49 |
+
self.t5.eval()
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
outputs = self.t5(**kwargs)
|
| 52 |
+
logits = outputs.logits
|
| 53 |
+
logits = logits[:, 0, :]
|
| 54 |
+
outputs.logits = logits
|
| 55 |
+
return outputs
|
| 56 |
+
# return self.t5(**kwargs)
|
| 57 |
|
| 58 |
def predict(self, text_pairs):
|
| 59 |
tokenizer = self.tokenzier
|