| from segmentation_models_pytorch.encoders import encoders | |
| from segmentation_models_pytorch import Unet | |
| import torch | |
| # Override pretrained settings for your weights | |
| encoders["resnet101"]["pretrained_settings"]["micronet"] = { | |
| "url": "https://huggingface.co/jstuckner/microscopy-resnet101-micronet/resolve/main/resnet101_micronet_weights.pth", | |
| "input_space": "RGB", | |
| "input_range": [0, 1], | |
| "mean": [0.485, 0.456, 0.406], | |
| "std": [0.229, 0.224, 0.225], | |
| } | |
| # Use as normal | |
| model = Unet( | |
| encoder_name="resnet101", | |
| encoder_weights="micronet", | |
| classes=1, | |
| activation=None, | |
| ) | |
| # Test input | |
| x = torch.randn(1, 3, 256, 256) | |
| with torch.no_grad(): | |
| y = model(x) | |
| print("Output shape:", y.shape) | |