SasikaA073 commited on
Commit
1ae2720
·
1 Parent(s): 48a795f

Fix model loading and file paths

Browse files
backend/baselines/ViT/ViT_LRP.py CHANGED
@@ -445,12 +445,13 @@ def _conv_filter(state_dict, patch_size=16):
445
 
446
  # ViT_LRP
447
  def vit_base_patch16_224(pretrained=False, **kwargs):
 
448
  model = VisionTransformer(
449
  patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
450
  model.default_cfg = default_cfgs['vit_base_patch16_224']
451
  if pretrained:
452
  load_pretrained(
453
- model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
454
  return model
455
 
456
  def vit_large_patch16_224(pretrained=False, **kwargs):
 
445
 
446
  # ViT_LRP
447
  def vit_base_patch16_224(pretrained=False, **kwargs):
448
+ checkpoint_dir = kwargs.pop('checkpoint_dir', None)
449
  model = VisionTransformer(
450
  patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
451
  model.default_cfg = default_cfgs['vit_base_patch16_224']
452
  if pretrained:
453
  load_pretrained(
454
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter, checkpoint_dir=checkpoint_dir)
455
  return model
456
 
457
  def vit_large_patch16_224(pretrained=False, **kwargs):
backend/flask_backend.py CHANGED
@@ -96,7 +96,7 @@ logger.info(f"Using checkpoint directory: {checkpoint_dir}")
96
 
97
  import json
98
 
99
- all_model_details_json_list= list(json.load(open("models_details.json")))
100
  all_model_details_dict = dict()
101
  for i, item in enumerate(all_model_details_json_list):
102
  all_model_details_dict[item["model_id"]] = item
@@ -168,10 +168,10 @@ def load_model_by_id(model_id):
168
  if model_id == "vit_base_patch16_224.augreg2_in21k_ft_in1k":
169
  # Original default model logic
170
  # Removed incorrect checkpoint_dir argument
171
- model = vit_base_patch16_224(pretrained=True).to(target_device)
172
  model.eval()
173
 
174
- inference_model = vit_orig(pretrained=True).to(target_device)
175
  inference_model.eval()
176
 
177
  elif model_id == "vit_base_patch14_reg4_dinov2.lvd142m":
 
96
 
97
  import json
98
 
99
+ all_model_details_json_list= list(json.load(open(os.path.join(_script_dir, "models_details.json"))))
100
  all_model_details_dict = dict()
101
  for i, item in enumerate(all_model_details_json_list):
102
  all_model_details_dict[item["model_id"]] = item
 
168
  if model_id == "vit_base_patch16_224.augreg2_in21k_ft_in1k":
169
  # Original default model logic
170
  # Removed incorrect checkpoint_dir argument
171
+ model = vit_base_patch16_224(pretrained=True, checkpoint_dir=checkpoint_dir).to(target_device)
172
  model.eval()
173
 
174
+ inference_model = vit_orig(pretrained=True, checkpoint_dir=checkpoint_dir).to(target_device)
175
  inference_model.eval()
176
 
177
  elif model_id == "vit_base_patch14_reg4_dinov2.lvd142m":