| | import torch |
| |
|
| | |
| | from transformers import PreTrainedModel, AutoConfig, T5ForTokenClassification, AutoModel, AutoTokenizer, AutoModelForTokenClassification |
| |
|
| | from .configuration_hhem_v2 import HHEMv2Config |
| |
|
| | class HHEMv2Model(PreTrainedModel): |
| | config_class = HHEMv2Config |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | class HHEMv2ForSequenceClassification(PreTrainedModel): |
| | config_class = HHEMv2Config |
| | |
| | def __init__(self, config=HHEMv2Config()): |
| | super().__init__(config) |
| | self.t5 = T5ForTokenClassification( |
| | AutoConfig.from_pretrained(config.foundation) |
| | ) |
| | self.prompt = config.prompt |
| | self.tokenzier = AutoTokenizer.from_pretrained(config.foundation) |
| |
|
| | def populate(self, model: AutoModel): |
| | """Initiate the model with the pretrained model |
| | |
| | This method should only be called by Vectara employee who prepares the model for publishing. Users do not need to call this method. |
| | |
| | """ |
| | self.t5 = model |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, **kwargs): |
| | self.t5.eval() |
| | with torch.no_grad(): |
| | outputs = self.t5(**kwargs) |
| | logits = outputs.logits |
| | logits = logits[:, 0, :] |
| | outputs.logits = logits |
| | return outputs |
| | |
| |
|
| | def predict(self, text_pairs): |
| | tokenizer = self.tokenzier |
| | pair_dict = [{'text1': pair[0], 'text2': pair[1]} for pair in text_pairs] |
| | inputs = tokenizer( |
| | [self.prompt.format(**pair) for pair in pair_dict], return_tensors='pt', padding=True).to(self.t5.device) |
| | self.t5.eval() |
| | with torch.no_grad(): |
| | outputs = self.t5(**inputs) |
| | logits = outputs.logits |
| | logits = logits[:, 0, :] |
| | transformed_probs = torch.softmax(logits, dim=-1) |
| | raw_scores = transformed_probs[:, 1] |
| | return raw_scores |
| |
|
| | |
| |
|