import argparse import torch import torch.nn.functional as F import numpy as np import os from data.datautils import build_medmnist_dataset from torchvision import transforms from utils.tools import * from BetaMixture import BetaMixtureModel from clip.custom_clip import get_coop from data.cls_to_names import * from tqdm import tqdm from sklearn.metrics import roc_auc_score from medmnistc_data import * import copy from datetime import datetime import warnings import gc from baselines import * warnings.filterwarnings("ignore") import random random.seed(0) medimeta_testset_task_dict = { # {test_set: [task_name, medmnist ID data],...} "pbc": ["cell_class","bloodmnist"], # "aml": ["morphological_class","bloodmnist"], "mammo_mass": ["pathology","breastmnist"], # "mammo_calc": ["pathology","breastmnist"], "pneumonia": ["disease_class","pneumoniamnist"], "fundus": ["disease_presence","retinamnist"], "oct": ["disease_class","octmnist"] } method_names = { # 'zero_shot_pt': 'Zero-Shot Pretrained', # 'zero_shot_ft': 'Zero-Shot Fine-tuned', 'model_ensemble': 'Model Ensemble', 'wise_ft': 'Model Souping', 'tcube': 'Entropy-based', # 'conf': 'Confidence-based Interpolation', # 'tcube_MI': 'TCube (sample-wise)', 'tcube_MI_bmm': 'Mutual Information', } ent_mi_dict = {'entropy': [], 'mi': [], 'agreement_diff': [], 'correct_pt': [], 'correct_ft': [], 'x_entropy': []} dyn_v_stat_plot = {method: [] for method in method_names.keys()} dyn_v_stat_plot['conditions'] = [] def fetch_keys_for_value(dictionary, target_value): return [key for key, value in dictionary.items() if value[1] == target_value] # Load pt clip def load_models(args, classnames, set_id=None): clip_pt = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, classnames) sd_pt = clip_pt.state_dict() # Load ft clip if set_id in medimeta_testset_task_dict.keys(): ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{medimeta_testset_task_dict[set_id][1]}.pth') else: ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{set_id}.pth') sd_ft = torch.load(ft_path, map_location='cpu') # saved sd_ft if 'pub' in ft_path.lower(): sd_ft = sd_ft['state_dict'] clip_ft = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, state_dict=sd_ft, classnames=classnames) del sd_ft sd_ft = clip_ft.state_dict() # sd_ft and sd_pt now have same keys return clip_pt, sd_pt, clip_ft, sd_ft def get_logits(model, dataloader, args, return_feats=False, normalize=True): # model.load_state_dict(state_dict) model.eval() logits = [] labels = [] image_features = [] text_features = [] with torch.no_grad(): for inputs, label in tqdm(dataloader): inputs = inputs.cuda(args.gpu, non_blocking=True) label = label.cuda(args.gpu, non_blocking=True) if return_feats: outputs, img_feats, text_feats = model(inputs, return_logits=return_feats, normalize=normalize) image_features.append(img_feats) text_features.append(text_feats) else: outputs = model(inputs) logits.append(outputs) labels.append(label) if return_feats: return torch.cat(logits), torch.cat(labels), torch.cat(image_features), torch.cat(text_features) return torch.cat(logits), torch.cat(labels) def self_entropy(logits, temperature=0.95): logits = logits / temperature probs = torch.nn.functional.softmax(logits, dim=1) # Compute probabilities return -(probs * torch.log(probs + 1e-9)).sum(dim=1) # Compute entropy def interpolation(lambdas, sd_pt, sd_ft): merged_sd = {} for key in sd_ft.keys(): interpolated_value = sd_pt[key] * lambdas[0] + sd_ft[key] * lambdas[1] merged_sd[key] = interpolated_value return merged_sd def compute_samplewise_tcube_weights(clip_pt, clip_ft, dataloader, args): logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False) logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) ent_pt = self_entropy(logits_pt) ent_ft = self_entropy(logits_ft) expertise_pt = (-ent_pt).exp() expertise_ft = (-ent_ft).exp() total_expertise = expertise_pt + expertise_ft if args.offset: coef_bias = (ent_pt.std()/ent_pt.mean() + ent_ft.std()/ent_ft.mean()) / 2 coef_biasw = (ent_pt.mean() + ent_ft.mean()) / ent_pt.mean() lambda_ft = (expertise_ft + (coef_bias/coef_biasw)) / (total_expertise + coef_bias) else: lambda_ft = expertise_ft / total_expertise # Per sample for fine-tuned # for ent vs mi plot ---------- global ent_mi_dict # p_pt = torch.softmax(logits_pt, dim=1) # p_ft = torch.softmax(logits_ft, dim=1) # p_bar = (p_pt + p_ft) / 2.0 # average_entropy = -(p_bar * torch.log(p_bar + 1e-8)).sum(dim=1) ent_mi_dict['entropy'] = lambda_ft # ----------------------------- if args.batch_wise: batch_size = len(dataloader.dataset) // len(dataloader) num_batches = len(dataloader) if True: # if args.lambda_mean_type == 'mean': # perform batch-wise mean # lambda_ft_batchwise = lambda_ft[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1) # lambda_pt = 1 - lambda_ft # lambda_pt_batchwise = lambda_pt[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1) # return torch.stack([lambda_pt_batchwise, lambda_ft_batchwise], dim=0) # Shape: (2, num_batches) # elif args.lambda_mean_type == 'bmm': lambda_ft_bmm = [] lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1) bmm = BetaMixtureModel(n_mixtures=num_batches) bmm.fit(lambda_ft_np) for i in range(bmm.n_mixtures): # n_mixtures = num_batches a,b = bmm.beta_params_[i, 0],bmm.beta_params_[i, 1] # print(f'beta means of {i}th cluster: {a/(a+b):.3f}') lambda_ft_bmm.append(a/(a+b)) lambda_ft_bmm = torch.tensor(lambda_ft_bmm) lambda_pt = 1 - lambda_ft_bmm return torch.stack([lambda_pt, lambda_ft_bmm], dim=0) # Shape: (2, num_batches) coefs_label = bmm.predict(lambda_ft_np) lambda_pt = 1 - lambda_ft return torch.stack([lambda_pt, lambda_ft]) # Shape: (2, num_samples) def compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, dataloader, args, delta=0.5, batch_wise=True): # Get logits from both models for all test samples logits_pt, labels = get_logits(clip_pt, dataloader, args, return_feats=False) logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) # Compute the probability distributions for each sample from both models p_pt = torch.softmax(logits_pt, dim=1) p_ft = torch.softmax(logits_ft, dim=1) pred_pt = p_pt.argmax(dim=1) pred_ft = p_ft.argmax(dim=1) correct_pt = pred_pt.eq(labels.squeeze()) correct_ft = pred_ft.eq(labels.squeeze()) # Compute the average predictive distribution (consensus) p_bar = (p_pt + p_ft) / 2.0 # Compute the KL divergence for each model with respect to the average distribution. # Summing over the class dimension yields a per-sample value. kl_pt = torch.sum(p_pt * torch.log(p_pt / (p_bar + 1e-8)), dim=1) kl_ft = torch.sum(p_ft * torch.log(p_ft / (p_bar + 1e-8)), dim=1) # Compute mutual information (MI) as the average of the two KL divergences per sample MI = 0.5 * (kl_pt + kl_ft) MI_orig = MI # Map MI to an interpolation coefficient (lambda) using a sigmoid function. # Values for lam_min, lam_max, and gamma are retrieved from args. lam_min = args.lam_min if hasattr(args, 'lam_min') else 0.01 lam_max = args.lam_max if hasattr(args, 'lam_max') else 0.99 gamma = args.gamma if hasattr(args, 'gamma') else 0.5 lambda_ft = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI) lambda_plot = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI_orig) # Compute entropy as uncertainty measure for both models ent_pt = self_entropy(logits_pt) # Pretrained uncertainty ent_ft = self_entropy(logits_ft) # Fine-tuned uncertainty # Set thresholds for extreme confidence for each model (default values; adjust as needed) entropy_thresh_ft = getattr(args, 'entropy_thresh_ft', 0.05) entropy_thresh_pt = getattr(args, 'entropy_thresh_pt', 0.65) delta_extrap = delta # Extrapolation factor # If fine-tuned model is extremely confident, push lambda_ft upward; # if pretrained model is extremely confident, push lambda_ft downward. lambda_ft = torch.where( ent_ft < entropy_thresh_ft, # if fine‐tuned is very confident, bump *its* current weight up by delta_extrap lambda_ft + delta_extrap, torch.where( ent_pt < entropy_thresh_pt, # if pretrained is very confident, push down the fine‐tuned weight lambda_ft - delta_extrap, # otherwise keep the MI‐computed value lambda_ft ) ) # Clamp lambda_ft in a reasonable range (allow extrapolation above 1 up to 1.5; below 0 is possible) lambda_ft = torch.clamp(lambda_ft, 0.0, 1.5) lambda_pt = 1 - lambda_ft # Note: if lambda_ft > 1, lambda_pt becomes negative # for ent vs mi plot ---------- global ent_mi_dict # alpha = 0.15 # Small influence; tune this! # alpha = np.random.uniform(0.5, 0.85) # MI = MI - alpha * ent_mi_dict['entropy'] # ent_mi_dict['mi'] = lambda_ft ent_mi_dict['mi'] = MI # agreement_diff = torch.norm(p_ft - p_pt, p=1, dim=1) # ent_mi_dict['agreement_diff'] = agreement_diff ent_mi_dict['Ppt'] = p_pt ent_mi_dict['Pft'] = p_ft ent_mi_dict['correct_pt'] = correct_pt ent_mi_dict['correct_ft'] = correct_ft ce_pt = F.cross_entropy(logits_pt, labels.squeeze(), reduction='none') # Cross-entropy for pretrained model ce_ft = F.cross_entropy(logits_ft, labels.squeeze(), reduction='none') # Cross-entropy for fine-tuned model x_entropy_ratio = ce_ft / (ce_pt + ce_ft + 1e-9) # Avoid division by zero ent_mi_dict['x_entropy'] = x_entropy_ratio ent_mi_dict['CE_pt'] = ce_pt ent_mi_dict['CE_ft'] = ce_ft # ----------------------------- # Batch-wise averaging (if enabled) is handled similarly to the entropy-based version. if batch_wise: batch_size = len(dataloader.dataset) // len(dataloader) num_batches = len(dataloader) if args.lambda_mean_type == 'mean': # Batch-wise mean lambda_ft_batchwise = lambda_ft[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1) lambda_pt_batchwise = 1 - lambda_ft_batchwise return torch.stack([lambda_pt_batchwise, lambda_ft_batchwise], dim=0) elif args.lambda_mean_type == 'bmm': # When using a Beta Mixture Model to cluster lambda values lambda_ft_bmm = [] lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1) bmm = BetaMixtureModel(n_mixtures=num_batches) bmm.fit(lambda_ft_np) for i in range(bmm.n_mixtures): a, b = bmm.beta_params_[i, 0], bmm.beta_params_[i, 1] # print(f'Beta mean of {i}th cluster: {a/(a+b):.3f}') lambda_ft_bmm.append(a/(a+b)) lambda_ft_bmm = torch.tensor(lambda_ft_bmm) lambda_pt_bmm = 1 - lambda_ft_bmm return torch.stack([lambda_pt_bmm, lambda_ft_bmm], dim=0) # lambda_pt = 1 - lambda_ft return torch.stack([lambda_pt, lambda_ft]), lambda_plot def compute_and_evaluate_model_ensemble(clip_pt, clip_ft, dataloaders, args): logits_pt, _ = get_logits(clip_pt, dataloaders[0], args, return_feats=False, normalize=False) logits_ft, _ = get_logits(clip_ft, dataloaders[0], args, return_feats=False, normalize=False) logits_final = (logits_pt + logits_ft) / 2.0 labels_final = [] for _, label in tqdm(dataloaders[0]): labels_final.append(label) labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True) return compute_metrics(logits_final, labels_final) def compute_samplewise_conf_weights(clip_pt, clip_ft, dataloader, device="cuda"): clip_pt.to(device).eval() clip_ft.to(device).eval() all_lambdas = [] with torch.no_grad(): for images, _ in dataloader: # We only need inputs, not labels images = images.to(device) # Get model outputs (logits) logits_pt = clip_pt(images) logits_ft = clip_ft(images) # Convert logits to confidence scores (softmax) conf_pt = F.softmax(logits_pt, dim=1).max(dim=1)[0] # Max confidence per sample conf_ft = F.softmax(logits_ft, dim=1).max(dim=1)[0] # Stack confidence scores conf_stack = torch.stack([conf_pt, conf_ft], dim=0) # Shape: (num_models, batch_size) # Normalize confidence scores to get lambdas lambdas = conf_stack / conf_stack.sum(dim=0, keepdim=True) # Ensures sum=1 for each sample all_lambdas.append(lambdas) # Concatenate results across all batches all_lambdas = torch.cat(all_lambdas, dim=1) # Shape: (num_models, num_samples) return all_lambdas # First row is pre-trained CLIP, second row is fine-tuned model def evaluate_zero_shot(clip, dataloaders, classnames, args): """ Evaluate using zero-shot """ model = copy.deepcopy(clip) return evaluate_model(model, dataloaders[0], args) def evaluate_wise_ft(clip_pt, sd_pt, sd_ft, dataloaders, args): """ Evaluate using weight-space interpolation (WiSE-FT). """ model = copy.deepcopy(clip_pt) sd_pt = copy.deepcopy(sd_pt) sd_ft = copy.deepcopy(sd_ft) alpha = 0.5 merged_sd = {key: (alpha * sd_ft[key] + (1 - alpha) * sd_pt[key]) for key in sd_ft.keys()} model.load_state_dict(merged_sd) return evaluate_model(model, dataloaders[0], args) def evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, dataloaders, args, batch_wise=True): """ Evaluate using TCube (Entropy-based Weight Interpolation). """ # original_sd = copy.deepcopy(clip_pt.state_dict()) # Store original model state model = copy.deepcopy(clip_pt) # Use reference to avoid deepcopy sd_pt = copy.deepcopy(sd_pt) sd_ft = copy.deepcopy(sd_ft) logits_final, labels_final = [], [] dataloader = dataloaders[0] if batch_wise else dataloaders[1] for i, (inputs, label) in enumerate(tqdm(dataloader)): inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True) merged_sd = interpolation(lambdas[:, i], sd_pt, sd_ft) model.load_state_dict(merged_sd, strict=False) # Load interpolated weights model.eval() with torch.no_grad(): outputs = model(inputs) logits_final.append(outputs) labels_final.append(label) logits_final = torch.cat(logits_final).cuda(args.gpu, non_blocking=True) labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True) return compute_metrics(logits_final, labels_final) def evaluate_model(model, dataloader, args): """ Generic evaluation function for a given model. """ logits_final, labels_final = [], [] model.eval() for inputs, label in tqdm(dataloader): inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True) with torch.no_grad(): outputs = model(inputs) logits_final.append(outputs) labels_final.append(label) logits_final = torch.cat(logits_final) labels_final = torch.cat(labels_final) return compute_metrics(logits_final, labels_final) def compute_metrics(logits_final, labels_final): """ Compute Accuracy and AUC metrics. """ logits_final_tensor = (logits_final) labels_final_tensor = (labels_final) acc = accuracy(logits_final_tensor, labels_final_tensor) probs = F.softmax(logits_final_tensor, dim=1).cpu().numpy() labels = labels_final_tensor.view(-1).cpu().numpy() if probs.shape[1] > 2: # Check if all classes are present in the labels unique_classes = np.unique(labels) n_classes = probs.shape[1] if len(unique_classes) < n_classes: # Not all classes are present in the test set # Calculate AUC only for present classes auc_scores = [] for cls in unique_classes: if np.sum(labels == cls) > 0: # Ensure class has samples # Binary classification: current class vs rest binary_labels = (labels == cls).astype(int) auc_scores.append(roc_auc_score(binary_labels, probs[:, cls])) auc = np.mean(auc_scores) if auc_scores else 0.5 # Default to 0.5 if no valid scores else: # All classes are present, use standard OvR auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro') else: # For binary classification auc = roc_auc_score(labels, probs[:, 1]) return acc, auc*100 class CustomDataset(Dataset): def __init__(self, images, labels, transform=None): self.images = images self.labels = labels self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = Image.fromarray(self.images[idx]) label = self.labels[idx] if self.transform: image = self.transform(image) return image, label def get_transform(args): transform = transforms.Compose([ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(args.resolution), transforms.Lambda(lambda image: image.convert('RGB')), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5]) ]) return transform def get_medmnistc_dataloader(args, set_id, batch_size=32, num_workers=4, split='test', dataset=None, severity=None): transform = get_transform(args) data_root = os.path.join(args.medmnistc_data, set_id, split) path = os.path.join(data_root, f'{dataset}_severity_{severity}.npz') if dataset not in ["clean", None] else os.path.join(data_root, "clean.npz") if not os.path.exists(path): raise FileNotFoundError(f"Dataset file not found: {path}") data = np.load(path) images = data["images"] labels = data["labels"].squeeze() dataset = CustomDataset(images, labels, transform=transform) return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) def get_medimeta_dataloader(args, testset, batch_size=32, num_workers=4, split='test'): transform = get_transform(args) task_name = medimeta_testset_task_dict[testset][0].replace("_", " ") dataset = build_medimeta_dataset(args.medimeta_data, testset, task_name, transform) return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) def evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set=None): if set_id not in results: results[set_id] = {} _dataset = test_set if test_set is not None else _dataset if _dataset not in results[set_id]: results[set_id][_dataset] = {} for severity in severities: print(f"\nEvaluating on _dataset: {_dataset}, severity: {severity}.....") if test_set is not None: _dataloaders = [get_medimeta_dataloader(args, test_set, batch_size=args.bs), get_medimeta_dataloader(args, test_set, batch_size=1)] else: _dataloaders = [get_medmnistc_dataloader(args, set_id, batch_size=args.bs, dataset=_dataset, severity=severity), get_medmnistc_dataloader(args, set_id, batch_size=1, dataset=_dataset, severity=severity)] # lambdas_tcube = compute_samplewise_tcube_weights(clip_pt, clip_ft, _dataloaders[0], args) # lambdas_conf = compute_samplewise_conf_weights(clip_pt, clip_ft, _dataloaders[0], args.gpu) # lambdas_tcube_MI, lambdas_tcube_plot = compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, _dataloaders[0], args, batch_wise=False) lambdas_tcube_MI_bmm = compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, _dataloaders[0], args, batch_wise=True) # for ent vs mi plot ---------- # entropies = ent_mi_dict['entropy'] # MIs = ent_mi_dict['mi'] # correct_pt = ent_mi_dict['correct_pt'] # correct_ft = ent_mi_dict['correct_ft'] # x_entropy = ent_mi_dict['x_entropy'] # plot_entropy_vs_mi(entropies, MIs, agreement_diff=ent_mi_dict['agreement_diff'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/mi_v_ent2/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') # plot_entropy_vs_mi_by_correctness(entropies, MIs, correct_pt, correct_ft, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/mi_v_ent_by_correctness/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') # plot_Xentropy_vs_mi_by_correctness(x_entropy, MIs, correct_pt, correct_ft, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/xent_v_mi_by_correctness/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') # plot_xentropy_vs_mi_entire(x_entropy, MIs, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/xent_v_mi_entire/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') # plot_stacked_ce_vs_mi_bins(MIs, ent_mi_dict['CE_pt'], ent_mi_dict['CE_ft'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/ce_v_mi_bins/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') # plot_ce_vs_mi_by_correctness(ent_mi_dict['CE_pt'], ent_mi_dict['CE_ft'], MIs, correct_pt, correct_ft, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/ce_v_mi_by_correctness/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') plot_confidence_vs_js(ent_mi_dict['Ppt'], ent_mi_dict['Pft'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/conf_v_jsd/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png') # ----------------------------- lambdas_dict = { # 'zero_shot_pt': None, # 'zero_shot_ft': None, # 'model_ensemble': None, # 'wise_ft': None, # 'slerp': None, # 't_arithmetic': None, # 'm3': None, # 'tcube': lambdas_tcube, # # 'conf': lambdas_conf, # 'tcube_MI': lambdas_tcube_MI, # 'tcube_MI_bmm': lambdas_tcube_MI_bmm, } if severity not in results[set_id][_dataset]: results[set_id][_dataset][severity] = {} for method_type, lambdas in lambdas_dict.items(): print("Interpolating and evaluating on - interpolation method: ", method_type) global dyn_v_stat_plot if method_type == 'zero_shot_pt': acc, auc = evaluate_zero_shot(clip_pt, _dataloaders, classnames, args) elif method_type == 'zero_shot_ft': acc, auc = evaluate_zero_shot(clip_ft, _dataloaders, classnames, args) elif method_type == 'model_ensemble': acc, auc = compute_and_evaluate_model_ensemble(clip_pt, clip_ft, _dataloaders, args) elif method_type == 'wise_ft': acc, auc = evaluate_wise_ft(clip_pt, sd_pt, sd_ft, _dataloaders, args) elif method_type == 'slerp': acc, auc = evaluate_slerp(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) elif method_type == 't_arithmetic': acc, auc = evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) elif method_type == 'm3': acc, auc = evaluate_m3(clip_pt, sd_pt, sd_ft, _dataloaders[0], args) elif method_type == 'tcube': acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=args.batch_wise) elif method_type == 'conf': acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False) elif method_type == 'tcube_MI': acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False) elif method_type == 'tcube_MI_bmm': acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=True) print(f'Accuracy: {acc[0].item():.2f}%, AUC: {auc:.2f}%, Mean: {(acc[0].item()+auc)/2:.2f}%') results[set_id][_dataset][severity][method_type] = {'accuracy': acc[0].item(), 'auc': auc, 'mean': (acc[0].item()+auc)/2} if method_type in method_names: # if method_names[method_type] in dyn_v_stat_plot: # dyn_v_stat_plot[method_names[method_type]] = [] dyn_v_stat_plot[method_type].append(acc[0].item()) # for lambda histogram ------------------------------------- # dyn_v_stat_plot['conditions'].append(f'{set_id}_{_dataset}') # lambdas_dict_plot = {} # lambdas_dict_plot[_dataset] = lambdas_tcube[1] # plot_lambda_histogram(lambdas_dict_plot, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/lambda_histogram_ER/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}.png') # ------------------------------------------------------------ del _dataloaders, lambdas_dict gc.collect() return results def evaluate_on_datasets(args, datasets, default_datasets, default_severity_range): results = {} for set_id in datasets: print(f"\nEvaluating on dataset: {set_id}\n") for _dataset in default_datasets: severities = [0] if _dataset in ["clean", "medimeta"] else range(default_severity_range[0], default_severity_range[1]+1) if _dataset == "medimeta": test_sets = fetch_keys_for_value(medimeta_testset_task_dict, set_id) for test_set in test_sets: classnames = eval("{}_classes".format(test_set.lower())) clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id) results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set) else: classnames = eval("{}_classes".format(set_id.lower())) clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id) results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results) del clip_pt, clip_ft, sd_ft # try: # plot_delta_performance(dyn_v_stat_plot, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/dynamic_vs_static/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}.png') # except Exception as e: # print(f"An error occurred while plotting delta performance: {e}") # pass return results def print_results(results): now = datetime.now() formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") print(f"\nResults (Evaluated on: {formatted_date}):") for set_id, result in results.items(): print(f"\nDataset: {set_id}") print("=" * 75) print(f"{'_dataset':<20}{'Severity':<10}{'Method':<20}{'Accuracy':<10}{'AUC':<10}{'Mean':<10}") for _dataset, severity_dict in result.items(): print("=" * 75) for severity, metrics_dict in severity_dict.items(): print("-" * 80) for method_type, metrics in metrics_dict.items(): print(f"{_dataset:<20}{severity:<10}{method_type:<20}{metrics['accuracy']:<10.2f}{metrics['auc']:<10.2f}{metrics['mean']:<10.2f}") print("=" * 75) def log_results(results, args): now = datetime.now() formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") os.makedirs(os.path.dirname(args.log_path), exist_ok=True) with open(args.log_path, 'w') as log_file: log_file.write(f"\nResults (Evaluated on: {formatted_date}):\n") log_file.write(f"Arguments:\n") for arg, value in vars(args).items(): log_file.write(f"{arg}: {value}\n") log_file.write("\n") for set_id, result in results.items(): log_file.write(f"\nDataset Group: {set_id}\n") log_file.write("-" * 80 + "\n") # Write header including severity header = f"{'_dataset':<20}{'Severity':<10}{'Method':<20}" \ f"{'Accuracy':<15}{'AUC':<15}{'Mean':<15}\n" log_file.write(header) for _dataset, severity_dict in result.items(): for severity, metrics_dict in severity_dict.items(): for method_type, metrics in metrics_dict.items(): line = f"{_dataset:<20}{str(severity):<10}{method_type:<20}" \ f"{metrics['accuracy']:<15.2f}{metrics['auc']:<15.2f}{metrics['mean']:<15.2f}\n" log_file.write(line) log_file.write("-" * 80 + "\n") log_file.write("-" * 80 + "\n") def save_json_results(results, args): json_results = {} for set_id, result in results.items(): # For each set_id (e.g., modality grouping), iterate through datasets for dataset, severity_dict in result.items(): for severity, metrics_dict in severity_dict.items(): for method, metrics in metrics_dict.items(): if method not in json_results: json_results[method] = {} if dataset not in json_results[method]: json_results[method][dataset] = {} # Store results for each severity as is (no averaging) json_results[method][dataset][str(severity)] = { "accuracy": metrics["accuracy"], "auc": metrics["auc"], "mean": metrics["mean"] } os.makedirs(os.path.dirname(args.json_path), exist_ok=True) with open(args.json_path, 'w') as f: json.dump(json_results, f, indent=4) def main(): default_ft_path = [ '/home/raza.imam/Documents/Umaima/TPT/finetuned_models/ViT-B_16' ] default_medmnistc_root = '/home/raza.imam/Documents/Umaima/TPT/MedMNIST-C' default_medimeta_root = '/home/raza.imam/Documents/Umaima/datasets/medimeta' default_testset = 'breastmnist/retinamnist/bloodmnist/octmnist' # 'breastmnist/retinamnist/bloodmnist/pneumoniamnist/octmnist' default_datasets = [ "clean", "medimeta", # "gaussian_noise", "impulse_noise", # "motion_blur", # "zoom_blur", # "brightness", # "contrast", "pixelate", ] default_seed = 42 default_arch = 'ViT-B/16' default_ctx_init = 'a_photo_of_a' default_gpu = 1 default_severity_range = [5, 5] # min 1 and max 5 is allowed default_batch_wise = True default_offset = False default_lambda_mean_type = 'mean' default_bs = 32 save_time = datetime.now().strftime("%Y%m%d_%H%M") save_path = f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/{save_time}_{default_arch.replace("/", "_")}/' default_log_path = f'{save_path}log.txt' default_json_path = f'{save_path}dict.json' parser = argparse.ArgumentParser(description='Multi-Model Interpolation') parser.add_argument('medmnistc_data', metavar='DIR', nargs="?", default=default_medmnistc_root, help='path to medmnistc dataset root') parser.add_argument('medimeta_data', metavar='DIR', nargs="?", default=default_medimeta_root, help='path to medimeta dataset root') parser.add_argument('--ft_path', type=str, default=default_ft_path[0], help='Paths to FT model state dicts') parser.add_argument('--log_path', type=str, default=default_log_path, help='Path to save results') parser.add_argument('--json_path', type=str, default=default_json_path, help='Path to save results in json format') parser.add_argument('--testset', type=str, default=default_testset, help='Dataset name') parser.add_argument('--offset', action='store_true', default=default_offset, help='Use offset for TCube') parser.add_argument('--lambda_mean_type', type=str, default=default_lambda_mean_type, help='Type of lambda mean for TCube') parser.add_argument('--batch_wise', action='store_true', default=default_batch_wise) parser.add_argument('--seed', type=int, default=default_seed, help='Random seed') parser.add_argument('-a', '--arch', metavar='ARCH', default=default_arch, help='model architecture') parser.add_argument('--gpu', type=int, default=default_gpu, help='GPU ID') parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens') parser.add_argument('--ctx_init', default=default_ctx_init, type=str, help='init tunable prompts') parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution') parser.add_argument('--bs', default=default_bs, type=int, help='Batch size') args = parser.parse_args() print(args) torch.manual_seed(args.seed) datasets = args.testset.split("/") results = evaluate_on_datasets(args, datasets, default_datasets, default_severity_range) # print_results(results) log_results(results, args) save_json_results(results, args) if __name__ == '__main__': main()