Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import timm | |
| class Model(nn.Module): | |
| def __init__(self, model_name, pretrained=True): | |
| super(Model, self).__init__() | |
| # Load the pretrained ConvNeXt model (you can choose the specific variant you want) | |
| self.model = timm.create_model(model_name, pretrained=pretrained) | |
| self.model.head.fc = nn.Linear(self.model.head.fc.in_features, 1) # change the last linear for classification | |
| def forward(self, x): | |
| return self.model(x) |