| from transformers import PreTrainedModel | |
| from .unet3d import UNet, UNetDeepSup | |
| from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig | |
| class UNet3D(PreTrainedModel): | |
| config_class = UNet3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UNet( | |
| in_ch=config.in_ch, | |
| out_ch=config.out_ch, | |
| init_features=config.init_features, | |
| dropout_rate=config.dropout_rate) | |
| def forward(self, x): | |
| return self.model(x) | |
| class UNetMSS3D(PreTrainedModel): | |
| config_class = UNetMSS3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UNetDeepSup( | |
| in_ch=config.in_ch, | |
| out_ch=config.out_ch, | |
| init_features=config.init_features, | |
| dropout_rate=config.dropout_rate) | |
| def forward(self, x): | |
| return self.model(x) |