|
|
import argparse |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import os |
|
|
from data.datautils import build_medmnist_dataset |
|
|
from torchvision import transforms |
|
|
from utils.tools import * |
|
|
|
|
|
from BetaMixture import BetaMixtureModel |
|
|
from clip.custom_clip import get_coop |
|
|
from data.cls_to_names import * |
|
|
from tqdm import tqdm |
|
|
from sklearn.metrics import roc_auc_score |
|
|
from medmnistc_data import * |
|
|
import copy |
|
|
from datetime import datetime |
|
|
import warnings |
|
|
import gc |
|
|
from baselines import * |
|
|
warnings.filterwarnings("ignore") |
|
|
import random |
|
|
|
|
|
random.seed(0) |
|
|
|
|
|
|
|
|
medimeta_testset_task_dict = { |
|
|
|
|
|
"pbc": ["cell_class","bloodmnist"], |
|
|
|
|
|
"mammo_mass": ["pathology","breastmnist"], |
|
|
|
|
|
"pneumonia": ["disease_class","pneumoniamnist"], |
|
|
"fundus": ["disease_presence","retinamnist"], |
|
|
"oct": ["disease_class","octmnist"] |
|
|
} |
|
|
|
|
|
method_names = { |
|
|
|
|
|
|
|
|
'model_ensemble': 'Model Ensemble', |
|
|
'wise_ft': 'Model Souping', |
|
|
'tcube': 'Entropy-based', |
|
|
|
|
|
|
|
|
'tcube_MI_bmm': 'Mutual Information', |
|
|
} |
|
|
|
|
|
ent_mi_dict = {'entropy': [], 'mi': [], 'agreement_diff': [], 'correct_pt': [], 'correct_ft': [], 'x_entropy': []} |
|
|
dyn_v_stat_plot = {method: [] for method in method_names.keys()} |
|
|
dyn_v_stat_plot['conditions'] = [] |
|
|
|
|
|
def fetch_keys_for_value(dictionary, target_value): |
|
|
return [key for key, value in dictionary.items() if value[1] == target_value] |
|
|
|
|
|
def load_models(args, classnames, set_id=None): |
|
|
clip_pt = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, classnames) |
|
|
sd_pt = clip_pt.state_dict() |
|
|
|
|
|
if set_id in medimeta_testset_task_dict.keys(): |
|
|
ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{medimeta_testset_task_dict[set_id][1]}.pth') |
|
|
else: |
|
|
ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{set_id}.pth') |
|
|
sd_ft = torch.load(ft_path, map_location='cpu') |
|
|
if 'pub' in ft_path.lower(): |
|
|
sd_ft = sd_ft['state_dict'] |
|
|
clip_ft = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, state_dict=sd_ft, classnames=classnames) |
|
|
del sd_ft |
|
|
sd_ft = clip_ft.state_dict() |
|
|
return clip_pt, sd_pt, clip_ft, sd_ft |
|
|
def get_logits(model, dataloader, args, return_feats=False, normalize=True): |
|
|
|
|
|
model.eval() |
|
|
logits = [] |
|
|
labels = [] |
|
|
image_features = [] |
|
|
text_features = [] |
|
|
with torch.no_grad(): |
|
|
for inputs, label in tqdm(dataloader): |
|
|
inputs = inputs.cuda(args.gpu, non_blocking=True) |
|
|
label = label.cuda(args.gpu, non_blocking=True) |
|
|
if return_feats: |
|
|
outputs, img_feats, text_feats = model(inputs, return_logits=return_feats, normalize=normalize) |
|
|
image_features.append(img_feats) |
|
|
text_features.append(text_feats) |
|
|
else: |
|
|
outputs = model(inputs) |
|
|
logits.append(outputs) |
|
|
labels.append(label) |
|
|
|
|
|
if return_feats: |
|
|
return torch.cat(logits), torch.cat(labels), torch.cat(image_features), torch.cat(text_features) |
|
|
return torch.cat(logits), torch.cat(labels) |
|
|
def self_entropy(logits, temperature=0.95): |
|
|
logits = logits / temperature |
|
|
probs = torch.nn.functional.softmax(logits, dim=1) |
|
|
return -(probs * torch.log(probs + 1e-9)).sum(dim=1) |
|
|
def interpolation(lambdas, sd_pt, sd_ft): |
|
|
merged_sd = {} |
|
|
for key in sd_ft.keys(): |
|
|
interpolated_value = sd_pt[key] * lambdas[0] + sd_ft[key] * lambdas[1] |
|
|
merged_sd[key] = interpolated_value |
|
|
return merged_sd |
|
|
|
|
|
def compute_samplewise_tcube_weights(clip_pt, clip_ft, dataloader, args): |
|
|
logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False) |
|
|
logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) |
|
|
ent_pt = self_entropy(logits_pt) |
|
|
ent_ft = self_entropy(logits_ft) |
|
|
expertise_pt = (-ent_pt).exp() |
|
|
expertise_ft = (-ent_ft).exp() |
|
|
|
|
|
total_expertise = expertise_pt + expertise_ft |
|
|
if args.offset: |
|
|
coef_bias = (ent_pt.std()/ent_pt.mean() + ent_ft.std()/ent_ft.mean()) / 2 |
|
|
coef_biasw = (ent_pt.mean() + ent_ft.mean()) / ent_pt.mean() |
|
|
lambda_ft = (expertise_ft + (coef_bias/coef_biasw)) / (total_expertise + coef_bias) |
|
|
else: |
|
|
lambda_ft = expertise_ft / total_expertise |
|
|
|
|
|
|
|
|
global ent_mi_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ent_mi_dict['entropy'] = lambda_ft |
|
|
|
|
|
|
|
|
if args.batch_wise: |
|
|
batch_size = len(dataloader.dataset) // len(dataloader) |
|
|
num_batches = len(dataloader) |
|
|
if True: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_ft_bmm = [] |
|
|
lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1) |
|
|
bmm = BetaMixtureModel(n_mixtures=num_batches) |
|
|
bmm.fit(lambda_ft_np) |
|
|
for i in range(bmm.n_mixtures): |
|
|
a,b = bmm.beta_params_[i, 0],bmm.beta_params_[i, 1] |
|
|
|
|
|
lambda_ft_bmm.append(a/(a+b)) |
|
|
lambda_ft_bmm = torch.tensor(lambda_ft_bmm) |
|
|
lambda_pt = 1 - lambda_ft_bmm |
|
|
return torch.stack([lambda_pt, lambda_ft_bmm], dim=0) |
|
|
coefs_label = bmm.predict(lambda_ft_np) |
|
|
|
|
|
lambda_pt = 1 - lambda_ft |
|
|
|
|
|
return torch.stack([lambda_pt, lambda_ft]) |
|
|
def compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, dataloader, args, delta=0.5, batch_wise=True): |
|
|
|
|
|
logits_pt, labels = get_logits(clip_pt, dataloader, args, return_feats=False) |
|
|
logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) |
|
|
|
|
|
|
|
|
p_pt = torch.softmax(logits_pt, dim=1) |
|
|
p_ft = torch.softmax(logits_ft, dim=1) |
|
|
|
|
|
pred_pt = p_pt.argmax(dim=1) |
|
|
pred_ft = p_ft.argmax(dim=1) |
|
|
correct_pt = pred_pt.eq(labels.squeeze()) |
|
|
correct_ft = pred_ft.eq(labels.squeeze()) |
|
|
|
|
|
|
|
|
p_bar = (p_pt + p_ft) / 2.0 |
|
|
|
|
|
|
|
|
|
|
|
kl_pt = torch.sum(p_pt * torch.log(p_pt / (p_bar + 1e-8)), dim=1) |
|
|
kl_ft = torch.sum(p_ft * torch.log(p_ft / (p_bar + 1e-8)), dim=1) |
|
|
|
|
|
|
|
|
MI = 0.5 * (kl_pt + kl_ft) |
|
|
MI_orig = MI |
|
|
|
|
|
|
|
|
|
|
|
lam_min = args.lam_min if hasattr(args, 'lam_min') else 0.01 |
|
|
lam_max = args.lam_max if hasattr(args, 'lam_max') else 0.99 |
|
|
gamma = args.gamma if hasattr(args, 'gamma') else 0.5 |
|
|
lambda_ft = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI) |
|
|
lambda_plot = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI_orig) |
|
|
|
|
|
|
|
|
|
|
|
ent_pt = self_entropy(logits_pt) |
|
|
ent_ft = self_entropy(logits_ft) |
|
|
|
|
|
|
|
|
entropy_thresh_ft = getattr(args, 'entropy_thresh_ft', 0.05) |
|
|
entropy_thresh_pt = getattr(args, 'entropy_thresh_pt', 0.65) |
|
|
delta_extrap = delta |
|
|
|
|
|
|
|
|
|
|
|
lambda_ft = torch.where( |
|
|
ent_ft < entropy_thresh_ft, |
|
|
|
|
|
lambda_ft + delta_extrap, |
|
|
torch.where( |
|
|
ent_pt < entropy_thresh_pt, |
|
|
|
|
|
lambda_ft - delta_extrap, |
|
|
|
|
|
lambda_ft |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
lambda_ft = torch.clamp(lambda_ft, 0.0, 1.5) |
|
|
lambda_pt = 1 - lambda_ft |
|
|
|
|
|
|
|
|
|
|
|
global ent_mi_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ent_mi_dict['mi'] = MI |
|
|
|
|
|
|
|
|
ent_mi_dict['Ppt'] = p_pt |
|
|
ent_mi_dict['Pft'] = p_ft |
|
|
ent_mi_dict['correct_pt'] = correct_pt |
|
|
ent_mi_dict['correct_ft'] = correct_ft |
|
|
ce_pt = F.cross_entropy(logits_pt, labels.squeeze(), reduction='none') |
|
|
ce_ft = F.cross_entropy(logits_ft, labels.squeeze(), reduction='none') |
|
|
x_entropy_ratio = ce_ft / (ce_pt + ce_ft + 1e-9) |
|
|
ent_mi_dict['x_entropy'] = x_entropy_ratio |
|
|
ent_mi_dict['CE_pt'] = ce_pt |
|
|
ent_mi_dict['CE_ft'] = ce_ft |
|
|
|
|
|
|
|
|
|
|
|
if batch_wise: |
|
|
batch_size = len(dataloader.dataset) // len(dataloader) |
|
|
num_batches = len(dataloader) |
|
|
if args.lambda_mean_type == 'mean': |
|
|
lambda_ft_batchwise = lambda_ft[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1) |
|
|
lambda_pt_batchwise = 1 - lambda_ft_batchwise |
|
|
return torch.stack([lambda_pt_batchwise, lambda_ft_batchwise], dim=0) |
|
|
elif args.lambda_mean_type == 'bmm': |
|
|
|
|
|
lambda_ft_bmm = [] |
|
|
lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1) |
|
|
bmm = BetaMixtureModel(n_mixtures=num_batches) |
|
|
bmm.fit(lambda_ft_np) |
|
|
for i in range(bmm.n_mixtures): |
|
|
a, b = bmm.beta_params_[i, 0], bmm.beta_params_[i, 1] |
|
|
|
|
|
lambda_ft_bmm.append(a/(a+b)) |
|
|
lambda_ft_bmm = torch.tensor(lambda_ft_bmm) |
|
|
lambda_pt_bmm = 1 - lambda_ft_bmm |
|
|
return torch.stack([lambda_pt_bmm, lambda_ft_bmm], dim=0) |
|
|
|
|
|
|
|
|
return torch.stack([lambda_pt, lambda_ft]), lambda_plot |
|
|
def compute_and_evaluate_model_ensemble(clip_pt, clip_ft, dataloaders, args): |
|
|
logits_pt, _ = get_logits(clip_pt, dataloaders[0], args, return_feats=False, normalize=False) |
|
|
logits_ft, _ = get_logits(clip_ft, dataloaders[0], args, return_feats=False, normalize=False) |
|
|
|
|
|
logits_final = (logits_pt + logits_ft) / 2.0 |
|
|
|
|
|
labels_final = [] |
|
|
for _, label in tqdm(dataloaders[0]): |
|
|
labels_final.append(label) |
|
|
labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True) |
|
|
|
|
|
return compute_metrics(logits_final, labels_final) |
|
|
def compute_samplewise_conf_weights(clip_pt, clip_ft, dataloader, device="cuda"): |
|
|
clip_pt.to(device).eval() |
|
|
clip_ft.to(device).eval() |
|
|
all_lambdas = [] |
|
|
with torch.no_grad(): |
|
|
for images, _ in dataloader: |
|
|
images = images.to(device) |
|
|
|
|
|
logits_pt = clip_pt(images) |
|
|
logits_ft = clip_ft(images) |
|
|
|
|
|
conf_pt = F.softmax(logits_pt, dim=1).max(dim=1)[0] |
|
|
conf_ft = F.softmax(logits_ft, dim=1).max(dim=1)[0] |
|
|
|
|
|
conf_stack = torch.stack([conf_pt, conf_ft], dim=0) |
|
|
|
|
|
lambdas = conf_stack / conf_stack.sum(dim=0, keepdim=True) |
|
|
all_lambdas.append(lambdas) |
|
|
|
|
|
all_lambdas = torch.cat(all_lambdas, dim=1) |
|
|
return all_lambdas |
|
|
|
|
|
def evaluate_zero_shot(clip, dataloaders, classnames, args): |
|
|
""" Evaluate using zero-shot """ |
|
|
model = copy.deepcopy(clip) |
|
|
return evaluate_model(model, dataloaders[0], args) |
|
|
def evaluate_wise_ft(clip_pt, sd_pt, sd_ft, dataloaders, args): |
|
|
""" Evaluate using weight-space interpolation (WiSE-FT). """ |
|
|
model = copy.deepcopy(clip_pt) |
|
|
sd_pt = copy.deepcopy(sd_pt) |
|
|
sd_ft = copy.deepcopy(sd_ft) |
|
|
|
|
|
alpha = 0.5 |
|
|
merged_sd = {key: (alpha * sd_ft[key] + (1 - alpha) * sd_pt[key]) for key in sd_ft.keys()} |
|
|
model.load_state_dict(merged_sd) |
|
|
|
|
|
return evaluate_model(model, dataloaders[0], args) |
|
|
def evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, dataloaders, args, batch_wise=True): |
|
|
""" Evaluate using TCube (Entropy-based Weight Interpolation). """ |
|
|
|
|
|
model = copy.deepcopy(clip_pt) |
|
|
sd_pt = copy.deepcopy(sd_pt) |
|
|
sd_ft = copy.deepcopy(sd_ft) |
|
|
|
|
|
logits_final, labels_final = [], [] |
|
|
dataloader = dataloaders[0] if batch_wise else dataloaders[1] |
|
|
for i, (inputs, label) in enumerate(tqdm(dataloader)): |
|
|
inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True) |
|
|
|
|
|
merged_sd = interpolation(lambdas[:, i], sd_pt, sd_ft) |
|
|
|
|
|
model.load_state_dict(merged_sd, strict=False) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs) |
|
|
logits_final.append(outputs) |
|
|
labels_final.append(label) |
|
|
|
|
|
logits_final = torch.cat(logits_final).cuda(args.gpu, non_blocking=True) |
|
|
labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True) |
|
|
|
|
|
return compute_metrics(logits_final, labels_final) |
|
|
def evaluate_model(model, dataloader, args): |
|
|
""" Generic evaluation function for a given model. """ |
|
|
logits_final, labels_final = [], [] |
|
|
model.eval() |
|
|
for inputs, label in tqdm(dataloader): |
|
|
inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs) |
|
|
logits_final.append(outputs) |
|
|
labels_final.append(label) |
|
|
|
|
|
logits_final = torch.cat(logits_final) |
|
|
labels_final = torch.cat(labels_final) |
|
|
return compute_metrics(logits_final, labels_final) |
|
|
def compute_metrics(logits_final, labels_final): |
|
|
""" Compute Accuracy and AUC metrics. """ |
|
|
logits_final_tensor = (logits_final) |
|
|
labels_final_tensor = (labels_final) |
|
|
acc = accuracy(logits_final_tensor, labels_final_tensor) |
|
|
probs = F.softmax(logits_final_tensor, dim=1).cpu().numpy() |
|
|
labels = labels_final_tensor.view(-1).cpu().numpy() |
|
|
if probs.shape[1] > 2: |
|
|
|
|
|
unique_classes = np.unique(labels) |
|
|
n_classes = probs.shape[1] |
|
|
if len(unique_classes) < n_classes: |
|
|
|
|
|
|
|
|
auc_scores = [] |
|
|
for cls in unique_classes: |
|
|
if np.sum(labels == cls) > 0: |
|
|
|
|
|
binary_labels = (labels == cls).astype(int) |
|
|
auc_scores.append(roc_auc_score(binary_labels, probs[:, cls])) |
|
|
auc = np.mean(auc_scores) if auc_scores else 0.5 |
|
|
else: |
|
|
|
|
|
auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro') |
|
|
else: |
|
|
|
|
|
auc = roc_auc_score(labels, probs[:, 1]) |
|
|
return acc, auc*100 |
|
|
|
|
|
class CustomDataset(Dataset): |
|
|
def __init__(self, images, labels, transform=None): |
|
|
self.images = images |
|
|
self.labels = labels |
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = Image.fromarray(self.images[idx]) |
|
|
label = self.labels[idx] |
|
|
if self.transform: |
|
|
image = self.transform(image) |
|
|
return image, label |
|
|
def get_transform(args): |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
transforms.CenterCrop(args.resolution), |
|
|
transforms.Lambda(lambda image: image.convert('RGB')), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[.5], std=[.5]) |
|
|
]) |
|
|
return transform |
|
|
def get_medmnistc_dataloader(args, set_id, batch_size=32, num_workers=4, split='test', dataset=None, severity=None): |
|
|
transform = get_transform(args) |
|
|
data_root = os.path.join(args.medmnistc_data, set_id, split) |
|
|
path = os.path.join(data_root, f'{dataset}_severity_{severity}.npz') if dataset not in ["clean", None] else os.path.join(data_root, "clean.npz") |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"Dataset file not found: {path}") |
|
|
data = np.load(path) |
|
|
images = data["images"] |
|
|
labels = data["labels"].squeeze() |
|
|
dataset = CustomDataset(images, labels, transform=transform) |
|
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
|
|
def get_medimeta_dataloader(args, testset, batch_size=32, num_workers=4, split='test'): |
|
|
transform = get_transform(args) |
|
|
task_name = medimeta_testset_task_dict[testset][0].replace("_", " ") |
|
|
dataset = build_medimeta_dataset(args.medimeta_data, testset, task_name, transform) |
|
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
|
|
|
|
|
def evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set=None): |
|
|
if set_id not in results: |
|
|
results[set_id] = {} |
|
|
_dataset = test_set if test_set is not None else _dataset |
|
|
if _dataset not in results[set_id]: |
|
|
results[set_id][_dataset] = {} |
|
|
|
|
|
for severity in severities: |
|
|
print(f"\nEvaluating on _dataset: {_dataset}, severity: {severity}.....") |
|
|
if test_set is not None: |
|
|
_dataloaders = [get_medimeta_dataloader(args, test_set, batch_size=args.bs), |
|
|
get_medimeta_dataloader(args, test_set, batch_size=1)] |
|
|
else: |
|
|
_dataloaders = [get_medmnistc_dataloader(args, set_id, batch_size=args.bs, dataset=_dataset, severity=severity), |
|
|
get_medmnistc_dataloader(args, set_id, batch_size=1, dataset=_dataset, severity=severity)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambdas_tcube_MI_bmm = compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, _dataloaders[0], args, batch_wise=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plot_confidence_vs_js(ent_mi_dict['Ppt'], ent_mi_dict['Pft'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/conf_v_jsd/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') |
|
|
|
|
|
|
|
|
lambdas_dict = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
if severity not in results[set_id][_dataset]: |
|
|
results[set_id][_dataset][severity] = {} |
|
|
|
|
|
for method_type, lambdas in lambdas_dict.items(): |
|
|
print("Interpolating and evaluating on - interpolation method: ", method_type) |
|
|
global dyn_v_stat_plot |
|
|
if method_type == 'zero_shot_pt': |
|
|
acc, auc = evaluate_zero_shot(clip_pt, _dataloaders, classnames, args) |
|
|
elif method_type == 'zero_shot_ft': |
|
|
acc, auc = evaluate_zero_shot(clip_ft, _dataloaders, classnames, args) |
|
|
elif method_type == 'model_ensemble': |
|
|
acc, auc = compute_and_evaluate_model_ensemble(clip_pt, clip_ft, _dataloaders, args) |
|
|
elif method_type == 'wise_ft': |
|
|
acc, auc = evaluate_wise_ft(clip_pt, sd_pt, sd_ft, _dataloaders, args) |
|
|
elif method_type == 'slerp': |
|
|
acc, auc = evaluate_slerp(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) |
|
|
elif method_type == 't_arithmetic': |
|
|
acc, auc = evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) |
|
|
elif method_type == 'm3': |
|
|
acc, auc = evaluate_m3(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) |
|
|
elif method_type == 'tcube': |
|
|
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=args.batch_wise) |
|
|
elif method_type == 'conf': |
|
|
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False) |
|
|
elif method_type == 'tcube_MI': |
|
|
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False) |
|
|
elif method_type == 'tcube_MI_bmm': |
|
|
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=True) |
|
|
|
|
|
print(f'Accuracy: {acc[0].item():.2f}%, AUC: {auc:.2f}%, Mean: {(acc[0].item()+auc)/2:.2f}%') |
|
|
|
|
|
results[set_id][_dataset][severity][method_type] = {'accuracy': acc[0].item(), 'auc': auc, 'mean': (acc[0].item()+auc)/2} |
|
|
if method_type in method_names: |
|
|
|
|
|
|
|
|
dyn_v_stat_plot[method_type].append(acc[0].item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del _dataloaders, lambdas_dict |
|
|
gc.collect() |
|
|
|
|
|
return results |
|
|
def evaluate_on_datasets(args, datasets, default_datasets, default_severity_range): |
|
|
results = {} |
|
|
for set_id in datasets: |
|
|
print(f"\nEvaluating on dataset: {set_id}\n") |
|
|
|
|
|
for _dataset in default_datasets: |
|
|
severities = [0] if _dataset in ["clean", "medimeta"] else range(default_severity_range[0], default_severity_range[1]+1) |
|
|
|
|
|
if _dataset == "medimeta": |
|
|
test_sets = fetch_keys_for_value(medimeta_testset_task_dict, set_id) |
|
|
for test_set in test_sets: |
|
|
classnames = eval("{}_classes".format(test_set.lower())) |
|
|
clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id) |
|
|
results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set) |
|
|
else: |
|
|
classnames = eval("{}_classes".format(set_id.lower())) |
|
|
clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id) |
|
|
results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results) |
|
|
|
|
|
del clip_pt, clip_ft, sd_ft |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
def print_results(results): |
|
|
now = datetime.now() |
|
|
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") |
|
|
print(f"\nResults (Evaluated on: {formatted_date}):") |
|
|
for set_id, result in results.items(): |
|
|
print(f"\nDataset: {set_id}") |
|
|
print("=" * 75) |
|
|
print(f"{'_dataset':<20}{'Severity':<10}{'Method':<20}{'Accuracy':<10}{'AUC':<10}{'Mean':<10}") |
|
|
for _dataset, severity_dict in result.items(): |
|
|
print("=" * 75) |
|
|
for severity, metrics_dict in severity_dict.items(): |
|
|
print("-" * 80) |
|
|
for method_type, metrics in metrics_dict.items(): |
|
|
print(f"{_dataset:<20}{severity:<10}{method_type:<20}{metrics['accuracy']:<10.2f}{metrics['auc']:<10.2f}{metrics['mean']:<10.2f}") |
|
|
print("=" * 75) |
|
|
def log_results(results, args): |
|
|
now = datetime.now() |
|
|
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
os.makedirs(os.path.dirname(args.log_path), exist_ok=True) |
|
|
with open(args.log_path, 'w') as log_file: |
|
|
log_file.write(f"\nResults (Evaluated on: {formatted_date}):\n") |
|
|
log_file.write(f"Arguments:\n") |
|
|
for arg, value in vars(args).items(): |
|
|
log_file.write(f"{arg}: {value}\n") |
|
|
log_file.write("\n") |
|
|
|
|
|
for set_id, result in results.items(): |
|
|
log_file.write(f"\nDataset Group: {set_id}\n") |
|
|
log_file.write("-" * 80 + "\n") |
|
|
|
|
|
header = f"{'_dataset':<20}{'Severity':<10}{'Method':<20}" \ |
|
|
f"{'Accuracy':<15}{'AUC':<15}{'Mean':<15}\n" |
|
|
log_file.write(header) |
|
|
|
|
|
for _dataset, severity_dict in result.items(): |
|
|
for severity, metrics_dict in severity_dict.items(): |
|
|
for method_type, metrics in metrics_dict.items(): |
|
|
line = f"{_dataset:<20}{str(severity):<10}{method_type:<20}" \ |
|
|
f"{metrics['accuracy']:<15.2f}{metrics['auc']:<15.2f}{metrics['mean']:<15.2f}\n" |
|
|
log_file.write(line) |
|
|
log_file.write("-" * 80 + "\n") |
|
|
log_file.write("-" * 80 + "\n") |
|
|
def save_json_results(results, args): |
|
|
json_results = {} |
|
|
|
|
|
for set_id, result in results.items(): |
|
|
|
|
|
for dataset, severity_dict in result.items(): |
|
|
for severity, metrics_dict in severity_dict.items(): |
|
|
for method, metrics in metrics_dict.items(): |
|
|
if method not in json_results: |
|
|
json_results[method] = {} |
|
|
if dataset not in json_results[method]: |
|
|
json_results[method][dataset] = {} |
|
|
|
|
|
json_results[method][dataset][str(severity)] = { |
|
|
"accuracy": metrics["accuracy"], |
|
|
"auc": metrics["auc"], |
|
|
"mean": metrics["mean"] |
|
|
} |
|
|
|
|
|
os.makedirs(os.path.dirname(args.json_path), exist_ok=True) |
|
|
with open(args.json_path, 'w') as f: |
|
|
json.dump(json_results, f, indent=4) |
|
|
|
|
|
def main(): |
|
|
default_ft_path = [ |
|
|
'/home/raza.imam/Documents/Umaima/TPT/finetuned_models/ViT-B_16' |
|
|
] |
|
|
default_medmnistc_root = '/home/raza.imam/Documents/Umaima/TPT/MedMNIST-C' |
|
|
default_medimeta_root = '/home/raza.imam/Documents/Umaima/datasets/medimeta' |
|
|
default_testset = 'breastmnist/retinamnist/bloodmnist/octmnist' |
|
|
default_datasets = [ |
|
|
"clean", |
|
|
"medimeta", |
|
|
|
|
|
"impulse_noise", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"pixelate", |
|
|
] |
|
|
default_seed = 42 |
|
|
default_arch = 'ViT-B/16' |
|
|
default_ctx_init = 'a_photo_of_a' |
|
|
default_gpu = 1 |
|
|
default_severity_range = [5, 5] |
|
|
default_batch_wise = True |
|
|
default_offset = False |
|
|
default_lambda_mean_type = 'mean' |
|
|
default_bs = 32 |
|
|
save_time = datetime.now().strftime("%Y%m%d_%H%M") |
|
|
save_path = f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/{save_time}_{default_arch.replace("/", "_")}/' |
|
|
default_log_path = f'{save_path}log.txt' |
|
|
default_json_path = f'{save_path}dict.json' |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Multi-Model Interpolation') |
|
|
parser.add_argument('medmnistc_data', metavar='DIR', nargs="?", default=default_medmnistc_root, help='path to medmnistc dataset root') |
|
|
parser.add_argument('medimeta_data', metavar='DIR', nargs="?", default=default_medimeta_root, help='path to medimeta dataset root') |
|
|
parser.add_argument('--ft_path', type=str, default=default_ft_path[0], help='Paths to FT model state dicts') |
|
|
parser.add_argument('--log_path', type=str, default=default_log_path, help='Path to save results') |
|
|
parser.add_argument('--json_path', type=str, default=default_json_path, help='Path to save results in json format') |
|
|
parser.add_argument('--testset', type=str, default=default_testset, help='Dataset name') |
|
|
parser.add_argument('--offset', action='store_true', default=default_offset, help='Use offset for TCube') |
|
|
parser.add_argument('--lambda_mean_type', type=str, default=default_lambda_mean_type, help='Type of lambda mean for TCube') |
|
|
parser.add_argument('--batch_wise', action='store_true', default=default_batch_wise) |
|
|
parser.add_argument('--seed', type=int, default=default_seed, help='Random seed') |
|
|
parser.add_argument('-a', '--arch', metavar='ARCH', default=default_arch, help='model architecture') |
|
|
parser.add_argument('--gpu', type=int, default=default_gpu, help='GPU ID') |
|
|
parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens') |
|
|
parser.add_argument('--ctx_init', default=default_ctx_init, type=str, help='init tunable prompts') |
|
|
parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution') |
|
|
parser.add_argument('--bs', default=default_bs, type=int, help='Batch size') |
|
|
args = parser.parse_args() |
|
|
print(args) |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
datasets = args.testset.split("/") |
|
|
results = evaluate_on_datasets(args, datasets, default_datasets, default_severity_range) |
|
|
|
|
|
|
|
|
log_results(results, args) |
|
|
save_json_results(results, args) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|