|
|
import io |
|
|
import logging |
|
|
import traceback |
|
|
import json |
|
|
import sys |
|
|
import os |
|
|
import argparse |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
from flask import Flask, request, send_file, jsonify |
|
|
from flask_cors import CORS |
|
|
|
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG, |
|
|
format="[%(asctime)s] [%(levelname)s] %(message)s", |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info("Starting Flask + ViT-LRP server...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from baselines.ViT.ViT_LRP import vit_base_patch16_224 |
|
|
from baselines.ViT.ViT_explanation_generator import Baselines,LRP |
|
|
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_orig |
|
|
|
|
|
from baselines.ViT.ViT_LRP import vit_base_patch14_reg4_dinov2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_thresholding = False |
|
|
|
|
|
normalize = transforms.Normalize( |
|
|
mean=[0.5, 0.5, 0.5], |
|
|
std=[0.5, 0.5, 0.5], |
|
|
) |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
|
|
|
VALID_ATTR_METHODS = [ |
|
|
"rollout", |
|
|
"transformer_attribution", |
|
|
"full", |
|
|
"last_layer", |
|
|
"last_layer_attn", |
|
|
"attn_gradcam", |
|
|
] |
|
|
|
|
|
_script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
with open(os.path.join(_script_dir, "cls2idx.json"), "r") as f: |
|
|
CLS2IDX = json.load(f) |
|
|
|
|
|
def get_class_name(idx: int) -> str: |
|
|
return CLS2IDX.get(str(idx), f"class_{idx}") |
|
|
|
|
|
|
|
|
|
|
|
def show_cam_on_image(img, mask): |
|
|
""" |
|
|
img: H x W x 3, float in [0,1] |
|
|
mask: H x W, float in [0,1] |
|
|
""" |
|
|
logger.debug(f"show_cam_on_image: img.shape={img.shape}, mask.shape={mask.shape}") |
|
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
|
|
heatmap = np.float32(heatmap) / 255 |
|
|
cam = heatmap + np.float32(img) |
|
|
cam = cam / np.max(cam) |
|
|
return cam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_dir = os.path.join(_script_dir, "model_checkpoints") |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
logger.info(f"Using checkpoint directory: {checkpoint_dir}") |
|
|
|
|
|
import json |
|
|
|
|
|
all_model_details_json_list= list(json.load(open(os.path.join(_script_dir, "models_details.json")))) |
|
|
all_model_details_dict = dict() |
|
|
for i, item in enumerate(all_model_details_json_list): |
|
|
all_model_details_dict[item["model_id"]] = item |
|
|
|
|
|
|
|
|
|
|
|
_current_device = "cpu" |
|
|
SELECTED_MODEL_ID = "vit_base_patch16_224.augreg2_in21k_ft_in1k" |
|
|
|
|
|
|
|
|
model = None |
|
|
inference_model = None |
|
|
baselines = None |
|
|
attribution_generator = None |
|
|
|
|
|
def get_device(requested_device=None): |
|
|
""" |
|
|
Get and validate device. Move models if needed. |
|
|
requested_device: str or None - 'cpu', 'cuda', or 'mps' |
|
|
""" |
|
|
global model, inference_model, baselines, attribution_generator, _current_device |
|
|
|
|
|
if requested_device is None: |
|
|
|
|
|
requested_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if requested_device == "cuda" and not torch.cuda.is_available(): |
|
|
logger.warning("CUDA requested but not available, falling back to CPU") |
|
|
requested_device = "cpu" |
|
|
elif requested_device == "mps": |
|
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
|
requested_device = "mps" |
|
|
else: |
|
|
logger.warning("MPS requested but not available, falling back to CPU") |
|
|
requested_device = "cpu" |
|
|
elif requested_device not in ["cpu", "cuda", "mps"]: |
|
|
logger.warning(f"Invalid device '{requested_device}', falling back to CPU") |
|
|
requested_device = "cpu" |
|
|
|
|
|
|
|
|
if _current_device != requested_device: |
|
|
logger.info(f"Moving models to device: {requested_device}") |
|
|
if model is not None: |
|
|
model = model.to(requested_device) |
|
|
if inference_model is not None: |
|
|
inference_model = inference_model.to(requested_device) |
|
|
|
|
|
|
|
|
|
|
|
baselines = Baselines(inference_model) |
|
|
|
|
|
if model is not None: |
|
|
attribution_generator = LRP(model) |
|
|
|
|
|
_current_device = requested_device |
|
|
|
|
|
return requested_device |
|
|
|
|
|
def load_model_by_id(model_id): |
|
|
global model, inference_model, baselines, attribution_generator, _current_device, SELECTED_MODEL_ID |
|
|
|
|
|
logger.info(f"Loading model: {model_id}...") |
|
|
|
|
|
|
|
|
target_device = _current_device |
|
|
|
|
|
try: |
|
|
if model_id == "vit_base_patch16_224.augreg2_in21k_ft_in1k": |
|
|
|
|
|
|
|
|
model = vit_base_patch16_224(pretrained=True, checkpoint_dir=checkpoint_dir).to(target_device) |
|
|
model.eval() |
|
|
|
|
|
inference_model = vit_orig(pretrained=True, checkpoint_dir=checkpoint_dir).to(target_device) |
|
|
inference_model.eval() |
|
|
|
|
|
elif model_id == "vit_base_patch14_reg4_dinov2.lvd142m": |
|
|
|
|
|
|
|
|
model = vit_base_patch14_reg4_dinov2( |
|
|
pretrained=False, |
|
|
img_size=518 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model.head = torch.nn.Linear( |
|
|
in_features=768, |
|
|
out_features=1000, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rel_path = "vit_base_patch14_reg4_dinov2_head_finetuned_best.pth" |
|
|
if model_id in all_model_details_dict: |
|
|
rel_path = all_model_details_dict[model_id].get("finetuned_head_path", rel_path) |
|
|
|
|
|
|
|
|
|
|
|
candidate_names = [ |
|
|
rel_path, |
|
|
"vit_base_patch14_reg4_dinov2_head_finetuned_best", |
|
|
"vit_base_patch14_reg4_dinov2_head_finetuned_best.pth" |
|
|
] |
|
|
|
|
|
ckpt_path = None |
|
|
for name in candidate_names: |
|
|
p = os.path.join(checkpoint_dir, "head_finetuned", os.path.basename(name)) |
|
|
if os.path.exists(p): |
|
|
ckpt_path = p |
|
|
break |
|
|
|
|
|
if name.endswith(".pth"): |
|
|
p_no_ext = os.path.join(checkpoint_dir, "head_finetuned", os.path.basename(name)[:-4]) |
|
|
if os.path.exists(p_no_ext): |
|
|
ckpt_path = p_no_ext |
|
|
break |
|
|
|
|
|
if ckpt_path is None: |
|
|
raise FileNotFoundError(f"Could not find checkpoint for {model_id} in {checkpoint_dir}/head_finetuned") |
|
|
|
|
|
logger.info(f"Loading weights from {ckpt_path}") |
|
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
|
|
|
if "pos_embed" in state_dict: |
|
|
chkpt_pos_embed = state_dict["pos_embed"] |
|
|
model_pos_embed = model.pos_embed |
|
|
logger.info(f"Checkpoint pos_embed: {chkpt_pos_embed.shape}, Model pos_embed: {model_pos_embed.shape}") |
|
|
|
|
|
if chkpt_pos_embed.shape != model_pos_embed.shape: |
|
|
logger.warning(f"Resizing pos_embed from {chkpt_pos_embed.shape} to {model_pos_embed.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if chkpt_pos_embed.shape[1] == 1369 and model_pos_embed.shape[1] == 1374: |
|
|
|
|
|
new_pos_embed = model_pos_embed.clone().detach() |
|
|
|
|
|
|
|
|
|
|
|
new_pos_embed[:, 5:, :] = chkpt_pos_embed |
|
|
state_dict["pos_embed"] = new_pos_embed |
|
|
logger.info("Successfully patched pos_embed with CLS/REG tokens preserved from init.") |
|
|
else: |
|
|
logger.warning("Unknown pos_embed mismatch pattern. Attempting standard interpolation if compatible, or failing.") |
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(target_device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_model = vit_base_patch14_reg4_dinov2( |
|
|
pretrained=False, |
|
|
img_size=518 |
|
|
) |
|
|
inference_model.head = torch.nn.Linear(768, 1000, bias=True) |
|
|
inference_model.load_state_dict(state_dict) |
|
|
inference_model.to(target_device) |
|
|
inference_model.eval() |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Model ID {model_id} not implemented in backend loading logic.") |
|
|
|
|
|
|
|
|
baselines = Baselines(inference_model) |
|
|
attribution_generator = LRP(model) |
|
|
|
|
|
if model_id in all_model_details_dict: |
|
|
img_size = all_model_details_dict[model_id].get("input_image_resolution", 224) |
|
|
else: |
|
|
img_size = 224 |
|
|
|
|
|
global transform |
|
|
if img_size == 224: |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
else: |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((img_size, img_size)), |
|
|
transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
logger.info(f"Updated transform for model resolution: {img_size}x{img_size}") |
|
|
|
|
|
SELECTED_MODEL_ID = model_id |
|
|
logger.info(f"Successfully loaded {model_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model {model_id}: {e}") |
|
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
def compute_attribution_map(original_image, class_index=None, method="transformer_attribution", device="cpu"): |
|
|
""" |
|
|
original_image: tensor [3, 224, 224] (normalized) |
|
|
class_index: int or None |
|
|
method: one of VALID_ATTR_METHODS |
|
|
device: str - device to use ('cpu', 'cuda', 'mps') |
|
|
returns: torch.Tensor [224, 224] on CPU, in [0,1] |
|
|
""" |
|
|
|
|
|
device = get_device(device) |
|
|
|
|
|
|
|
|
input_tensor = original_image.unsqueeze(0).to(device) |
|
|
input_tensor.requires_grad_(True) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if model is None: |
|
|
raise RuntimeError("No model selected or loaded. Please select a model first.") |
|
|
|
|
|
if method == "attn_gradcam": |
|
|
transformer_attribution = baselines.generate_cam_attn(input_tensor, index=class_index) |
|
|
else: |
|
|
transformer_attribution = attribution_generator.generate_LRP( |
|
|
input_tensor, |
|
|
method=method, |
|
|
index=class_index |
|
|
) |
|
|
logger.debug(f"transformer_attribution shape after generate_LRP: {transformer_attribution.shape}, dims: {transformer_attribution.dim()}") |
|
|
|
|
|
|
|
|
|
|
|
if transformer_attribution.dim() == 0: |
|
|
raise ValueError(f"Received scalar tensor from generate_LRP (method={method})") |
|
|
elif transformer_attribution.dim() == 1: |
|
|
|
|
|
numel = transformer_attribution.numel() |
|
|
size = int(np.sqrt(numel)) |
|
|
if size * size == numel: |
|
|
transformer_attribution = transformer_attribution.reshape(size, size) |
|
|
else: |
|
|
raise ValueError(f"Cannot reshape 1D tensor of size {numel} to square") |
|
|
|
|
|
|
|
|
if transformer_attribution.dim() == 2: |
|
|
|
|
|
transformer_attribution = transformer_attribution.unsqueeze(0).unsqueeze(0) |
|
|
elif transformer_attribution.dim() == 3: |
|
|
|
|
|
if transformer_attribution.shape[0] == 1: |
|
|
transformer_attribution = transformer_attribution.unsqueeze(0) |
|
|
else: |
|
|
|
|
|
transformer_attribution = transformer_attribution.mean(dim=0, keepdim=True).unsqueeze(0) |
|
|
elif transformer_attribution.dim() == 4: |
|
|
|
|
|
if transformer_attribution.shape[0] != 1: |
|
|
transformer_attribution = transformer_attribution[0:1] |
|
|
if transformer_attribution.shape[1] != 1: |
|
|
transformer_attribution = transformer_attribution.mean(dim=1, keepdim=True) |
|
|
|
|
|
logger.debug(f"transformer_attribution shape after normalization: {transformer_attribution.shape}") |
|
|
|
|
|
|
|
|
if method == "full": |
|
|
|
|
|
if transformer_attribution.shape[-2:] != (224, 224): |
|
|
transformer_attribution = F.interpolate( |
|
|
transformer_attribution, |
|
|
size=(224, 224), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
else: |
|
|
|
|
|
current_h, current_w = transformer_attribution.shape[-2:] |
|
|
if (current_h, current_w) != (14, 14): |
|
|
|
|
|
if current_h * current_w == 14 * 14: |
|
|
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) |
|
|
else: |
|
|
|
|
|
transformer_attribution = F.interpolate( |
|
|
transformer_attribution, |
|
|
size=(14, 14), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
|
|
|
transformer_attribution = F.interpolate( |
|
|
transformer_attribution, |
|
|
size=(224, 224), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
|
|
|
transformer_attribution = transformer_attribution.squeeze() |
|
|
if transformer_attribution.dim() != 2: |
|
|
transformer_attribution = transformer_attribution.reshape(224, 224) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing attribution map: {e}") |
|
|
logger.error(f"Method: {method}, tensor shape: {transformer_attribution.shape if 'transformer_attribution' in locals() else 'N/A'}") |
|
|
raise |
|
|
|
|
|
|
|
|
min_val = transformer_attribution.min() |
|
|
max_val = transformer_attribution.max() |
|
|
attr_norm = (transformer_attribution - min_val) / (max_val - min_val + 1e-8) |
|
|
|
|
|
|
|
|
return attr_norm.detach().cpu() |
|
|
|
|
|
|
|
|
|
|
|
def generate_visualization(original_image, class_index=None, method="transformer_attribution", device="cpu"): |
|
|
""" |
|
|
original_image: tensor [3, 224, 224] (normalized) |
|
|
device: str - device to use ('cpu', 'cuda', 'mps') |
|
|
returns: np.array HxWx3 (BGR, uint8) |
|
|
""" |
|
|
|
|
|
attr = compute_attribution_map(original_image, class_index=class_index, method=method, device=device) |
|
|
transformer_attribution = attr.numpy() |
|
|
|
|
|
|
|
|
image_transformer_attribution = original_image.permute(1, 2, 0).cpu().numpy() |
|
|
img_min = image_transformer_attribution.min() |
|
|
img_max = image_transformer_attribution.max() |
|
|
image_transformer_attribution = ( |
|
|
image_transformer_attribution - img_min |
|
|
) / (img_max - img_min + 1e-8) |
|
|
|
|
|
|
|
|
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) |
|
|
vis = np.uint8(255 * vis) |
|
|
|
|
|
return vis |
|
|
|
|
|
def run_perturbation(img_raw, img_norm, target_index=None,perturbation_type="positive", method="transformer_attribution", device="cpu"): |
|
|
""" |
|
|
img_raw: [3,224,224] tensor in [0,1] |
|
|
img_norm: [3,224,224] normalized tensor |
|
|
target_index: int or None |
|
|
perturbation_type: str - 'positive' or 'negative' |
|
|
method: str - attribution method |
|
|
device: str - device to use ('cpu', 'cuda', 'mps') |
|
|
|
|
|
returns: |
|
|
target_class_idx |
|
|
target_class_name |
|
|
original_prediction: { |
|
|
"top1": {...}, |
|
|
"target": {...} |
|
|
} |
|
|
perturbation_results: list of { |
|
|
"fraction", "top1_class_idx", "top1_class_name", |
|
|
"top1_prob", "target_prob" |
|
|
} |
|
|
""" |
|
|
|
|
|
device = get_device(device) |
|
|
|
|
|
base_size = 224 * 224 |
|
|
perturbation_steps = [0.01, 0.05, 0.08, 0.1, 0.15, 0.3, 0.35, 0.4, 0.45] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = inference_model(img_norm.unsqueeze(0).to(device)) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
top1_prob, top1_idx = torch.max(probs, dim=1) |
|
|
|
|
|
top1_idx_int = int(top1_idx.item()) |
|
|
top1_prob_float = float(top1_prob.item()) |
|
|
|
|
|
|
|
|
if target_index is None: |
|
|
target_index = top1_idx_int |
|
|
|
|
|
target_index_int = int(target_index) |
|
|
target_prob_float = float(probs[0, target_index_int].item()) |
|
|
|
|
|
original_prediction = { |
|
|
"top1": { |
|
|
"class_idx": top1_idx_int, |
|
|
"class_name": get_class_name(top1_idx_int), |
|
|
"prob": top1_prob_float, |
|
|
}, |
|
|
"target": { |
|
|
"class_idx": target_index_int, |
|
|
"class_name": get_class_name(target_index_int), |
|
|
"prob": target_prob_float, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
attr = compute_attribution_map(img_norm, class_index=target_index_int, method=method, device=device) |
|
|
vis = attr.view(-1) |
|
|
|
|
|
|
|
|
|
|
|
if perturbation_type == "negative": |
|
|
vis = -vis |
|
|
|
|
|
|
|
|
data = img_raw.unsqueeze(0) |
|
|
org_shape = data.shape |
|
|
|
|
|
perturbation_results = [] |
|
|
|
|
|
for frac in perturbation_steps: |
|
|
k = int(base_size * frac) |
|
|
|
|
|
_data = data.clone() |
|
|
|
|
|
|
|
|
_, idx = torch.topk(vis, k, dim=-1) |
|
|
|
|
|
idx_expanded = idx.unsqueeze(0).unsqueeze(0).repeat( |
|
|
org_shape[0], org_shape[1], 1 |
|
|
) |
|
|
|
|
|
_data_flat = _data.view(org_shape[0], org_shape[1], -1) |
|
|
_data_flat.scatter_(-1, idx_expanded, 0.0) |
|
|
|
|
|
_data_perturbed = _data_flat.view(org_shape) |
|
|
|
|
|
|
|
|
pert_single = _data_perturbed.squeeze(0) |
|
|
pert_norm = normalize(pert_single).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits_p = inference_model(pert_norm) |
|
|
probs_p = torch.softmax(logits_p, dim=1) |
|
|
top1_prob_p, top1_idx_p = torch.max(probs_p, dim=1) |
|
|
top1_idx_p_int = int(top1_idx_p.item()) |
|
|
top1_prob_p_float = float(top1_prob_p.item()) |
|
|
target_prob_p = float(probs_p[0, target_index_int].item()) |
|
|
|
|
|
perturbation_results.append({ |
|
|
"fraction": float(frac), |
|
|
"top1_class_idx": top1_idx_p_int, |
|
|
"top1_class_name": get_class_name(top1_idx_p_int), |
|
|
"top1_prob": top1_prob_p_float, |
|
|
"target_prob": target_prob_p, |
|
|
}) |
|
|
|
|
|
return ( |
|
|
target_index_int, |
|
|
get_class_name(target_index_int), |
|
|
original_prediction, |
|
|
perturbation_results, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
CORS(app, resources={r"/api/*": {"origins": "*"}}) |
|
|
|
|
|
@app.route("/api/heatmap", methods=["POST"]) |
|
|
def heatmap(): |
|
|
logger.info("➡️ /api/heatmap called") |
|
|
|
|
|
logger.debug(f"request.files keys: {list(request.files.keys())}") |
|
|
logger.debug(f"request.form keys: {list(request.form.keys())}") |
|
|
|
|
|
if "image" not in request.files: |
|
|
logger.warning("No 'image' field in request.files") |
|
|
return jsonify({"error": "No image field in form-data (expected 'image')"}), 400 |
|
|
|
|
|
file = request.files["image"] |
|
|
logger.info( |
|
|
f"Received file: filename='{file.filename}', content_type='{file.content_type}'" |
|
|
) |
|
|
|
|
|
if file.filename == "": |
|
|
logger.warning("Empty filename in uploaded file.") |
|
|
return jsonify({"error": "Empty filename"}), 400 |
|
|
|
|
|
|
|
|
target_index = request.form.get("target_index", None) |
|
|
class_index = None |
|
|
if target_index not in (None, ""): |
|
|
try: |
|
|
class_index = int(target_index) |
|
|
except ValueError: |
|
|
return jsonify({"error": "target_index must be an integer"}), 400 |
|
|
|
|
|
|
|
|
method = request.form.get("method", "transformer_attribution") |
|
|
if method not in VALID_ATTR_METHODS: |
|
|
return jsonify({"error": f"Invalid method '{method}'. Must be one of: {', '.join(VALID_ATTR_METHODS)}"}), 400 |
|
|
|
|
|
|
|
|
device = request.form.get("device", "cpu") |
|
|
if device not in ["cpu", "cuda", "mps"]: |
|
|
device = "cpu" |
|
|
logger.warning(f"Invalid device requested, using CPU") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.debug("Opening image with PIL...") |
|
|
img = Image.open(file.stream).convert("RGB") |
|
|
logger.debug(f"Image opened: size={img.size}, mode={img.mode}") |
|
|
|
|
|
logger.debug("Applying torchvision transforms...") |
|
|
img_tensor = transform(img) |
|
|
logger.debug( |
|
|
f"Image transformed: shape={tuple(img_tensor.shape)}, " |
|
|
f"dtype={img_tensor.dtype}" |
|
|
) |
|
|
|
|
|
|
|
|
logger.debug(f"Generating visualization on device: {device}...") |
|
|
vis_bgr = generate_visualization(img_tensor, class_index=class_index, method=method, device=device) |
|
|
|
|
|
logger.debug( |
|
|
f"Visualization generated: vis_bgr.shape={vis_bgr.shape}, " |
|
|
f"dtype={vis_bgr.dtype}" |
|
|
) |
|
|
|
|
|
|
|
|
logger.debug("Converting BGR numpy array to RGB PIL Image...") |
|
|
vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB) |
|
|
vis_pil = Image.fromarray(vis_rgb) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
vis_pil.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
|
|
|
logger.info("✅ Successfully generated heatmap, sending PNG response.") |
|
|
return send_file(buf, mimetype="image/png") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("🔥 Exception in /api/heatmap") |
|
|
logger.error(e) |
|
|
traceback.print_exc() |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route("/api/perturbation", methods=["POST"]) |
|
|
def perturbation(): |
|
|
if "image" not in request.files: |
|
|
return jsonify({"error": "No image field in form-data (expected 'image')"}), 400 |
|
|
|
|
|
file = request.files["image"] |
|
|
if file.filename == "": |
|
|
return jsonify({"error": "Empty filename"}), 400 |
|
|
|
|
|
|
|
|
target_index_raw = request.form.get("target_index", None) |
|
|
target_index = None |
|
|
if target_index_raw not in (None, ""): |
|
|
try: |
|
|
target_index = int(target_index_raw) |
|
|
except ValueError: |
|
|
return jsonify({"error": "target_index must be an integer"}), 400 |
|
|
|
|
|
|
|
|
perturbation_type = request.form.get("perturbation_type", "positive").lower() |
|
|
if perturbation_type not in ("positive", "negative"): |
|
|
return jsonify({"error": "perturbation_type must be 'positive' or 'negative'"}), 400 |
|
|
|
|
|
|
|
|
method = request.form.get("method", "transformer_attribution") |
|
|
if method not in VALID_ATTR_METHODS: |
|
|
return jsonify({"error": f"Invalid method '{method}'. Must be one of: {', '.join(VALID_ATTR_METHODS)}"}), 400 |
|
|
|
|
|
|
|
|
device = request.form.get("device", "cpu") |
|
|
if device not in ["cpu", "cuda", "mps"]: |
|
|
device = "cpu" |
|
|
logger.warning(f"Invalid device requested, using CPU") |
|
|
|
|
|
try: |
|
|
img = Image.open(file.stream).convert("RGB") |
|
|
img_norm = transform(img) |
|
|
img_raw = img_norm * 0.5 + 0.5 |
|
|
|
|
|
( |
|
|
target_idx, |
|
|
target_name, |
|
|
original_prediction, |
|
|
perturbation_results, |
|
|
) = run_perturbation( |
|
|
img_raw, |
|
|
img_norm, |
|
|
target_index=target_index, |
|
|
perturbation_type=perturbation_type, |
|
|
method=method, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
return jsonify({ |
|
|
"perturbation_type": perturbation_type, |
|
|
"method": method, |
|
|
"target_class_idx": target_idx, |
|
|
"target_class_name": target_name, |
|
|
"original_prediction": original_prediction, |
|
|
"perturbation_results": perturbation_results, |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print("Error in /api/perturbation:", e) |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route("/api/infer", methods=["POST"]) |
|
|
def infer(): |
|
|
""" |
|
|
Run plain classification on a given image (original or perturbed). |
|
|
Returns top-5 predictions with class names and probabilities. |
|
|
""" |
|
|
if "image" not in request.files: |
|
|
return jsonify({"error": "No image field in form-data (expected 'image')"}), 400 |
|
|
|
|
|
file = request.files["image"] |
|
|
if file.filename == "": |
|
|
return jsonify({"error": "Empty filename"}), 400 |
|
|
|
|
|
|
|
|
device = request.form.get("device", "cpu") |
|
|
if device not in ["cpu", "cuda", "mps"]: |
|
|
device = "cpu" |
|
|
logger.warning(f"Invalid device requested, using CPU") |
|
|
|
|
|
|
|
|
device = get_device(device) |
|
|
|
|
|
try: |
|
|
img = Image.open(file.stream).convert("RGB") |
|
|
img_norm = transform(img) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = inference_model(img_norm.unsqueeze(0).to(device)) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
|
|
|
|
|
|
topk = 5 |
|
|
top_probs, top_idxs = torch.topk(probs, k=topk, dim=1) |
|
|
top_probs = top_probs[0].cpu().tolist() |
|
|
top_idxs = top_idxs[0].cpu().tolist() |
|
|
|
|
|
predictions = [] |
|
|
for idx, p in zip(top_idxs, top_probs): |
|
|
predictions.append({ |
|
|
"class_idx": int(idx), |
|
|
"class_name": get_class_name(int(idx)), |
|
|
"prob": float(p), |
|
|
}) |
|
|
|
|
|
return jsonify({"predictions": predictions}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("Error in /api/infer:", e) |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route("/api/selected_model", methods=["POST"]) |
|
|
def set_selected_model(): |
|
|
""" |
|
|
Set the selected model for the session. |
|
|
Expects JSON with {"model_id": "model_id_here"}. |
|
|
""" |
|
|
global SELECTED_MODEL_ID |
|
|
logger.info("➡️ /api/selected_model called") |
|
|
|
|
|
try: |
|
|
data = request.get_json() |
|
|
if not data or "model_id" not in data: |
|
|
return jsonify({"error": "Missing 'model_id' in request body"}), 400 |
|
|
|
|
|
model_id = data["model_id"] |
|
|
|
|
|
|
|
|
if hasattr(torch.cuda, 'empty_cache'): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
load_model_by_id(model_id) |
|
|
|
|
|
logger.debug(f"Setting global selected model: {SELECTED_MODEL_ID}") |
|
|
|
|
|
return jsonify({"selected_model": SELECTED_MODEL_ID}), 200 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error setting selected model: {e}") |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/api/models", methods=["GET"]) |
|
|
def get_models(): |
|
|
""" |
|
|
Returns a list of available models from models_details.json |
|
|
""" |
|
|
logger.info("➡️ /api/models called") |
|
|
try: |
|
|
details_path = os.path.join(_script_dir, "models_details.json") |
|
|
with open(details_path, "r") as f: |
|
|
all_models = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
available_models = [m for m in all_models if m.get("available") is True] |
|
|
|
|
|
return jsonify({"models": available_models}) |
|
|
except Exception as e: |
|
|
logger.error(f"Error reading models_details.json: {e}") |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if SELECTED_MODEL_ID: |
|
|
load_model_by_id(SELECTED_MODEL_ID) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load default model on startup: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Flask backend for ViT explainability") |
|
|
parser.add_argument("--port", type=int, default=5000, help="Port to run the Flask server on (default: 5000)") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the Flask server to (default: 0.0.0.0)") |
|
|
parser.add_argument("--debug", action="store_true", default=True, help="Run Flask in debug mode (default: True)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
logger.info(f"Running Flask app on http://{args.host}:{args.port}") |
|
|
app.run(host=args.host, port=args.port, debug=args.debug) |
|
|
|