Spaces:
Running
on
Zero
Running
on
Zero
| """Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583 | |
| """ | |
| import torch | |
| from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names | |
| from ...core import register | |
| from .utils import IntermediateLayerGetter | |
| class TimmModel(torch.nn.Module): | |
| def __init__( | |
| self, name, return_layers, pretrained=False, exportable=True, features_only=True, **kwargs | |
| ) -> None: | |
| super().__init__() | |
| import timm | |
| model = timm.create_model( | |
| name, | |
| pretrained=pretrained, | |
| exportable=exportable, | |
| features_only=features_only, | |
| **kwargs, | |
| ) | |
| # nodes, _ = get_graph_node_names(model) | |
| # print(nodes) | |
| # features = {'': ''} | |
| # model = create_feature_extractor(model, return_nodes=features) | |
| assert set(return_layers).issubset( | |
| model.feature_info.module_name() | |
| ), f"return_layers should be a subset of {model.feature_info.module_name()}" | |
| # self.model = model | |
| self.model = IntermediateLayerGetter(model, return_layers) | |
| return_idx = [model.feature_info.module_name().index(name) for name in return_layers] | |
| self.strides = [model.feature_info.reduction()[i] for i in return_idx] | |
| self.channels = [model.feature_info.channels()[i] for i in return_idx] | |
| self.return_idx = return_idx | |
| self.return_layers = return_layers | |
| def forward(self, x: torch.Tensor): | |
| outputs = self.model(x) | |
| # outputs = [outputs[i] for i in self.return_idx] | |
| return outputs | |
| if __name__ == "__main__": | |
| model = TimmModel(name="resnet34", return_layers=["layer2", "layer3"]) | |
| data = torch.rand(1, 3, 640, 640) | |
| outputs = model(data) | |
| for output in outputs: | |
| print(output.shape) | |
| """ | |
| model: | |
| type: TimmModel | |
| name: resnet34 | |
| return_layers: ['layer2', 'layer4'] | |
| """ | |