| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| class SimpleClassifierConfig: | |
| model_type = "simple_classifier" | |
| class SimpleClassifier(PreTrainedModel): | |
| config_class = SimpleClassifierConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.linear1 = nn.Linear(config.input_dim, 256) | |
| self.ln1 = nn.LayerNorm(256) | |
| self.dropout = nn.Dropout(config.p_dropout) | |
| self.linear2 = nn.Linear(256, 128) | |
| self.ln2 = nn.LayerNorm(128) | |
| self.linear_out = nn.Linear(128, config.num_classes) | |
| self.post_init() | |
| def forward(self, x): | |
| x = F.gelu(self.ln1(self.linear1(x))) | |
| x = self.dropout(x) | |
| x = F.gelu(self.ln2(self.linear2(x))) | |
| x = self.dropout(x) | |
| return self.linear_out(x) | |