TCube_Merging / t_cube.py
razaimam45's picture
Upload 108 files
a96891a verified
import argparse
import torch
import torch.nn.functional as F
import numpy as np
import os
from data.datautils import build_medmnist_dataset
from torchvision import transforms
from utils.tools import *
from BetaMixture import BetaMixtureModel
from clip.custom_clip import get_coop
from data.cls_to_names import *
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from medmnistc_data import *
import copy
from datetime import datetime
import warnings
import gc
from baselines import *
warnings.filterwarnings("ignore")
import random
random.seed(0)
medimeta_testset_task_dict = {
# {test_set: [task_name, medmnist ID data],...}
"pbc": ["cell_class","bloodmnist"],
# "aml": ["morphological_class","bloodmnist"],
"mammo_mass": ["pathology","breastmnist"],
# "mammo_calc": ["pathology","breastmnist"],
"pneumonia": ["disease_class","pneumoniamnist"],
"fundus": ["disease_presence","retinamnist"],
"oct": ["disease_class","octmnist"]
}
method_names = {
# 'zero_shot_pt': 'Zero-Shot Pretrained',
# 'zero_shot_ft': 'Zero-Shot Fine-tuned',
'model_ensemble': 'Model Ensemble',
'wise_ft': 'Model Souping',
'tcube': 'Entropy-based',
# 'conf': 'Confidence-based Interpolation',
# 'tcube_MI': 'TCube (sample-wise)',
'tcube_MI_bmm': 'Mutual Information',
}
ent_mi_dict = {'entropy': [], 'mi': [], 'agreement_diff': [], 'correct_pt': [], 'correct_ft': [], 'x_entropy': []}
dyn_v_stat_plot = {method: [] for method in method_names.keys()}
dyn_v_stat_plot['conditions'] = []
def fetch_keys_for_value(dictionary, target_value):
return [key for key, value in dictionary.items() if value[1] == target_value]
# Load pt clip
def load_models(args, classnames, set_id=None):
clip_pt = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, classnames)
sd_pt = clip_pt.state_dict()
# Load ft clip
if set_id in medimeta_testset_task_dict.keys():
ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{medimeta_testset_task_dict[set_id][1]}.pth')
else:
ft_path = os.path.join(args.ft_path, f'fine_tuned_clip_{set_id}.pth')
sd_ft = torch.load(ft_path, map_location='cpu') # saved sd_ft
if 'pub' in ft_path.lower():
sd_ft = sd_ft['state_dict']
clip_ft = get_coop(args.arch, None, args.gpu, args.n_ctx, args.ctx_init, state_dict=sd_ft, classnames=classnames)
del sd_ft
sd_ft = clip_ft.state_dict() # sd_ft and sd_pt now have same keys
return clip_pt, sd_pt, clip_ft, sd_ft
def get_logits(model, dataloader, args, return_feats=False, normalize=True):
# model.load_state_dict(state_dict)
model.eval()
logits = []
labels = []
image_features = []
text_features = []
with torch.no_grad():
for inputs, label in tqdm(dataloader):
inputs = inputs.cuda(args.gpu, non_blocking=True)
label = label.cuda(args.gpu, non_blocking=True)
if return_feats:
outputs, img_feats, text_feats = model(inputs, return_logits=return_feats, normalize=normalize)
image_features.append(img_feats)
text_features.append(text_feats)
else:
outputs = model(inputs)
logits.append(outputs)
labels.append(label)
if return_feats:
return torch.cat(logits), torch.cat(labels), torch.cat(image_features), torch.cat(text_features)
return torch.cat(logits), torch.cat(labels)
def self_entropy(logits, temperature=0.95):
logits = logits / temperature
probs = torch.nn.functional.softmax(logits, dim=1) # Compute probabilities
return -(probs * torch.log(probs + 1e-9)).sum(dim=1) # Compute entropy
def interpolation(lambdas, sd_pt, sd_ft):
merged_sd = {}
for key in sd_ft.keys():
interpolated_value = sd_pt[key] * lambdas[0] + sd_ft[key] * lambdas[1]
merged_sd[key] = interpolated_value
return merged_sd
def compute_samplewise_tcube_weights(clip_pt, clip_ft, dataloader, args):
logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False)
logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False)
ent_pt = self_entropy(logits_pt)
ent_ft = self_entropy(logits_ft)
expertise_pt = (-ent_pt).exp()
expertise_ft = (-ent_ft).exp()
total_expertise = expertise_pt + expertise_ft
if args.offset:
coef_bias = (ent_pt.std()/ent_pt.mean() + ent_ft.std()/ent_ft.mean()) / 2
coef_biasw = (ent_pt.mean() + ent_ft.mean()) / ent_pt.mean()
lambda_ft = (expertise_ft + (coef_bias/coef_biasw)) / (total_expertise + coef_bias)
else:
lambda_ft = expertise_ft / total_expertise # Per sample for fine-tuned
# for ent vs mi plot ----------
global ent_mi_dict
# p_pt = torch.softmax(logits_pt, dim=1)
# p_ft = torch.softmax(logits_ft, dim=1)
# p_bar = (p_pt + p_ft) / 2.0
# average_entropy = -(p_bar * torch.log(p_bar + 1e-8)).sum(dim=1)
ent_mi_dict['entropy'] = lambda_ft
# -----------------------------
if args.batch_wise:
batch_size = len(dataloader.dataset) // len(dataloader)
num_batches = len(dataloader)
if True:
# if args.lambda_mean_type == 'mean': # perform batch-wise mean
# lambda_ft_batchwise = lambda_ft[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1)
# lambda_pt = 1 - lambda_ft
# lambda_pt_batchwise = lambda_pt[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1)
# return torch.stack([lambda_pt_batchwise, lambda_ft_batchwise], dim=0) # Shape: (2, num_batches)
# elif args.lambda_mean_type == 'bmm':
lambda_ft_bmm = []
lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1)
bmm = BetaMixtureModel(n_mixtures=num_batches)
bmm.fit(lambda_ft_np)
for i in range(bmm.n_mixtures): # n_mixtures = num_batches
a,b = bmm.beta_params_[i, 0],bmm.beta_params_[i, 1]
# print(f'beta means of {i}th cluster: {a/(a+b):.3f}')
lambda_ft_bmm.append(a/(a+b))
lambda_ft_bmm = torch.tensor(lambda_ft_bmm)
lambda_pt = 1 - lambda_ft_bmm
return torch.stack([lambda_pt, lambda_ft_bmm], dim=0) # Shape: (2, num_batches)
coefs_label = bmm.predict(lambda_ft_np)
lambda_pt = 1 - lambda_ft
return torch.stack([lambda_pt, lambda_ft]) # Shape: (2, num_samples)
def compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, dataloader, args, delta=0.5, batch_wise=True):
# Get logits from both models for all test samples
logits_pt, labels = get_logits(clip_pt, dataloader, args, return_feats=False)
logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False)
# Compute the probability distributions for each sample from both models
p_pt = torch.softmax(logits_pt, dim=1)
p_ft = torch.softmax(logits_ft, dim=1)
pred_pt = p_pt.argmax(dim=1)
pred_ft = p_ft.argmax(dim=1)
correct_pt = pred_pt.eq(labels.squeeze())
correct_ft = pred_ft.eq(labels.squeeze())
# Compute the average predictive distribution (consensus)
p_bar = (p_pt + p_ft) / 2.0
# Compute the KL divergence for each model with respect to the average distribution.
# Summing over the class dimension yields a per-sample value.
kl_pt = torch.sum(p_pt * torch.log(p_pt / (p_bar + 1e-8)), dim=1)
kl_ft = torch.sum(p_ft * torch.log(p_ft / (p_bar + 1e-8)), dim=1)
# Compute mutual information (MI) as the average of the two KL divergences per sample
MI = 0.5 * (kl_pt + kl_ft)
MI_orig = MI
# Map MI to an interpolation coefficient (lambda) using a sigmoid function.
# Values for lam_min, lam_max, and gamma are retrieved from args.
lam_min = args.lam_min if hasattr(args, 'lam_min') else 0.01
lam_max = args.lam_max if hasattr(args, 'lam_max') else 0.99
gamma = args.gamma if hasattr(args, 'gamma') else 0.5
lambda_ft = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI)
lambda_plot = lam_min + (lam_max - lam_min) * torch.sigmoid(gamma * MI_orig)
# Compute entropy as uncertainty measure for both models
ent_pt = self_entropy(logits_pt) # Pretrained uncertainty
ent_ft = self_entropy(logits_ft) # Fine-tuned uncertainty
# Set thresholds for extreme confidence for each model (default values; adjust as needed)
entropy_thresh_ft = getattr(args, 'entropy_thresh_ft', 0.05)
entropy_thresh_pt = getattr(args, 'entropy_thresh_pt', 0.65)
delta_extrap = delta # Extrapolation factor
# If fine-tuned model is extremely confident, push lambda_ft upward;
# if pretrained model is extremely confident, push lambda_ft downward.
lambda_ft = torch.where(
ent_ft < entropy_thresh_ft,
# if fine‐tuned is very confident, bump *its* current weight up by delta_extrap
lambda_ft + delta_extrap,
torch.where(
ent_pt < entropy_thresh_pt,
# if pretrained is very confident, push down the fine‐tuned weight
lambda_ft - delta_extrap,
# otherwise keep the MI‐computed value
lambda_ft
)
)
# Clamp lambda_ft in a reasonable range (allow extrapolation above 1 up to 1.5; below 0 is possible)
lambda_ft = torch.clamp(lambda_ft, 0.0, 1.5)
lambda_pt = 1 - lambda_ft # Note: if lambda_ft > 1, lambda_pt becomes negative
# for ent vs mi plot ----------
global ent_mi_dict
# alpha = 0.15 # Small influence; tune this!
# alpha = np.random.uniform(0.5, 0.85)
# MI = MI - alpha * ent_mi_dict['entropy']
# ent_mi_dict['mi'] = lambda_ft
ent_mi_dict['mi'] = MI
# agreement_diff = torch.norm(p_ft - p_pt, p=1, dim=1)
# ent_mi_dict['agreement_diff'] = agreement_diff
ent_mi_dict['Ppt'] = p_pt
ent_mi_dict['Pft'] = p_ft
ent_mi_dict['correct_pt'] = correct_pt
ent_mi_dict['correct_ft'] = correct_ft
ce_pt = F.cross_entropy(logits_pt, labels.squeeze(), reduction='none') # Cross-entropy for pretrained model
ce_ft = F.cross_entropy(logits_ft, labels.squeeze(), reduction='none') # Cross-entropy for fine-tuned model
x_entropy_ratio = ce_ft / (ce_pt + ce_ft + 1e-9) # Avoid division by zero
ent_mi_dict['x_entropy'] = x_entropy_ratio
ent_mi_dict['CE_pt'] = ce_pt
ent_mi_dict['CE_ft'] = ce_ft
# -----------------------------
# Batch-wise averaging (if enabled) is handled similarly to the entropy-based version.
if batch_wise:
batch_size = len(dataloader.dataset) // len(dataloader)
num_batches = len(dataloader)
if args.lambda_mean_type == 'mean': # Batch-wise mean
lambda_ft_batchwise = lambda_ft[:num_batches * batch_size].view(num_batches, batch_size).mean(dim=1)
lambda_pt_batchwise = 1 - lambda_ft_batchwise
return torch.stack([lambda_pt_batchwise, lambda_ft_batchwise], dim=0)
elif args.lambda_mean_type == 'bmm':
# When using a Beta Mixture Model to cluster lambda values
lambda_ft_bmm = []
lambda_ft_np = lambda_ft.cpu().numpy().reshape(-1,1)
bmm = BetaMixtureModel(n_mixtures=num_batches)
bmm.fit(lambda_ft_np)
for i in range(bmm.n_mixtures):
a, b = bmm.beta_params_[i, 0], bmm.beta_params_[i, 1]
# print(f'Beta mean of {i}th cluster: {a/(a+b):.3f}')
lambda_ft_bmm.append(a/(a+b))
lambda_ft_bmm = torch.tensor(lambda_ft_bmm)
lambda_pt_bmm = 1 - lambda_ft_bmm
return torch.stack([lambda_pt_bmm, lambda_ft_bmm], dim=0)
# lambda_pt = 1 - lambda_ft
return torch.stack([lambda_pt, lambda_ft]), lambda_plot
def compute_and_evaluate_model_ensemble(clip_pt, clip_ft, dataloaders, args):
logits_pt, _ = get_logits(clip_pt, dataloaders[0], args, return_feats=False, normalize=False)
logits_ft, _ = get_logits(clip_ft, dataloaders[0], args, return_feats=False, normalize=False)
logits_final = (logits_pt + logits_ft) / 2.0
labels_final = []
for _, label in tqdm(dataloaders[0]):
labels_final.append(label)
labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True)
return compute_metrics(logits_final, labels_final)
def compute_samplewise_conf_weights(clip_pt, clip_ft, dataloader, device="cuda"):
clip_pt.to(device).eval()
clip_ft.to(device).eval()
all_lambdas = []
with torch.no_grad():
for images, _ in dataloader: # We only need inputs, not labels
images = images.to(device)
# Get model outputs (logits)
logits_pt = clip_pt(images)
logits_ft = clip_ft(images)
# Convert logits to confidence scores (softmax)
conf_pt = F.softmax(logits_pt, dim=1).max(dim=1)[0] # Max confidence per sample
conf_ft = F.softmax(logits_ft, dim=1).max(dim=1)[0]
# Stack confidence scores
conf_stack = torch.stack([conf_pt, conf_ft], dim=0) # Shape: (num_models, batch_size)
# Normalize confidence scores to get lambdas
lambdas = conf_stack / conf_stack.sum(dim=0, keepdim=True) # Ensures sum=1 for each sample
all_lambdas.append(lambdas)
# Concatenate results across all batches
all_lambdas = torch.cat(all_lambdas, dim=1) # Shape: (num_models, num_samples)
return all_lambdas # First row is pre-trained CLIP, second row is fine-tuned model
def evaluate_zero_shot(clip, dataloaders, classnames, args):
""" Evaluate using zero-shot """
model = copy.deepcopy(clip)
return evaluate_model(model, dataloaders[0], args)
def evaluate_wise_ft(clip_pt, sd_pt, sd_ft, dataloaders, args):
""" Evaluate using weight-space interpolation (WiSE-FT). """
model = copy.deepcopy(clip_pt)
sd_pt = copy.deepcopy(sd_pt)
sd_ft = copy.deepcopy(sd_ft)
alpha = 0.5
merged_sd = {key: (alpha * sd_ft[key] + (1 - alpha) * sd_pt[key]) for key in sd_ft.keys()}
model.load_state_dict(merged_sd)
return evaluate_model(model, dataloaders[0], args)
def evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, dataloaders, args, batch_wise=True):
""" Evaluate using TCube (Entropy-based Weight Interpolation). """
# original_sd = copy.deepcopy(clip_pt.state_dict()) # Store original model state
model = copy.deepcopy(clip_pt) # Use reference to avoid deepcopy
sd_pt = copy.deepcopy(sd_pt)
sd_ft = copy.deepcopy(sd_ft)
logits_final, labels_final = [], []
dataloader = dataloaders[0] if batch_wise else dataloaders[1]
for i, (inputs, label) in enumerate(tqdm(dataloader)):
inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True)
merged_sd = interpolation(lambdas[:, i], sd_pt, sd_ft)
model.load_state_dict(merged_sd, strict=False) # Load interpolated weights
model.eval()
with torch.no_grad():
outputs = model(inputs)
logits_final.append(outputs)
labels_final.append(label)
logits_final = torch.cat(logits_final).cuda(args.gpu, non_blocking=True)
labels_final = torch.cat(labels_final).cuda(args.gpu, non_blocking=True)
return compute_metrics(logits_final, labels_final)
def evaluate_model(model, dataloader, args):
""" Generic evaluation function for a given model. """
logits_final, labels_final = [], []
model.eval()
for inputs, label in tqdm(dataloader):
inputs, label = inputs.cuda(args.gpu, non_blocking=True), label.cuda(args.gpu, non_blocking=True)
with torch.no_grad():
outputs = model(inputs)
logits_final.append(outputs)
labels_final.append(label)
logits_final = torch.cat(logits_final)
labels_final = torch.cat(labels_final)
return compute_metrics(logits_final, labels_final)
def compute_metrics(logits_final, labels_final):
""" Compute Accuracy and AUC metrics. """
logits_final_tensor = (logits_final)
labels_final_tensor = (labels_final)
acc = accuracy(logits_final_tensor, labels_final_tensor)
probs = F.softmax(logits_final_tensor, dim=1).cpu().numpy()
labels = labels_final_tensor.view(-1).cpu().numpy()
if probs.shape[1] > 2:
# Check if all classes are present in the labels
unique_classes = np.unique(labels)
n_classes = probs.shape[1]
if len(unique_classes) < n_classes:
# Not all classes are present in the test set
# Calculate AUC only for present classes
auc_scores = []
for cls in unique_classes:
if np.sum(labels == cls) > 0: # Ensure class has samples
# Binary classification: current class vs rest
binary_labels = (labels == cls).astype(int)
auc_scores.append(roc_auc_score(binary_labels, probs[:, cls]))
auc = np.mean(auc_scores) if auc_scores else 0.5 # Default to 0.5 if no valid scores
else:
# All classes are present, use standard OvR
auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
else:
# For binary classification
auc = roc_auc_score(labels, probs[:, 1])
return acc, auc*100
class CustomDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.fromarray(self.images[idx])
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
def get_transform(args):
transform = transforms.Compose([
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(args.resolution),
transforms.Lambda(lambda image: image.convert('RGB')),
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
return transform
def get_medmnistc_dataloader(args, set_id, batch_size=32, num_workers=4, split='test', dataset=None, severity=None):
transform = get_transform(args)
data_root = os.path.join(args.medmnistc_data, set_id, split)
path = os.path.join(data_root, f'{dataset}_severity_{severity}.npz') if dataset not in ["clean", None] else os.path.join(data_root, "clean.npz")
if not os.path.exists(path):
raise FileNotFoundError(f"Dataset file not found: {path}")
data = np.load(path)
images = data["images"]
labels = data["labels"].squeeze()
dataset = CustomDataset(images, labels, transform=transform)
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
def get_medimeta_dataloader(args, testset, batch_size=32, num_workers=4, split='test'):
transform = get_transform(args)
task_name = medimeta_testset_task_dict[testset][0].replace("_", " ")
dataset = build_medimeta_dataset(args.medimeta_data, testset, task_name, transform)
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
def evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set=None):
if set_id not in results:
results[set_id] = {}
_dataset = test_set if test_set is not None else _dataset
if _dataset not in results[set_id]:
results[set_id][_dataset] = {}
for severity in severities:
print(f"\nEvaluating on _dataset: {_dataset}, severity: {severity}.....")
if test_set is not None:
_dataloaders = [get_medimeta_dataloader(args, test_set, batch_size=args.bs),
get_medimeta_dataloader(args, test_set, batch_size=1)]
else:
_dataloaders = [get_medmnistc_dataloader(args, set_id, batch_size=args.bs, dataset=_dataset, severity=severity),
get_medmnistc_dataloader(args, set_id, batch_size=1, dataset=_dataset, severity=severity)]
# lambdas_tcube = compute_samplewise_tcube_weights(clip_pt, clip_ft, _dataloaders[0], args)
# lambdas_conf = compute_samplewise_conf_weights(clip_pt, clip_ft, _dataloaders[0], args.gpu)
# lambdas_tcube_MI, lambdas_tcube_plot = compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, _dataloaders[0], args, batch_wise=False)
lambdas_tcube_MI_bmm = compute_samplewise_tcube_weights_MI(clip_pt, clip_ft, _dataloaders[0], args, batch_wise=True)
# for ent vs mi plot ----------
# entropies = ent_mi_dict['entropy']
# MIs = ent_mi_dict['mi']
# correct_pt = ent_mi_dict['correct_pt']
# correct_ft = ent_mi_dict['correct_ft']
# x_entropy = ent_mi_dict['x_entropy']
# plot_entropy_vs_mi(entropies, MIs, agreement_diff=ent_mi_dict['agreement_diff'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/mi_v_ent2/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
# plot_entropy_vs_mi_by_correctness(entropies, MIs, correct_pt, correct_ft, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/mi_v_ent_by_correctness/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
# plot_Xentropy_vs_mi_by_correctness(x_entropy, MIs, correct_pt, correct_ft, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/xent_v_mi_by_correctness/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
# plot_xentropy_vs_mi_entire(x_entropy, MIs, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/xent_v_mi_entire/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
# plot_stacked_ce_vs_mi_bins(MIs, ent_mi_dict['CE_pt'], ent_mi_dict['CE_ft'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/ce_v_mi_bins/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
# plot_ce_vs_mi_by_correctness(ent_mi_dict['CE_pt'], ent_mi_dict['CE_ft'], MIs, correct_pt, correct_ft, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/ce_v_mi_by_correctness/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
plot_confidence_vs_js(ent_mi_dict['Ppt'], ent_mi_dict['Pft'], save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/conf_v_jsd/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}_{severity}.png')
# -----------------------------
lambdas_dict = {
# 'zero_shot_pt': None,
# 'zero_shot_ft': None,
# 'model_ensemble': None,
# 'wise_ft': None,
# 'slerp': None,
# 't_arithmetic': None,
# 'm3': None,
# 'tcube': lambdas_tcube,
# # 'conf': lambdas_conf,
# 'tcube_MI': lambdas_tcube_MI,
# 'tcube_MI_bmm': lambdas_tcube_MI_bmm,
}
if severity not in results[set_id][_dataset]:
results[set_id][_dataset][severity] = {}
for method_type, lambdas in lambdas_dict.items():
print("Interpolating and evaluating on - interpolation method: ", method_type)
global dyn_v_stat_plot
if method_type == 'zero_shot_pt':
acc, auc = evaluate_zero_shot(clip_pt, _dataloaders, classnames, args)
elif method_type == 'zero_shot_ft':
acc, auc = evaluate_zero_shot(clip_ft, _dataloaders, classnames, args)
elif method_type == 'model_ensemble':
acc, auc = compute_and_evaluate_model_ensemble(clip_pt, clip_ft, _dataloaders, args)
elif method_type == 'wise_ft':
acc, auc = evaluate_wise_ft(clip_pt, sd_pt, sd_ft, _dataloaders, args)
elif method_type == 'slerp':
acc, auc = evaluate_slerp(clip_pt, sd_pt, sd_ft, _dataloaders[0], args)
elif method_type == 't_arithmetic':
acc, auc = evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, _dataloaders[0], args)
elif method_type == 'm3':
acc, auc = evaluate_m3(clip_pt, sd_pt, sd_ft, _dataloaders[0], args)
elif method_type == 'tcube':
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=args.batch_wise)
elif method_type == 'conf':
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False)
elif method_type == 'tcube_MI':
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=False)
elif method_type == 'tcube_MI_bmm':
acc, auc = evaluate_tcube(clip_pt, sd_pt, sd_ft, lambdas, _dataloaders, args, batch_wise=True)
print(f'Accuracy: {acc[0].item():.2f}%, AUC: {auc:.2f}%, Mean: {(acc[0].item()+auc)/2:.2f}%')
results[set_id][_dataset][severity][method_type] = {'accuracy': acc[0].item(), 'auc': auc, 'mean': (acc[0].item()+auc)/2}
if method_type in method_names:
# if method_names[method_type] in dyn_v_stat_plot:
# dyn_v_stat_plot[method_names[method_type]] = []
dyn_v_stat_plot[method_type].append(acc[0].item())
# for lambda histogram -------------------------------------
# dyn_v_stat_plot['conditions'].append(f'{set_id}_{_dataset}')
# lambdas_dict_plot = {}
# lambdas_dict_plot[_dataset] = lambdas_tcube[1]
# plot_lambda_histogram(lambdas_dict_plot, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/lambda_histogram_ER/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}.png')
# ------------------------------------------------------------
del _dataloaders, lambdas_dict
gc.collect()
return results
def evaluate_on_datasets(args, datasets, default_datasets, default_severity_range):
results = {}
for set_id in datasets:
print(f"\nEvaluating on dataset: {set_id}\n")
for _dataset in default_datasets:
severities = [0] if _dataset in ["clean", "medimeta"] else range(default_severity_range[0], default_severity_range[1]+1)
if _dataset == "medimeta":
test_sets = fetch_keys_for_value(medimeta_testset_task_dict, set_id)
for test_set in test_sets:
classnames = eval("{}_classes".format(test_set.lower()))
clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id)
results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results, test_set)
else:
classnames = eval("{}_classes".format(set_id.lower()))
clip_pt, sd_pt, clip_ft, sd_ft = load_models(args, classnames, set_id)
results = evaluate_on_test_set(args, set_id, _dataset, severities, clip_pt, sd_pt, clip_ft, sd_ft, classnames, results)
del clip_pt, clip_ft, sd_ft
# try:
# plot_delta_performance(dyn_v_stat_plot, save_path=f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/plots/dynamic_vs_static/{(args.arch).replace("/", "_")}/{set_id}_{_dataset}.png')
# except Exception as e:
# print(f"An error occurred while plotting delta performance: {e}")
# pass
return results
def print_results(results):
now = datetime.now()
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S")
print(f"\nResults (Evaluated on: {formatted_date}):")
for set_id, result in results.items():
print(f"\nDataset: {set_id}")
print("=" * 75)
print(f"{'_dataset':<20}{'Severity':<10}{'Method':<20}{'Accuracy':<10}{'AUC':<10}{'Mean':<10}")
for _dataset, severity_dict in result.items():
print("=" * 75)
for severity, metrics_dict in severity_dict.items():
print("-" * 80)
for method_type, metrics in metrics_dict.items():
print(f"{_dataset:<20}{severity:<10}{method_type:<20}{metrics['accuracy']:<10.2f}{metrics['auc']:<10.2f}{metrics['mean']:<10.2f}")
print("=" * 75)
def log_results(results, args):
now = datetime.now()
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S")
os.makedirs(os.path.dirname(args.log_path), exist_ok=True)
with open(args.log_path, 'w') as log_file:
log_file.write(f"\nResults (Evaluated on: {formatted_date}):\n")
log_file.write(f"Arguments:\n")
for arg, value in vars(args).items():
log_file.write(f"{arg}: {value}\n")
log_file.write("\n")
for set_id, result in results.items():
log_file.write(f"\nDataset Group: {set_id}\n")
log_file.write("-" * 80 + "\n")
# Write header including severity
header = f"{'_dataset':<20}{'Severity':<10}{'Method':<20}" \
f"{'Accuracy':<15}{'AUC':<15}{'Mean':<15}\n"
log_file.write(header)
for _dataset, severity_dict in result.items():
for severity, metrics_dict in severity_dict.items():
for method_type, metrics in metrics_dict.items():
line = f"{_dataset:<20}{str(severity):<10}{method_type:<20}" \
f"{metrics['accuracy']:<15.2f}{metrics['auc']:<15.2f}{metrics['mean']:<15.2f}\n"
log_file.write(line)
log_file.write("-" * 80 + "\n")
log_file.write("-" * 80 + "\n")
def save_json_results(results, args):
json_results = {}
for set_id, result in results.items():
# For each set_id (e.g., modality grouping), iterate through datasets
for dataset, severity_dict in result.items():
for severity, metrics_dict in severity_dict.items():
for method, metrics in metrics_dict.items():
if method not in json_results:
json_results[method] = {}
if dataset not in json_results[method]:
json_results[method][dataset] = {}
# Store results for each severity as is (no averaging)
json_results[method][dataset][str(severity)] = {
"accuracy": metrics["accuracy"],
"auc": metrics["auc"],
"mean": metrics["mean"]
}
os.makedirs(os.path.dirname(args.json_path), exist_ok=True)
with open(args.json_path, 'w') as f:
json.dump(json_results, f, indent=4)
def main():
default_ft_path = [
'/home/raza.imam/Documents/Umaima/TPT/finetuned_models/ViT-B_16'
]
default_medmnistc_root = '/home/raza.imam/Documents/Umaima/TPT/MedMNIST-C'
default_medimeta_root = '/home/raza.imam/Documents/Umaima/datasets/medimeta'
default_testset = 'breastmnist/retinamnist/bloodmnist/octmnist' # 'breastmnist/retinamnist/bloodmnist/pneumoniamnist/octmnist'
default_datasets = [
"clean",
"medimeta",
# "gaussian_noise",
"impulse_noise",
# "motion_blur",
# "zoom_blur",
# "brightness",
# "contrast",
"pixelate",
]
default_seed = 42
default_arch = 'ViT-B/16'
default_ctx_init = 'a_photo_of_a'
default_gpu = 1
default_severity_range = [5, 5] # min 1 and max 5 is allowed
default_batch_wise = True
default_offset = False
default_lambda_mean_type = 'mean'
default_bs = 32
save_time = datetime.now().strftime("%Y%m%d_%H%M")
save_path = f'/home/raza.imam/Documents/Umaima/TPT/results_tcube/{save_time}_{default_arch.replace("/", "_")}/'
default_log_path = f'{save_path}log.txt'
default_json_path = f'{save_path}dict.json'
parser = argparse.ArgumentParser(description='Multi-Model Interpolation')
parser.add_argument('medmnistc_data', metavar='DIR', nargs="?", default=default_medmnistc_root, help='path to medmnistc dataset root')
parser.add_argument('medimeta_data', metavar='DIR', nargs="?", default=default_medimeta_root, help='path to medimeta dataset root')
parser.add_argument('--ft_path', type=str, default=default_ft_path[0], help='Paths to FT model state dicts')
parser.add_argument('--log_path', type=str, default=default_log_path, help='Path to save results')
parser.add_argument('--json_path', type=str, default=default_json_path, help='Path to save results in json format')
parser.add_argument('--testset', type=str, default=default_testset, help='Dataset name')
parser.add_argument('--offset', action='store_true', default=default_offset, help='Use offset for TCube')
parser.add_argument('--lambda_mean_type', type=str, default=default_lambda_mean_type, help='Type of lambda mean for TCube')
parser.add_argument('--batch_wise', action='store_true', default=default_batch_wise)
parser.add_argument('--seed', type=int, default=default_seed, help='Random seed')
parser.add_argument('-a', '--arch', metavar='ARCH', default=default_arch, help='model architecture')
parser.add_argument('--gpu', type=int, default=default_gpu, help='GPU ID')
parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens')
parser.add_argument('--ctx_init', default=default_ctx_init, type=str, help='init tunable prompts')
parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution')
parser.add_argument('--bs', default=default_bs, type=int, help='Batch size')
args = parser.parse_args()
print(args)
torch.manual_seed(args.seed)
datasets = args.testset.split("/")
results = evaluate_on_datasets(args, datasets, default_datasets, default_severity_range)
# print_results(results)
log_results(results, args)
save_json_results(results, args)
if __name__ == '__main__':
main()