Update modeling_internimage.py
#2
by
parakh01
- opened
- modeling_internimage.py +55 -15
modeling_internimage.py
CHANGED
|
@@ -800,23 +800,31 @@ class InternImage(nn.Module):
|
|
| 800 |
'pooler_output': x if self.num_classes > 0 else None
|
| 801 |
}
|
| 802 |
|
| 803 |
-
def forward(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
if self.use_clip_projector: # for InternImage-H/G
|
| 805 |
-
outputs = self.forward_clip_projector(
|
| 806 |
else: # for InternImage-T/S/B/L/XL
|
| 807 |
-
outputs = self.forward_features(
|
| 808 |
|
| 809 |
-
hidden_states = outputs['hidden_states']
|
| 810 |
-
pooler_output = outputs['pooler_output']
|
|
|
|
| 811 |
|
| 812 |
if self.num_classes > 0:
|
| 813 |
-
logits = self.head(pooler_output)
|
| 814 |
else:
|
| 815 |
logits = None
|
| 816 |
|
|
|
|
|
|
|
|
|
|
| 817 |
return BackboneOutput(
|
| 818 |
hidden_states=hidden_states,
|
| 819 |
-
last_hidden_state=
|
| 820 |
pooler_output=pooler_output,
|
| 821 |
logits=logits
|
| 822 |
)
|
|
@@ -853,8 +861,17 @@ class InternImageModel(PreTrainedModel):
|
|
| 853 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 854 |
)
|
| 855 |
|
| 856 |
-
def forward(self,
|
| 857 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
|
| 860 |
class InternImageModelForImageClassification(PreTrainedModel):
|
|
@@ -862,6 +879,7 @@ class InternImageModelForImageClassification(PreTrainedModel):
|
|
| 862 |
|
| 863 |
def __init__(self, config):
|
| 864 |
super().__init__(config)
|
|
|
|
| 865 |
self.model = InternImage(
|
| 866 |
core_op=config.core_op,
|
| 867 |
channels=config.channels,
|
|
@@ -888,12 +906,34 @@ class InternImageModelForImageClassification(PreTrainedModel):
|
|
| 888 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 889 |
)
|
| 890 |
|
| 891 |
-
def forward(self,
|
| 892 |
-
|
| 893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 894 |
if labels is not None:
|
| 895 |
-
logits = outputs[
|
| 896 |
loss = F.cross_entropy(logits, labels)
|
| 897 |
-
outputs['loss'] = loss
|
| 898 |
|
| 899 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
'pooler_output': x if self.num_classes > 0 else None
|
| 801 |
}
|
| 802 |
|
| 803 |
+
def forward(self,
|
| 804 |
+
pixel_values,
|
| 805 |
+
output_attentions=None,
|
| 806 |
+
output_hidden_states=None,
|
| 807 |
+
return_dict=None):
|
| 808 |
if self.use_clip_projector: # for InternImage-H/G
|
| 809 |
+
outputs = self.forward_clip_projector(pixel_values)
|
| 810 |
else: # for InternImage-T/S/B/L/XL
|
| 811 |
+
outputs = self.forward_features(pixel_values)
|
| 812 |
|
| 813 |
+
hidden_states = outputs['hidden_states'] if output_hidden_states is not None else None
|
| 814 |
+
pooler_output = outputs['pooler_output'] if output_attentions is not None else None
|
| 815 |
+
last_hidden_state = outputs['hidden_states'][-1] if output_hidden_states is not None else None
|
| 816 |
|
| 817 |
if self.num_classes > 0:
|
| 818 |
+
logits = self.head(outputs['pooler_output'])
|
| 819 |
else:
|
| 820 |
logits = None
|
| 821 |
|
| 822 |
+
if not return_dict:
|
| 823 |
+
return tuple(v for v in [logits, hidden_states, pooler_output, last_hidden_state] if v is not None)
|
| 824 |
+
|
| 825 |
return BackboneOutput(
|
| 826 |
hidden_states=hidden_states,
|
| 827 |
+
last_hidden_state=last_hidden_state,
|
| 828 |
pooler_output=pooler_output,
|
| 829 |
logits=logits
|
| 830 |
)
|
|
|
|
| 861 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 862 |
)
|
| 863 |
|
| 864 |
+
def forward(self,
|
| 865 |
+
pixel_values,
|
| 866 |
+
output_attentions=None,
|
| 867 |
+
output_hidden_states=None,
|
| 868 |
+
return_dict=None):
|
| 869 |
+
|
| 870 |
+
return self.model.forward_features(
|
| 871 |
+
pixel_values,
|
| 872 |
+
output_attentions=output_attentions,
|
| 873 |
+
output_hidden_states=output_hidden_states,
|
| 874 |
+
return_dict=return_dict)
|
| 875 |
|
| 876 |
|
| 877 |
class InternImageModelForImageClassification(PreTrainedModel):
|
|
|
|
| 879 |
|
| 880 |
def __init__(self, config):
|
| 881 |
super().__init__(config)
|
| 882 |
+
self.config = config
|
| 883 |
self.model = InternImage(
|
| 884 |
core_op=config.core_op,
|
| 885 |
channels=config.channels,
|
|
|
|
| 906 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 907 |
)
|
| 908 |
|
| 909 |
+
def forward(self,
|
| 910 |
+
pixel_values,
|
| 911 |
+
labels=None,
|
| 912 |
+
output_attentions=None,
|
| 913 |
+
output_hidden_states=None,
|
| 914 |
+
return_dict=None):
|
| 915 |
+
|
| 916 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 917 |
+
|
| 918 |
+
outputs = self.model.forward(
|
| 919 |
+
pixel_values,
|
| 920 |
+
output_attentions=output_attentions,
|
| 921 |
+
output_hidden_states=output_hidden_states,
|
| 922 |
+
return_dict=return_dict)
|
| 923 |
+
|
| 924 |
+
loss = None
|
| 925 |
if labels is not None:
|
| 926 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 927 |
loss = F.cross_entropy(logits, labels)
|
|
|
|
| 928 |
|
| 929 |
+
if not return_dict:
|
| 930 |
+
output = (outputs[0],) + outputs[1:]
|
| 931 |
+
return ((loss,) + output) if loss is not None else output
|
| 932 |
+
|
| 933 |
+
return BackboneOutput(
|
| 934 |
+
loss = loss,
|
| 935 |
+
logits = outputs.logits,
|
| 936 |
+
hidden_states = outputs.hidden_states,
|
| 937 |
+
last_hidden_state = outputs.last_hidden_state,
|
| 938 |
+
pooler_output = outputs.pooler_output
|
| 939 |
+
)
|