Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import numpy as np | |
| from torch.autograd import Variable | |
| import torchvision.models as models | |
| import transformers | |
| import torchvision.transforms | |
| import torchxrayvision as xrv | |
| from transformers import ViTModel, ViTConfig | |
| class VisualFeatureExtractor(nn.Module): | |
| def __init__(self, model_name='densenet201', pretrained=False): | |
| super(VisualFeatureExtractor, self).__init__() | |
| self.model_name = 'chexnet' | |
| self.pretrained = pretrained | |
| self.model, self.out_features, self.avg_func, self.bn, self.linear = self.__get_model() | |
| self.activation = nn.ReLU() | |
| def __get_model(self): | |
| model = None | |
| out_features = None | |
| func = None | |
| if self.model_name == 'resnet152': | |
| resnet = models.resnet152(pretrained=self.pretrained) | |
| modules = list(resnet.children())[:-2] | |
| model = nn.Sequential(*modules) | |
| out_features = resnet.fc.in_features | |
| func = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) | |
| elif self.model_name == 'densenet201': | |
| densenet = models.densenet201(pretrained=self.pretrained) | |
| modules = list(densenet.features) | |
| model = nn.Sequential(*modules) | |
| func = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) | |
| out_features = densenet.classifier.in_features | |
| elif self.model_name == 'chexnet': | |
| print("vit chest xray pretrained model loading") | |
| # Load the Vision Transformer (ViT) model configuration | |
| config = ViTConfig.from_pretrained('nickmuchi/vit-finetuned-chest-xray-pneumonia') | |
| # Initialize the ViT model with the specific configuration | |
| vit_model = ViTModel(config) | |
| # Load the state dict specifically, excluding 'classifier.bias', 'classifier.weight' | |
| state_dict = torch.load('pytorch_model.bin', map_location=torch.device('cpu')) | |
| state_dict = {k: v for k, v in state_dict.items() if not k.startswith('classifier')} | |
| vit_model.load_state_dict(state_dict, strict=False) | |
| model = vit_model | |
| out_features = config.hidden_size | |
| linear = nn.Linear(in_features=out_features, out_features=out_features) | |
| bn = nn.BatchNorm1d(num_features=out_features, momentum=0.1) | |
| return model, out_features, func, bn, linear | |
| def forward(self, images): | |
| """ | |
| :param images: Input images | |
| :return: visual_features, avg_features | |
| """ | |
| model_output = self.model(images) | |
| # Extract the pooler_output | |
| pooler_output = model_output.pooler_output | |
| # Apply the linear layer, batch normalization, and activation | |
| avg_features = self.activation(self.bn(self.linear(pooler_output))) | |
| return model_output.last_hidden_state, avg_features | |
| # def forward(self, images): | |
| # """ | |
| # :param images: | |
| # :return: | |
| # """ | |
| # visual_features = self.model(images) | |
| # avg_features = self.avg_func(visual_features).squeeze() | |
| # # avg_features = self.activation(self.bn(self.linear(visual_features))) | |
| # return visual_features, avg_features | |
| class MLC(nn.Module): | |
| def __init__(self, | |
| classes=210, | |
| sementic_features_dim=512, | |
| fc_in_features=2048, | |
| k=10, | |
| ): | |
| super(MLC, self).__init__() | |
| pretrained_model_name="nickmuchi/vit-finetuned-chest-xray-pneumonia" | |
| vit_config = ViTConfig.from_pretrained(pretrained_model_name) | |
| self.vit = ViTModel(vit_config) | |
| # Adjust the classifier to your number of classes | |
| self.classifier = nn.Linear(in_features=vit_config.hidden_size, out_features=classes) | |
| self.embed = nn.Embedding(classes, sementic_features_dim) | |
| self.k = k | |
| self.sigmoid = nn.Sigmoid() | |
| self.__init_weight() | |
| def __init_weight(self): | |
| nn.init.xavier_uniform_(self.classifier.weight) | |
| if self.classifier.bias is not None: | |
| self.classifier.bias.data.fill_(0) | |
| def forward(self, avg_features): | |
| tags = self.sigmoid(self.classifier(avg_features)) | |
| semantic_features = self.embed(torch.topk(tags, self.k)[1]) | |
| return tags, semantic_features | |
| # class MLC(nn.Module): | |
| # def __init__(self, | |
| # classes=210, | |
| # sementic_features_dim=512, | |
| # fc_in_features=2048, | |
| # k=10): | |
| # super(MLC, self).__init__() | |
| # self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes) | |
| # self.embed = nn.Embedding(classes, sementic_features_dim) | |
| # self.k = k | |
| # self.sigmoid = nn.Sigmoid() | |
| # self.__init_weight() | |
| # def __init_weight(self): | |
| # # Example: Initialize weights with a different strategy | |
| # nn.init.xavier_uniform_(self.classifier.weight) | |
| # if self.classifier.bias is not None: | |
| # self.classifier.bias.data.fill_(0) | |
| # def forward(self, avg_features): | |
| # tags = self.sigmoid(self.classifier(avg_features)) | |
| # semantic_features = self.embed(torch.topk(tags, self.k)[1]) | |
| # return tags, semantic_features | |
| class CoAttention(nn.Module): | |
| def __init__(self, | |
| version='v1', | |
| embed_size=512, | |
| hidden_size=512, | |
| visual_size=2048, | |
| k=10, | |
| momentum=0.1): | |
| super(CoAttention, self).__init__() | |
| self.version = version | |
| self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size) | |
| self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=momentum) | |
| self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size) | |
| self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=momentum) | |
| self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size) | |
| self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=momentum) | |
| self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size) | |
| self.bn_a = nn.BatchNorm1d(num_features=k, momentum=momentum) | |
| self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size) | |
| self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size) | |
| self.bn_a_att = nn.BatchNorm1d(num_features=k, momentum=momentum) | |
| # self.W_fc = nn.Linear(in_features=visual_size, out_features=embed_size) # for v3 | |
| self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size) | |
| self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=momentum) | |
| self.tanh = nn.Tanh() | |
| self.softmax = nn.Softmax() | |
| self.__init_weight() | |
| def __init_weight(self): | |
| self.W_v.weight.data.uniform_(-0.1, 0.1) | |
| self.W_v.bias.data.fill_(0) | |
| self.W_v_h.weight.data.uniform_(-0.1, 0.1) | |
| self.W_v_h.bias.data.fill_(0) | |
| self.W_v_att.weight.data.uniform_(-0.1, 0.1) | |
| self.W_v_att.bias.data.fill_(0) | |
| self.W_a.weight.data.uniform_(-0.1, 0.1) | |
| self.W_a.bias.data.fill_(0) | |
| self.W_a_h.weight.data.uniform_(-0.1, 0.1) | |
| self.W_a_h.bias.data.fill_(0) | |
| self.W_a_att.weight.data.uniform_(-0.1, 0.1) | |
| self.W_a_att.bias.data.fill_(0) | |
| self.W_fc.weight.data.uniform_(-0.1, 0.1) | |
| self.W_fc.bias.data.fill_(0) | |
| def forward(self, avg_features, semantic_features, h_sent): | |
| if self.version == 'v1': | |
| return self.v1(avg_features, semantic_features, h_sent) | |
| elif self.version == 'v2': | |
| return self.v2(avg_features, semantic_features, h_sent) | |
| elif self.version == 'v3': | |
| return self.v3(avg_features, semantic_features, h_sent) | |
| elif self.version == 'v4': | |
| return self.v4(avg_features, semantic_features, h_sent) | |
| elif self.version == 'v5': | |
| return self.v5(avg_features, semantic_features, h_sent) | |
| def v1(self, avg_features, semantic_features, h_sent) -> object: | |
| """ | |
| only training | |
| :rtype: object | |
| """ | |
| W_v = self.bn_v(self.W_v(avg_features)) | |
| W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1))) | |
| alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h)))) | |
| v_att = torch.mul(alpha_v, avg_features) | |
| W_a_h = self.bn_a_h(self.W_a_h(h_sent)) | |
| W_a = self.bn_a(self.W_a(semantic_features)) | |
| alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))) | |
| a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
| ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
| return ctx, alpha_v, alpha_a | |
| def v2(self, avg_features, semantic_features, h_sent) -> object: | |
| """ | |
| no bn | |
| :rtype: object | |
| """ | |
| W_v = self.W_v(avg_features) | |
| W_v_h = self.W_v_h(h_sent.squeeze(1)) | |
| alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h))) | |
| v_att = torch.mul(alpha_v, avg_features) | |
| W_a_h = self.W_a_h(h_sent) | |
| W_a = self.W_a(semantic_features) | |
| alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))) | |
| a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
| ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
| return ctx, alpha_v, alpha_a | |
| def v3(self, avg_features, semantic_features, h_sent) -> object: | |
| """ | |
| :rtype: object | |
| """ | |
| W_v = self.bn_v(self.W_v(avg_features)) | |
| W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1))) | |
| alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h))) | |
| v_att = torch.mul(alpha_v, avg_features) | |
| W_a_h = self.bn_a_h(self.W_a_h(h_sent)) | |
| W_a = self.bn_a(self.W_a(semantic_features)) | |
| alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))) | |
| a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
| ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
| return ctx, alpha_v, alpha_a | |
| def v4(self, avg_features, semantic_features, h_sent): | |
| W_v = self.W_v(avg_features) | |
| W_v_h = self.W_v_h(h_sent.squeeze(1)) | |
| alpha_v = self.softmax(self.W_v_att(self.tanh(torch.add(W_v, W_v_h)))) | |
| v_att = torch.mul(alpha_v, avg_features) | |
| W_a_h = self.W_a_h(h_sent) | |
| W_a = self.W_a(semantic_features) | |
| alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))) | |
| a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
| ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
| return ctx, alpha_v, alpha_a | |
| def v5(self, avg_features, semantic_features, h_sent): | |
| W_v = self.W_v(avg_features) | |
| W_v_h = self.W_v_h(h_sent.squeeze(1)) | |
| alpha_v = self.softmax(self.W_v_att(self.tanh(self.bn_v(torch.add(W_v, W_v_h))))) | |
| v_att = torch.mul(alpha_v, avg_features) | |
| W_a_h = self.W_a_h(h_sent) | |
| W_a = self.W_a(semantic_features) | |
| alpha_a = self.softmax(self.W_a_att(self.tanh(self.bn_a(torch.add(W_a_h, W_a))))) | |
| a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
| ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
| return ctx, alpha_v, alpha_a | |
| class SentenceLSTM(nn.Module): | |
| def __init__(self, | |
| version='v1', | |
| embed_size=512, | |
| hidden_size=512, | |
| num_layers=1, | |
| dropout=0.3, | |
| momentum=0.1): | |
| super(SentenceLSTM, self).__init__() | |
| self.version = version | |
| self.lstm = nn.LSTM(input_size=embed_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout=dropout) | |
| self.W_t_h = nn.Linear(in_features=hidden_size, | |
| out_features=embed_size, | |
| bias=True) | |
| self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.W_t_ctx = nn.Linear(in_features=embed_size, | |
| out_features=embed_size, | |
| bias=True) | |
| self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.W_stop_s_1 = nn.Linear(in_features=hidden_size, | |
| out_features=embed_size, | |
| bias=True) | |
| self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.W_stop_s = nn.Linear(in_features=hidden_size, | |
| out_features=embed_size, | |
| bias=True) | |
| self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.W_stop = nn.Linear(in_features=embed_size, | |
| out_features=2, | |
| bias=True) | |
| self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.W_topic = nn.Linear(in_features=embed_size, | |
| out_features=embed_size, | |
| bias=True) | |
| self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
| self.sigmoid = nn.Sigmoid() | |
| self.tanh = nn.Tanh() | |
| self.__init_weight() | |
| def __init_weight(self): | |
| self.W_t_h.weight.data.uniform_(-0.1, 0.1) | |
| self.W_t_h.bias.data.fill_(0) | |
| self.W_t_ctx.weight.data.uniform_(-0.1, 0.1) | |
| self.W_t_ctx.bias.data.fill_(0) | |
| self.W_stop_s_1.weight.data.uniform_(-0.1, 0.1) | |
| self.W_stop_s_1.bias.data.fill_(0) | |
| self.W_stop_s.weight.data.uniform_(-0.1, 0.1) | |
| self.W_stop_s.bias.data.fill_(0) | |
| self.W_stop.weight.data.uniform_(-0.1, 0.1) | |
| self.W_stop.bias.data.fill_(0) | |
| self.W_topic.weight.data.uniform_(-0.1, 0.1) | |
| self.W_topic.bias.data.fill_(0) | |
| def forward(self, ctx, prev_hidden_state, states=None) -> object: | |
| """ | |
| :rtype: object | |
| """ | |
| if self.version == 'v1': | |
| return self.v1(ctx, prev_hidden_state, states) | |
| elif self.version == 'v2': | |
| return self.v2(ctx, prev_hidden_state, states) | |
| elif self.version == 'v3': | |
| return self.v3(ctx, prev_hidden_state, states) | |
| def v1(self, ctx, prev_hidden_state, states=None): | |
| """ | |
| v1 (only training) | |
| :param ctx: | |
| :param prev_hidden_state: | |
| :param states: | |
| :return: | |
| """ | |
| ctx = ctx.unsqueeze(1) | |
| hidden_state, states = self.lstm(ctx, states) | |
| topic = self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state)) | |
| + self.bn_t_ctx(self.W_t_ctx(ctx)))) | |
| p_stop = self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state)) | |
| + self.bn_stop_s(self.W_stop_s(hidden_state)))) | |
| return topic, p_stop, hidden_state, states | |
| def v2(self, ctx, prev_hidden_state, states=None): | |
| """ | |
| v2 | |
| :rtype: object | |
| """ | |
| ctx = ctx.unsqueeze(1) | |
| hidden_state, states = self.lstm(ctx, states) | |
| topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state) | |
| + self.W_t_ctx(ctx))))) | |
| p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state) | |
| + self.W_stop_s(hidden_state))))) | |
| return topic, p_stop, hidden_state, states | |
| def v3(self, ctx, prev_hidden_state, states=None): | |
| """ | |
| v3 | |
| :rtype: object | |
| """ | |
| ctx = ctx.unsqueeze(1) | |
| hidden_state, states = self.lstm(ctx, states) | |
| topic = self.W_topic(self.tanh(self.W_t_h(hidden_state) + self.W_t_ctx(ctx))) | |
| p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_hidden_state) + self.W_stop_s(hidden_state))) | |
| return topic, p_stop, hidden_state, states | |
| class WordLSTM(nn.Module): | |
| def __init__(self, | |
| embed_size, | |
| hidden_size, | |
| vocab_size, | |
| num_layers, | |
| n_max=50): | |
| super(WordLSTM, self).__init__() | |
| self.embed = nn.Embedding(vocab_size, embed_size) | |
| self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) | |
| self.linear = nn.Linear(hidden_size, vocab_size) | |
| self.__init_weights() | |
| self.n_max = n_max | |
| self.vocab_size = vocab_size | |
| def __init_weights(self): | |
| self.embed.weight.data.uniform_(-0.1, 0.1) | |
| self.linear.weight.data.uniform_(-0.1, 0.1) | |
| self.linear.bias.data.fill_(0) | |
| def forward(self, topic_vec, captions): | |
| embeddings = self.embed(captions) | |
| embeddings = torch.cat((topic_vec, embeddings), 1) | |
| hidden, _ = self.lstm(embeddings) | |
| outputs = self.linear(hidden[:, -1, :]) | |
| return outputs | |
| def sample(self, features, start_tokens): | |
| sampled_ids = np.zeros((np.shape(features)[0], self.n_max)) | |
| sampled_ids[:, 0] = start_tokens.view(-1, ) | |
| predicted = start_tokens | |
| embeddings = features | |
| embeddings = embeddings | |
| for i in range(1, self.n_max): | |
| predicted = self.embed(predicted) | |
| embeddings = torch.cat([embeddings, predicted], dim=1) | |
| hidden_states, _ = self.lstm(embeddings) | |
| hidden_states = hidden_states[:, -1, :] | |
| outputs = self.linear(hidden_states) | |
| predicted = torch.max(outputs, 1)[1] | |
| sampled_ids[:, i] = predicted | |
| predicted = predicted.unsqueeze(1) | |
| return sampled_ids | |
| if __name__ == '__main__': | |
| import torchvision.transforms as transforms | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # | |
| extractor = VisualFeatureExtractor(model_name='resnet152') | |
| mlc = MLC(fc_in_features=extractor.out_features) | |
| co_att = CoAttention(visual_size=extractor.out_features) | |
| sent_lstm = SentenceLSTM() | |
| word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1) | |
| images = torch.randn((4, 3, 224, 224)) | |
| captions = torch.ones((4, 10)).long() | |
| hidden_state = torch.randn((4, 1, 512)) | |
| # # image_file = '../data/images/CXR2814_IM-1239-1001.png' | |
| # # # images = Image.open(image_file).convert('RGB') | |
| # # # captions = torch.ones((1, 10)).long() | |
| # # # hidden_state = torch.randn((10, 512)) | |
| # # | |
| # norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| # | |
| # transform = transforms.Compose([ | |
| # transforms.Resize(256), | |
| # transforms.TenCrop(224), | |
| # transforms.Lambda(lambda crops: torch.stack([norm(transforms.ToTensor()(crop)) for crop in crops])), | |
| # ]) | |
| # images = transform(images) | |
| # images.unsqueeze_(0) | |
| # | |
| # # bs, ncrops, c, h, w = images.size() | |
| # # images = images.view(-1, c, h, w) | |
| # | |
| print("images:{}".format(images.shape)) | |
| print("captions:{}".format(captions.shape)) | |
| print("hidden_states:{}".format(hidden_state.shape)) | |
| visual_features, avg_features = extractor.forward(images) | |
| print("visual_features:{}".format(visual_features.shape)) | |
| print("avg features:{}".format(avg_features.shape)) | |
| tags, semantic_features = mlc.forward(avg_features) | |
| print("tags:{}".format(tags.shape)) | |
| print("semantic_features:{}".format(semantic_features.shape)) | |
| ctx, alpht_v, alpht_a = co_att.forward(avg_features, semantic_features, hidden_state) | |
| print("ctx:{}".format(ctx.shape)) | |
| print("alpht_v:{}".format(alpht_v.shape)) | |
| print("alpht_a:{}".format(alpht_a.shape)) | |
| topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state) | |
| # p_stop_avg = p_stop.view(bs, ncrops, -1).mean(1) | |
| print("Topic:{}".format(topic.shape)) | |
| print("P_STOP:{}".format(p_stop.shape)) | |
| # print("P_stop_avg:{}".format(p_stop_avg.shape)) | |
| words = word_lstm.forward(topic, captions) | |
| print("words:{}".format(words.shape)) | |
| cam = torch.mul(visual_features, alpht_v.view(alpht_v.shape[0], alpht_v.shape[1], 1, 1)).sum(1) | |
| cam.squeeze_() | |
| cam = cam.cpu().data.numpy() | |
| for i in range(cam.shape[0]): | |
| heatmap = cam[i] | |
| heatmap = heatmap / np.max(heatmap) | |
| print(heatmap.shape) | |