Commit
·
1ae2720
1
Parent(s):
48a795f
Fix model loading and file paths
Browse files
backend/baselines/ViT/ViT_LRP.py
CHANGED
|
@@ -445,12 +445,13 @@ def _conv_filter(state_dict, patch_size=16):
|
|
| 445 |
|
| 446 |
# ViT_LRP
|
| 447 |
def vit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
| 448 |
model = VisionTransformer(
|
| 449 |
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
|
| 450 |
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
| 451 |
if pretrained:
|
| 452 |
load_pretrained(
|
| 453 |
-
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
| 454 |
return model
|
| 455 |
|
| 456 |
def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
| 445 |
|
| 446 |
# ViT_LRP
|
| 447 |
def vit_base_patch16_224(pretrained=False, **kwargs):
|
| 448 |
+
checkpoint_dir = kwargs.pop('checkpoint_dir', None)
|
| 449 |
model = VisionTransformer(
|
| 450 |
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
|
| 451 |
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
| 452 |
if pretrained:
|
| 453 |
load_pretrained(
|
| 454 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter, checkpoint_dir=checkpoint_dir)
|
| 455 |
return model
|
| 456 |
|
| 457 |
def vit_large_patch16_224(pretrained=False, **kwargs):
|
backend/flask_backend.py
CHANGED
|
@@ -96,7 +96,7 @@ logger.info(f"Using checkpoint directory: {checkpoint_dir}")
|
|
| 96 |
|
| 97 |
import json
|
| 98 |
|
| 99 |
-
all_model_details_json_list= list(json.load(open("models_details.json")))
|
| 100 |
all_model_details_dict = dict()
|
| 101 |
for i, item in enumerate(all_model_details_json_list):
|
| 102 |
all_model_details_dict[item["model_id"]] = item
|
|
@@ -168,10 +168,10 @@ def load_model_by_id(model_id):
|
|
| 168 |
if model_id == "vit_base_patch16_224.augreg2_in21k_ft_in1k":
|
| 169 |
# Original default model logic
|
| 170 |
# Removed incorrect checkpoint_dir argument
|
| 171 |
-
model = vit_base_patch16_224(pretrained=True).to(target_device)
|
| 172 |
model.eval()
|
| 173 |
|
| 174 |
-
inference_model = vit_orig(pretrained=True).to(target_device)
|
| 175 |
inference_model.eval()
|
| 176 |
|
| 177 |
elif model_id == "vit_base_patch14_reg4_dinov2.lvd142m":
|
|
|
|
| 96 |
|
| 97 |
import json
|
| 98 |
|
| 99 |
+
all_model_details_json_list= list(json.load(open(os.path.join(_script_dir, "models_details.json"))))
|
| 100 |
all_model_details_dict = dict()
|
| 101 |
for i, item in enumerate(all_model_details_json_list):
|
| 102 |
all_model_details_dict[item["model_id"]] = item
|
|
|
|
| 168 |
if model_id == "vit_base_patch16_224.augreg2_in21k_ft_in1k":
|
| 169 |
# Original default model logic
|
| 170 |
# Removed incorrect checkpoint_dir argument
|
| 171 |
+
model = vit_base_patch16_224(pretrained=True, checkpoint_dir=checkpoint_dir).to(target_device)
|
| 172 |
model.eval()
|
| 173 |
|
| 174 |
+
inference_model = vit_orig(pretrained=True, checkpoint_dir=checkpoint_dir).to(target_device)
|
| 175 |
inference_model.eval()
|
| 176 |
|
| 177 |
elif model_id == "vit_base_patch14_reg4_dinov2.lvd142m":
|