| from transformers import PreTrainedModel, PretrainedConfig | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| class BiLSTMConfig(PretrainedConfig): | |
| model_type = "bilstm" | |
| def __init__(self, vocab_size=64000, embedding_dim=1024, hidden_dim=512, num_labels=3, **kwargs): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.embedding_dim = embedding_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_labels = num_labels | |
| class BiLSTMClassifier(PreTrainedModel): | |
| config_class = BiLSTMConfig | |
| def __init__(self, config: BiLSTMConfig): | |
| super().__init__(config) | |
| self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim) | |
| self.lstm = nn.LSTM(config.embedding_dim, config.hidden_dim, batch_first=True, bidirectional=True) | |
| self.fc = nn.Linear(config.hidden_dim * 2, config.num_labels) | |
| self.post_init() | |
| def forward(self, input_ids, attention_mask=None, labels=None): | |
| x = self.embedding(input_ids) | |
| _, (h_n, _) = self.lstm(x) | |
| h_cat = torch.cat((h_n[0], h_n[1]), dim=1) | |
| logits = self.fc(h_cat) | |
| if labels is not None: | |
| loss = F.cross_entropy(logits, labels) | |
| return {"loss": loss, "logits": logits} | |
| return {"logits": logits} | |