|
|
from typing import Any |
|
|
|
|
|
import torch.nn as nn |
|
|
from transformers import AlbertModel, AutoConfig |
|
|
from transformers import PreTrainedModel, AutoModel, AutoConfig |
|
|
|
|
|
|
|
|
class ProtAlBertModel(PreTrainedModel): |
|
|
def __init__(self, config: Any, num_labels: int, hidden_size: int, model_name: str, *args, **kwargs): |
|
|
""" |
|
|
Initialise the model. |
|
|
:param: config class for the PreTrainedModel class |
|
|
:param hidden_size: size of the hidden layer after the CLS token. |
|
|
:param num_labels: the number of labels. |
|
|
:param model_name: the name of the model. |
|
|
""" |
|
|
super().__init__(config) |
|
|
albert_config = AutoConfig.from_pretrained(model_name) |
|
|
self.config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.protbert = AlbertModel.from_pretrained( |
|
|
model_name, config=self.config, trust_remote_code=True |
|
|
) |
|
|
self.last_layer = nn.Sequential( |
|
|
nn.Dropout(0.1), |
|
|
nn.LayerNorm(self.config.hidden_size), |
|
|
nn.Linear(self.config.hidden_size, hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_size, num_labels), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
z = self.protbert(**x).last_hidden_state[:, 0, :] |
|
|
output = self.last_layer(z) |
|
|
return {"logits": output} |