File size: 1,409 Bytes
13fc001 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import Any
import torch.nn as nn
from transformers import AlbertModel, AutoConfig
from transformers import PreTrainedModel, AutoModel, AutoConfig
class ProtAlBertModel(PreTrainedModel):
def __init__(self, config: Any, num_labels: int, hidden_size: int, model_name: str, *args, **kwargs):
"""
Initialise the model.
:param: config class for the PreTrainedModel class
:param hidden_size: size of the hidden layer after the CLS token.
:param num_labels: the number of labels.
:param model_name: the name of the model.
"""
super().__init__(config)
albert_config = AutoConfig.from_pretrained(model_name)
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True,
)
self.protbert = AlbertModel.from_pretrained(
model_name, config=self.config, trust_remote_code=True
)
self.last_layer = nn.Sequential(
nn.Dropout(0.1),
nn.LayerNorm(self.config.hidden_size),
nn.Linear(self.config.hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, num_labels),
)
def forward(self, x):
# Take the last embedding of the [CLS] token
z = self.protbert(**x).last_hidden_state[:, 0, :]
output = self.last_layer(z)
return {"logits": output} |