from typing import Any import torch.nn as nn from transformers import AlbertModel, AutoConfig from transformers import PreTrainedModel, AutoModel, AutoConfig class ProtAlBertModel(PreTrainedModel): def __init__(self, config: Any, num_labels: int, hidden_size: int, model_name: str, *args, **kwargs): """ Initialise the model. :param: config class for the PreTrainedModel class :param hidden_size: size of the hidden layer after the CLS token. :param num_labels: the number of labels. :param model_name: the name of the model. """ super().__init__(config) albert_config = AutoConfig.from_pretrained(model_name) self.config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, ) self.protbert = AlbertModel.from_pretrained( model_name, config=self.config, trust_remote_code=True ) self.last_layer = nn.Sequential( nn.Dropout(0.1), nn.LayerNorm(self.config.hidden_size), nn.Linear(self.config.hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, num_labels), ) def forward(self, x): # Take the last embedding of the [CLS] token z = self.protbert(**x).last_hidden_state[:, 0, :] output = self.last_layer(z) return {"logits": output}