prot_albert_pfam / prot_albert_model.py
sayby's picture
Upload model
13fc001 verified
raw
history blame
1.41 kB
from typing import Any
import torch.nn as nn
from transformers import AlbertModel, AutoConfig
from transformers import PreTrainedModel, AutoModel, AutoConfig
class ProtAlBertModel(PreTrainedModel):
def __init__(self, config: Any, num_labels: int, hidden_size: int, model_name: str, *args, **kwargs):
"""
Initialise the model.
:param: config class for the PreTrainedModel class
:param hidden_size: size of the hidden layer after the CLS token.
:param num_labels: the number of labels.
:param model_name: the name of the model.
"""
super().__init__(config)
albert_config = AutoConfig.from_pretrained(model_name)
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True,
)
self.protbert = AlbertModel.from_pretrained(
model_name, config=self.config, trust_remote_code=True
)
self.last_layer = nn.Sequential(
nn.Dropout(0.1),
nn.LayerNorm(self.config.hidden_size),
nn.Linear(self.config.hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, num_labels),
)
def forward(self, x):
# Take the last embedding of the [CLS] token
z = self.protbert(**x).last_hidden_state[:, 0, :]
output = self.last_layer(z)
return {"logits": output}