Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,60 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
To use Sum-of-Parts(SOP), you would need to install exlib. Currently SOP is only available on the dev branch https://github.com/BrachioLab/exlib/tree/dev
|
| 6 |
+
|
| 7 |
+
To use SOP trained for `google/vit-base-patch16-224`, follow the following code.
|
| 8 |
+
|
| 9 |
+
### Load the model
|
| 10 |
+
```
|
| 11 |
+
import torch
|
| 12 |
+
import os
|
| 13 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from exlib.modules.sop import WrappedModel, SOPConfig, SOPImageCls, get_chained_attr
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# init backbone model
|
| 20 |
+
backbone_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
| 21 |
+
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
|
| 22 |
+
|
| 23 |
+
# get needed wrapped models
|
| 24 |
+
original_model = WrappedModel(backbone_model, output_type='logits')
|
| 25 |
+
wrapped_backbone_model = WrappedModel(backbone_model, output_type='tuple')
|
| 26 |
+
projection_layer = WrappedModel(wrapped_backbone_model, output_type='hidden_states')
|
| 27 |
+
|
| 28 |
+
# load trained sop model
|
| 29 |
+
model = SOPImageCls.from_pretrained('BrachioLab/sop-vit-base-patch16-224',
|
| 30 |
+
blackbox_model=wrapped_backbone_model,
|
| 31 |
+
projection_layer=projection_layer)
|
| 32 |
+
model.eval();
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Open an image
|
| 36 |
+
```
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
# Open an example image
|
| 40 |
+
# image_path = '../../examples/ILSVRC2012_val_00000873.JPEG'
|
| 41 |
+
image_path = '../../examples/ILSVRC2012_val_00000247.JPEG'
|
| 42 |
+
image = Image.open(image_path)
|
| 43 |
+
image.show()
|
| 44 |
+
image_rgb = image.convert("RGB")
|
| 45 |
+
inputs = torch.tensor(processor(image_rgb)['pixel_values'])
|
| 46 |
+
inputs.shape # (1, 3, 224, 224)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### Get the output from SOP
|
| 50 |
+
```
|
| 51 |
+
# Get the outputs from the model
|
| 52 |
+
outputs = model(inputs, return_tuple=True)
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### Show the groups
|
| 56 |
+
```
|
| 57 |
+
from exlib.modules.sop import show_masks_weights
|
| 58 |
+
|
| 59 |
+
show_masks_weights(inputs, outputs, i=0) # This allows you to see the group masks with group attribution scores.
|
| 60 |
+
```
|