SimplePromptClassifier-85k / modeling_simple_classifier
Neweret's picture
Update modeling_simple_classifier
95039fd verified
raw
history blame contribute delete
878 Bytes
# modeling_simple_classifier.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
class SimpleClassifierConfig:
model_type = "simple_classifier"
class SimpleClassifier(PreTrainedModel):
config_class = SimpleClassifierConfig
def __init__(self, config):
super().__init__(config)
self.linear1 = nn.Linear(config.input_dim, 256)
self.ln1 = nn.LayerNorm(256)
self.dropout = nn.Dropout(config.p_dropout)
self.linear2 = nn.Linear(256, 128)
self.ln2 = nn.LayerNorm(128)
self.linear_out = nn.Linear(128, config.num_classes)
self.post_init()
def forward(self, x):
x = F.gelu(self.ln1(self.linear1(x)))
x = self.dropout(x)
x = F.gelu(self.ln2(self.linear2(x)))
x = self.dropout(x)
return self.linear_out(x)