| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import RobertaPreTrainedModel, RobertaModel | |
| from .configuration_emoaxis import EmoAxisConfig | |
| class EmoAxis(RobertaPreTrainedModel): | |
| config_class = EmoAxisConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.roberta = RobertaModel(config, add_pooling_layer=False) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(config.hidden_size, 512), | |
| nn.LayerNorm(512), | |
| nn.GELU(), | |
| nn.Dropout(0.25), | |
| nn.Linear(512, config.num_classes) | |
| ) | |
| self.post_init() | |
| def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
| outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
| last_hidden_state = outputs.hidden_states[-1] | |
| mask = attention_mask.unsqueeze(-1).float() | |
| text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) | |
| text_emb = F.normalize(text_emb, p=2, dim=1) | |
| logits = self.mlp(text_emb) | |
| return text_emb, logits | |