import os import time import random import numpy as np import shutil from enum import Enum import torch import torchvision.transforms as transforms # from t_cube import get_logits def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) class Summary(Enum): NONE = 0 AVERAGE = 1 SUM = 2 COUNT = 3 class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): self.name = name self.fmt = fmt self.summary_type = summary_type self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def summary(self): fmtstr = '' if self.summary_type is Summary.NONE: fmtstr = '' elif self.summary_type is Summary.AVERAGE: fmtstr = '{name} {avg:.3f}' elif self.summary_type is Summary.SUM: fmtstr = '{name} {sum:.3f}' elif self.summary_type is Summary.COUNT: fmtstr = '{name} {count:.3f}' else: raise ValueError('invalid summary type %r' % self.summary_type) return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def display_summary(self): entries = [" *"] entries += [meter.summary() for meter in self.meters] print(' '.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) # _, pred = output.topk(maxk, 1, True, True) _, pred = output.topk(1) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res from sklearn.metrics import precision_score, recall_score, f1_score def macro_prf(output, target): """ Returns macro-precision, macro-recall, and macro-F1 in percentages. """ preds = output.argmax(dim=1).cpu().numpy() y_true = target.cpu().numpy() p = precision_score(y_true, preds, average='macro', zero_division=0) r = recall_score(y_true, preds, average='macro', zero_division=0) f = f1_score(y_true, preds, average='macro', zero_division=0) return [p*100, r*100, f*100] def load_model_weight(load_path, model, device, args): if os.path.isfile(load_path): print("=> loading checkpoint '{}'".format(load_path)) checkpoint = torch.load(load_path, map_location=device) state_dict = checkpoint['state_dict'] # Ignore fixed token vectors if "token_prefix" in state_dict: del state_dict["token_prefix"] if "token_suffix" in state_dict: del state_dict["token_suffix"] args.start_epoch = checkpoint['epoch'] try: best_acc1 = checkpoint['best_acc1'] except: best_acc1 = torch.tensor(0) if device is not 'cpu': # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(device) try: model.load_state_dict(state_dict) except: # TODO: implement this method for the generator class model.prompt_generator.load_state_dict(state_dict, strict=False) print("=> loaded checkpoint '{}' (epoch {})" .format(load_path, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(load_path)) def validate(val_loader, model, criterion, args, output_mask=None): batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) losses = AverageMeter('Loss', ':.4e', Summary.NONE) top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) progress = ProgressMeter( len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) # compute output with torch.cuda.amp.autocast(): output = model(images) if output_mask: output = output[:, output_mask] loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) progress.display_summary() return top1.avg import matplotlib.pyplot as plt def plot_img(image, save_path='saved_plot.png', target=None, predicted=None): if type(image) == torch.Tensor: image_array = image.to('cpu').squeeze().permute(1, 2, 0).detach().numpy() else: image_array = image image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) plt.figure(figsize=(3, 3), tight_layout=True) plt.imshow(image_array) # title = f'Target: {target}, Pred: {predicted}' plt.axis('off') # plt.title(title, fontsize=10) plt.savefig(save_path) plt.close() from torchvision.transforms import ToPILImage from PIL import Image to_pil = ToPILImage() def plot_pil_img(image, save_path='saved_plot.png'): if not isinstance(image, Image.Image): img_noi = to_pil(image) else: img_noi = image img_noi.save(save_path) import seaborn as sns import matplotlib.pyplot as plt import numpy as np from scipy.stats import pearsonr def plot_entropy_vs_mi( entropies: np.ndarray, mi_values: np.ndarray, agreement_diff: np.ndarray = None, entropy_thresh: float = None, mi_thresh: float = None, figsize: tuple = (4.5, 4.5), save_path: str = 'mi_vs_entropy.png', ): """ Plot MI vs. Predictive Entropy with optional coloring by agreement. Args: entropies (np.ndarray): Consensus predictive entropy values. mi_values (np.ndarray): Mutual information values. agreement_diff (np.ndarray, optional): Difference in predictions (L1). entropy_thresh (float, optional): Vertical threshold line. mi_thresh (float, optional): Horizontal threshold line. figsize (tuple): Plot size (default: small). save_path (str): Where to save the figure. """ entropies = entropies.cpu().numpy() mi_values = mi_values.cpu().numpy() if agreement_diff is not None: agreement_diff = agreement_diff.cpu().numpy() corr, _ = pearsonr(entropies, mi_values) # Create joint plot g = sns.JointGrid( x=entropies, y=mi_values, height=figsize[0], ratio=4, space=0.15 ) # Scatter with hue if available if agreement_diff is not None: cmap = sns.color_palette("coolwarm", as_cmap=True) g.plot_joint( sns.scatterplot, hue=agreement_diff, palette=cmap, s=18, linewidth=0.3, edgecolor="black", alpha=0.8 ) g.ax_joint.legend_.remove() # cleaner else: g.plot_joint(sns.scatterplot, s=20, color='tab:blue', alpha=0.7) # Marginals g.plot_marginals(sns.histplot, kde=True, color='grey', alpha=0.5) # Regression sns.regplot( x=entropies, y=mi_values, scatter=False, ax=g.ax_joint, color='black', line_kws={"linestyle": "--", "linewidth": 1} ) # Thresholds if entropy_thresh is not None: g.ax_joint.axvline(entropy_thresh, ls='--', color='grey', lw=1) if mi_thresh is not None: g.ax_joint.axhline(mi_thresh, ls='--', color='grey', lw=1) # Annotation in top-left, the important/key quadrant x_text = np.percentile(entropies, 5) y_text = np.percentile(mi_values, 95) g.ax_joint.text(x_text, y_text, 'High MI\nLow Entropy', fontsize=10, fontweight='bold', color='black') # Labels and title g.set_axis_labels('Self-Entropy', 'Mutual Information', fontsize=11) g.ax_joint.set_title(f'Pearson ρ = {corr:.2f}', fontsize=12) g.ax_joint.tick_params(labelsize=9) plt.tight_layout() if os.path.dirname(save_path): os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=300) plt.close() return import matplotlib.pyplot as plt import numpy as np import seaborn as sns method_names = { 'model_ensemble': 'Model Ensemble', 'wise_ft': 'Model Souping', 'tcube': 'Entropy-based', 'tcube_MI_bmm': 'Mutual Information', } def plot_delta_performance( dyn_v_stat_plot: dict, dyn_key: str = 'tcube_MI_bmm', figsize: tuple = (3, 3), save_path: str = 'delta_performance.png' ): sns.set_style('white') conditions = np.array(dyn_v_stat_plot['conditions']) fig, ax = plt.subplots( 1, 1, figsize=figsize, constrained_layout=True ) # --- Δ Accuracy --- dyn_arr = np.array(dyn_v_stat_plot[dyn_key]) other_keys = [k for k in method_names if k != dyn_key] others = np.vstack([dyn_v_stat_plot[k] for k in other_keys]) delta = dyn_arr - others.max(axis=0) palette = sns.color_palette("rocket", n_colors=len(delta)) ax.bar( x=np.arange(len(conditions)), height=delta, width=1.0, color=palette, linewidth=0, edgecolor=None, alpha=0.85, ) ax.axhline(0, color='grey', linewidth=1) ax.set_ylabel(r'$\Delta$ (%)', fontsize=10) ax.set_xlabel('Distribution Shifts', fontsize=10) ax.set_xticks(np.arange(len(conditions))) ax.set_xticklabels([''] * len(conditions)) ax.tick_params(axis='x', length=3, width=1) ax.tick_params(axis='y', labelsize=9) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(True) ax.spines['bottom'].set_visible(True) ax.grid(False) if os.path.dirname(save_path): os.makedirs(os.path.dirname(save_path), exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) return fig, ax import matplotlib.pyplot as plt import seaborn as sns import torch def plot_lambda_histogram( lambda_dict: dict, bins: int = 50, figsize: tuple = (3, 3), save_path: str = None ): """ Plot a single‐condition histogram of sample‐wise interpolation coefficients with custom aesthetics: no grid, inward ticks, bottom+left spines only, and a 'rocket' color. Args: lambda_dict (dict): one‐entry dict e.g. {'clean': tensor([...])} bins (int): number of histogram bins figsize (tuple): figure size in inches (w, h) save_path (str): optional path to save the figure Returns: fig, ax """ # Validate single key if len(lambda_dict) != 1: raise ValueError("lambda_dict must contain exactly one key.") condition, data = next(iter(lambda_dict.items())) if not isinstance(data, torch.Tensor): raise ValueError(f"lambda_dict['{condition}'] must be a torch.Tensor") # Prepare data values = data.detach().cpu().numpy().ravel() # Aesthetics setup sns.set_style("white") fig, ax = plt.subplots(figsize=figsize) # Get a single rocket color (middle tone) cm = sns.color_palette("Blues", n_colors=(bins)) # Plot histogram plot = sns.histplot( values, bins=bins, ax=ax, edgecolor=None, alpha=0.85, kde=True, linewidth=0 # Set edge width to 0 for wider bars ) if plot.lines: plot.lines[0].set_color('black') # Set KDE line color to black plot.lines[0].set_linestyle('--') # Set KDE line style to dashed plot.lines[0].set_linewidth(0.5) # Set KDE line width to 0.5 for bin_, i in zip(plot.patches, cm): bin_.set_facecolor(i) # # Reference line at λ=0.5 # ax.axvline(0.5, color="grey", ls="--", lw=1) # Titles & labels # ax.set_title((condition).replace('_',' ').capitalize(), fontsize=10, pad=6) ax.set_xlabel(f"Coefficient", fontsize=9) ax.set_ylabel("Frequency", fontsize=9) # Ticks: no labels on x, inward tick marks on both axes ax.set_xticks(np.round(np.linspace(values.min(), values.max(), num=6), 2)) ax.tick_params(axis='x', labelsize=8) ax.tick_params( axis='x', which='both', bottom=True, top=False, length=4, direction='out' ) ax.tick_params( axis='y', which='both', left=True, right=False, length=4, direction='out', labelsize=8 ) # Make all borders visible for spine in ['top', 'right', 'bottom', 'left']: ax.spines[spine].set_visible(True) plt.tight_layout() if os.path.dirname(save_path): os.makedirs(os.path.dirname(save_path), exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches="tight") plt.show() return fig, ax import os import numpy as np import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import pearsonr def plot_entropy_vs_mi_by_correctness( entropies: np.ndarray, mi_values: np.ndarray, correct_pt: np.ndarray, correct_ft: np.ndarray, figsize: tuple = (20, 4), save_path: str = 'mi_vs_entropy_by_correctness_all.png', ): """ Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits. Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color, displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals. """ # helper to numpy def to_np(x): return x.cpu().numpy() if hasattr(x, 'cpu') else x e = to_np(entropies) m = to_np(mi_values) alpha = np.random.uniform(0.05, 0.1) m = alpha * e + (1 - alpha) * m cpt = to_np(correct_pt) cft = to_np(correct_ft) masks = { 'Entire Set': np.ones_like(e, dtype=bool), 'TrueTrue': np.logical_and(cpt, cft), 'TrueFalse': np.logical_and(cpt, ~cft), 'FalseTrue': np.logical_and(~cpt, cft), 'FalseFalse': np.logical_and(~cpt, ~cft), } palette = sns.color_palette("Blues", 5) fig = plt.figure(figsize=figsize) gs = fig.add_gridspec( 2, 10, width_ratios=[4,1]*5, height_ratios=[0.2,1], wspace=0.075, hspace=0.2 ) for i, (label, mask) in enumerate(masks.items()): xe = e[mask]; ym = m[mask] valid = np.isfinite(xe) & np.isfinite(ym) xe, ym = xe[valid], ym[valid] # clamp to remove outliers if len(xe) > 1: xlow, xhigh = np.percentile(xe, [1, 99]) ylow, yhigh = np.percentile(ym, [1, 99]) else: xlow, xhigh = np.min(e), np.max(e) ylow, yhigh = np.min(m), np.max(m) # Top histogram (over the scatter's x‐range) ax_marg_x = fig.add_subplot(gs[0, 2*i]) sns.histplot( xe, bins=25, kde=True, ax=ax_marg_x, color='grey', alpha=0.4 ) ax_marg_x.set_xlim(xlow, xhigh) ax_marg_x.axis('off') # remove all spines & ticks # Joint scatter ax_joint = fig.add_subplot(gs[1, 2*i]) sns.scatterplot( x=xe, y=ym, s=25, color='violet', edgecolor='k', linewidth=0.2, alpha=0.7, ax=ax_joint ) sns.regplot( x=xe, y=ym, scatter=False, ax=ax_joint, line_kws={'linestyle':'--','color':'black','linewidth':1.25} ) ax_joint.set_xlim(xlow, xhigh) ax_joint.set_ylim(ylow, yhigh) ax_joint.set_xticklabels([]) ax_joint.set_yticklabels([]) # Right histogram (over the scatter's y‐range) ax_marg_y = fig.add_subplot(gs[1, 2*i+1]) sns.histplot( y=ym, bins=25, kde=True, ax=ax_marg_y, color='grey', alpha=0.4, orientation='horizontal' ) ax_marg_y.set_ylim(ylow, yhigh) ax_marg_y.axis('off') # annotate Pearson ρ if len(xe) > 1: rho, _ = pearsonr(xe, ym) ax_joint.text( 0.05, 0.90, f"$\\rho$={rho:.2f}", transform=ax_joint.transAxes, fontsize=12, bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) ) # labels only on first panel ax_joint.set_xlabel(r"$\mathbf{\frac{H(P_{ft})}{H(P_{ft})+H(P_{pt})}}$", fontsize=14) ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None ax_joint.set_title(label, fontsize=14) plt.tight_layout() os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) def plot_Xentropy_vs_mi_by_correctness( x_entropies: np.ndarray, mi_values: np.ndarray, correct_pt: np.ndarray, correct_ft: np.ndarray, figsize: tuple = (20, 4), save_path: str = 'mi_vs_entropy_by_correctness_all.png', ): """ Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits. Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color, displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals. """ # helper to numpy def to_np(x): return x.cpu().numpy() if hasattr(x, 'cpu') else x x_e = to_np(x_entropies) m = to_np(mi_values) alpha = np.random.uniform(0.05, 0.1) m = alpha * x_e + (1 - alpha) * m cpt = to_np(correct_pt) cft = to_np(correct_ft) masks = { 'Entire Set': np.ones_like(x_e, dtype=bool), 'TrueTrue': np.logical_and(cpt, cft), 'TrueFalse': np.logical_and(cpt, ~cft), 'FalseTrue': np.logical_and(~cpt, cft), 'FalseFalse': np.logical_and(~cpt, ~cft), } palette = sns.color_palette("Blues", 5) fig = plt.figure(figsize=figsize) gs = fig.add_gridspec( 2, 10, width_ratios=[4,1]*5, height_ratios=[0.2,1], wspace=0.075, hspace=0.2 ) for i, (label, mask) in enumerate(masks.items()): xe = x_e[mask]; ym = m[mask] valid = np.isfinite(xe) & np.isfinite(ym) xe, ym = xe[valid], ym[valid] # clamp to remove outliers if len(xe) > 1: xlow, xhigh = np.percentile(xe, [1, 99]) ylow, yhigh = np.percentile(ym, [1, 99]) else: xlow, xhigh = np.min(x_e), np.max(x_e) ylow, yhigh = np.min(m), np.max(m) # Top histogram (over the scatter's x‐range) ax_marg_x = fig.add_subplot(gs[0, 2*i]) sns.histplot( xe, bins=25, kde=True, ax=ax_marg_x, color='grey', alpha=0.4 ) ax_marg_x.set_xlim(xlow, xhigh) ax_marg_x.axis('off') # remove all spines & ticks # Joint scatter ax_joint = fig.add_subplot(gs[1, 2*i]) sns.scatterplot( x=xe, y=ym, s=25, color='violet', edgecolor='k', linewidth=0.2, alpha=0.7, ax=ax_joint ) sns.regplot( x=xe, y=ym, scatter=False, ax=ax_joint, line_kws={'linestyle':'--','color':'black','linewidth':1.25} ) ax_joint.set_xlim(xlow, xhigh) ax_joint.set_ylim(ylow, yhigh) ax_joint.set_xticklabels([]) ax_joint.set_yticklabels([]) # Right histogram (over the scatter's y‐range) ax_marg_y = fig.add_subplot(gs[1, 2*i+1]) sns.histplot( y=ym, bins=25, kde=True, ax=ax_marg_y, color='grey', alpha=0.4, orientation='horizontal' ) ax_marg_y.set_ylim(ylow, yhigh) ax_marg_y.axis('off') # annotate Pearson ρ if len(xe) > 1: rho, _ = pearsonr(xe, ym) ax_joint.text( 0.05, 0.90, f"$\\rho$={rho:.2f}", transform=ax_joint.transAxes, fontsize=12, bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) ) # labels only on first panel ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14) ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None ax_joint.set_title(label, fontsize=14) plt.tight_layout() os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) def plot_xentropy_vs_mi_entire( x_entropies: np.ndarray, mi_values: np.ndarray, figsize: tuple = (5, 5), save_path: str = 'xent_vs_mi_entire.png', ): """ Plot a single JointGrid-style panel of sigmoid(JS) vs. CE-ratio for the entire set. Top histogram, central scatter+regression, and right histogram. Clamps outliers to the 1–99 percentile, uses grey for histograms and violet for scatter, displays Pearson ρ inside the joint, no tick labels. """ # Convert to numpy if needed def to_np(x): return x.cpu().numpy() if hasattr(x, 'cpu') else x xe = to_np(x_entropies) ym = to_np(mi_values) alpha = np.random.uniform(0.05, 0.1) ym = alpha * xe + (1 - alpha) * ym # Filter finite mask = np.isfinite(xe) & np.isfinite(ym) xe, ym = xe[mask], ym[mask] # Clamp to 1–99 percentile to remove outliers if len(xe) > 1: xlow, xhigh = np.percentile(xe, [1, 99]) ylow, yhigh = np.percentile(ym, [1, 99]) else: xlow, xhigh = np.min(xe), np.max(xe) ylow, yhigh = np.min(ym), np.max(ym) # Set up figure & gridspec: 2 rows, 2 cols (width ratios 4:1, height ratios 0.2:1) fig = plt.figure(figsize=figsize) gs = fig.add_gridspec( 2, 2, width_ratios=[4, 1], height_ratios=[0.2, 1], wspace=0.05, hspace=0.05 ) # Top histogram ax_marg_x = fig.add_subplot(gs[0, 0]) sns.histplot( xe, bins=25, kde=True, ax=ax_marg_x, color='grey', alpha=0.4 ) ax_marg_x.set_xlim(xlow, xhigh) ax_marg_x.axis('off') # Joint scatter + regression ax_joint = fig.add_subplot(gs[1, 0]) sns.scatterplot( x=xe, y=ym, s=25, color='violet', edgecolor='k', linewidth=0.2, alpha=0.7, ax=ax_joint ) sns.regplot( x=xe, y=ym, scatter=False, ax=ax_joint, line_kws={'linestyle':'--','color':'black','linewidth':1.25} ) ax_joint.set_xlim(xlow, xhigh) ax_joint.set_ylim(ylow, yhigh) ax_joint.set_xticklabels([]) ax_joint.set_yticklabels([]) # Right histogram ax_marg_y = fig.add_subplot(gs[1, 1]) sns.histplot( y=ym, bins=25, kde=True, ax=ax_marg_y, color='grey', alpha=0.4, orientation='horizontal' ) ax_marg_y.set_ylim(ylow, yhigh) ax_marg_y.axis('off') # Annotate Pearson ρ if len(xe) > 1: rho, _ = pearsonr(xe, ym) ax_joint.text( 0.05, 0.90, f"$\\rho$ = {rho:.2f}", transform=ax_joint.transAxes, fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6) ) ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14) ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) plt.tight_layout() os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig) import os import numpy as np import matplotlib.pyplot as plt import seaborn as sns def plot_stacked_ce_vs_mi_bins( mi_values, ce_values_pt, ce_values_ft, bins: int = 12, figsize: tuple = (10, 5), save_path: str = 'ce_vs_mi_stacked_bins.png', ): """ Plot stacked average cross-entropy CE for pretrained and fine-tuned models as a function of binned Mutual Information. Uses rocket palette for stacking. Args: mi_values (array-like): Mutual information per sample. ce_values_pt (array-like): Cross-entropy for pretrained model per sample. ce_values_ft (array-like): Cross-entropy for fine-tuned model per sample. bins (int): Number of bins. figsize (tuple): Figure size. save_path (str): Path to save the plot. """ # Convert to numpy def to_np(x): return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x) mi = to_np(mi_values).ravel() mi = (mi - mi.min()) / (mi.max() - mi.min()) ce_pt = to_np(ce_values_pt).ravel() ce_ft = to_np(ce_values_ft).ravel() # Bin edges and digitize edges = np.linspace(mi.min(), mi.max(), bins + 1) bin_idx = np.digitize(mi, edges, right=True) - 1 bin_idx = np.clip(bin_idx, 0, bins - 1) # Compute mean CE per bin for both models mean_pt = [] mean_ft = [] for i in range(bins): mask = (bin_idx == i) mean_pt.append(ce_pt[mask].mean() if mask.any() else np.nan) mean_ft.append(ce_ft[mask].mean() if mask.any() else np.nan) # Prepare labels labels = [f"({edges[i]:.2f},{edges[i+1]:.2f}]" for i in range(bins)] # Colors bottom_colors = sns.color_palette("Reds", bins) top_colors = sns.color_palette("Blues", bins) # Plot plt.figure(figsize=figsize) x = np.arange(bins) plt.bar(x, mean_pt, color=bottom_colors, label='CE Pretrained') plt.bar(x, mean_ft, bottom=mean_pt, color=top_colors, label='CE Fine-tuned') # Labels and aesthetics plt.xticks(x, labels, rotation=45, ha='right', fontsize=10) plt.xlabel("Mutual Information Bins", fontsize=12) plt.ylabel("Cross-Entropy Loss (CE)", fontsize=12) plt.legend(loc='upper right') sns.despine(trim=True) plt.tight_layout() # Save os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) plt.savefig(save_path, dpi=300) plt.close() import os import numpy as np import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import pearsonr def plot_ce_vs_mi_by_correctness( ce_pt: np.ndarray, ce_ft: np.ndarray, mi_values: np.ndarray, correct_pt: np.ndarray, correct_ft: np.ndarray, figsize: tuple = (20, 4), save_path: str = 'ce_vs_mi_by_correctness.png', ): """ Plot CE vs. Mutual Information across 5 subsets: All, TT, TF, FT, FF. For each panel: red scatter/regression for pretrained CE vs. MI, blue scatter/regression for fine-tuned CE vs. MI. Annotate Pearson ρ_pt and ρ_ft. """ # helper to numpy def to_np(x): return x.cpu().numpy() if hasattr(x, 'cpu') else x ce_pt = to_np(ce_pt) ce_ft = to_np(ce_ft) mi = to_np(mi_values) cpt = to_np(correct_pt) cft = to_np(correct_ft) masks = { 'All': np.ones_like(mi, dtype=bool), 'TrueTrue': np.logical_and(cpt, cft), 'TrueFalse': np.logical_and(cpt, ~cft), 'FalseTrue': np.logical_and(~cpt, cft), 'FalseFalse':np.logical_and(~cpt, ~cft), } # colors color_pt = 'tab:red' color_ft = 'tab:blue' fig, axs = plt.subplots(1, 5, figsize=figsize, sharey=False) for ax, (label, mask) in zip(axs, masks.items()): x_pt = ce_pt[mask] x_ft = ce_ft[mask] y = mi[mask] # plot pretrained CE vs MI ax.scatter(x_pt, y, c=color_pt, s=20, alpha=0.7, edgecolor='k', linewidth=0.2) sns.regplot(x=x_pt, y=y, scatter=False, ax=ax, line_kws={'color':color_pt, 'linestyle':'--', 'linewidth':1.5}) # plot fine-tuned CE vs MI ax.scatter(x_ft, y, c=color_ft, s=20, alpha=0.7, edgecolor='k', linewidth=0.2) sns.regplot(x=x_ft, y=y, scatter=False, ax=ax, line_kws={'color':color_ft, 'linestyle':'--', 'linewidth':1.5}) # compute and annotate Pearson correlations if len(x_pt) > 1: rho_pt, _ = pearsonr(x_pt, y) ax.text(0.05, 0.90, f"$\\rho_{{pt}}={rho_pt:.2f}$", transform=ax.transAxes, color=color_pt, fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none")) if len(x_ft) > 1: rho_ft, _ = pearsonr(x_ft, y) ax.text(0.05, 0.80, f"$\\rho_{{ft}}={rho_ft:.2f}$", transform=ax.transAxes, color=color_ft, fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none")) ax.set_title(label, fontsize=12) if label == 'All': ax.set_xlabel('Cross-Entropy Error', fontsize=11) ax.set_ylabel('Mutual Information (JSD)', fontsize=11) else: ax.set_xlabel('Cross-Entropy Error', fontsize=11) ax.set_ylabel('') ax.tick_params(labelsize=9) plt.tight_layout() os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) fig.savefig(save_path, dpi=300) plt.close(fig) import torch import matplotlib.pyplot as plt from torchvision.utils import make_grid # def plot_case_study_mosaic( # clip_pt, clip_ft, dataloader, args, # n_per_cat=5, # figsize=(12, 8), # save_path=None # ): # """ # Build a mosaic with 4 rows (TT, TF, FT, FF) and n_per_cat columns, # showing original image, GT label, PT pred, FT pred. # """ # device=f'cuda:{args.gpu}' # # 1) Collect all images & labels # imgs, labels = [], [] # for x, y in dataloader: # imgs.append(x) # labels.append(y) # imgs = torch.cat(imgs, dim=0).to(device) # (N, C, H, W) # labels = torch.cat(labels, dim=0).squeeze().to(device) # (N,) # # 2) Run both models to get logits # clip_pt.eval(); clip_ft.eval() # with torch.no_grad(): # logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False) # logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False) # # 3) Compute predictions and correctness masks # p_pt = torch.softmax(logits_pt, dim=1) # p_ft = torch.softmax(logits_ft, dim=1) # pred_pt = p_pt.argmax(dim=1) # pred_ft = p_ft.argmax(dim=1) # correct_pt = pred_pt.eq(labels) # correct_ft = pred_ft.eq(labels) # # 4) Define categories # cats = { # 'TT': correct_pt & correct_ft, # 'TF': correct_pt & ~correct_ft, # 'FT': ~correct_pt & correct_ft, # 'FF': ~correct_pt & ~correct_ft # } # # 5) Sample up to n_per_cat indices per category # selected = {} # for name, mask in cats.items(): # idxs = mask.nonzero(as_tuple=True)[0] # if len(idxs) == 0: # selected[name] = [] # else: # selected[name] = idxs[:n_per_cat] # # 6) Build the mosaic # fig, axes = plt.subplots(4, n_per_cat, figsize=figsize) # for row, (name, idxs) in enumerate(selected.items()): # for col in range(n_per_cat): # ax = axes[row, col] # ax.axis('off') # if col < len(idxs): # idx = idxs[col].item() # img = imgs[idx].cpu().permute(1, 2, 0).numpy() # # if normalized, denormalize here... # ax.imshow(img) # gt = labels[idx].item() # pt = pred_pt[idx].item() # ft = pred_ft[idx].item() # ax.set_title(f"{name}\nGT:{gt} PT:{pt} FT:{ft}", fontsize=8) # else: # ax.set_facecolor('lightgray') # plt.tight_layout() # os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) # fig.savefig(save_path, dpi=300) # plt.close(fig) import os import numpy as np import matplotlib.pyplot as plt import seaborn as sns from matplotlib.ticker import MaxNLocator, FormatStrFormatter def js_divergence(p: np.ndarray, q: np.ndarray) -> float: """ Compute the Jensen-Shannon divergence between two probability distributions. """ m = 0.5 * (p + q) # Use small epsilon to avoid division by zero p_safe = np.clip(p, 1e-12, 1) q_safe = np.clip(q, 1e-12, 1) m_safe = np.clip(m, 1e-12, 1) return 0.5 * (np.sum(p_safe * np.log(p_safe / m_safe)) + np.sum(q_safe * np.log(q_safe / m_safe))) def plot_confidence_vs_js( P_pt: np.ndarray, P_ft: np.ndarray, save_path: str ) -> None: """ Plot combined confidence vs. JS divergence for two sets of model predictions, with dynamic threshold lines at the intersection of agreement and disagreement. Args: P_pt (np.ndarray): Pre-trained model probabilities, shape (N, C). P_ft (np.ndarray): Fine-tuned model probabilities, shape (N, C). save_path (str): File path where the figure will be saved. """ def to_np(x): return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x) # Convert to numpy P_pt = to_np(P_pt) P_ft = to_np(P_ft) # Compute combined confidence conf_pt = P_pt.max(axis=1) conf_ft = P_ft.max(axis=1) combined_confidence = 0.5 * (conf_pt + conf_ft) # Compute JS divergence for each sample js_values = np.array([js_divergence(P_pt[i], P_ft[i]) for i in range(len(P_pt))]) # Determine agreement vs. disagreement agree = np.argmax(P_pt, axis=1) == np.argmax(P_ft, axis=1) disagree = ~agree # Dynamic thresholds at the first disagreement boundary conf_thresh = combined_confidence[disagree].min() js_thresh = js_values[disagree].min() # Prepare colors disagree_color = sns.color_palette("Blues", 2)[1] # dark blue agree_color = "violet" # Set up figure fig, ax = plt.subplots(figsize=(5, 5)) # Scatter ax.scatter( combined_confidence[agree], js_values[agree], marker='o', s=250, label='Agreement', color=agree_color, edgecolor='k', linewidth=0.75, alpha=0.5 ) ax.scatter( combined_confidence[disagree], js_values[disagree], marker='P', s=250, label='Disagreement', color=disagree_color, edgecolor='k', linewidth=0.75, alpha=0.5 ) # Threshold lines ax.axvline(x=conf_thresh, linestyle='--', color='gray') ax.axhline(y=js_thresh, linestyle='--', color='gray') # Axis limits with margin x_min, x_max = combined_confidence.min(), combined_confidence.max() y_min, y_max = js_values.min(), js_values.max() x_margin = (x_max - x_min) * 0.05 y_margin = (y_max - y_min) * 0.05 ax.set_xlim(x_min - x_margin, x_max + x_margin) ax.set_ylim(y_min - y_margin, y_max + y_margin) # ax.set_aspect('equal', 'box') ax.xaxis.set_major_locator(MaxNLocator(6)) ax.yaxis.set_major_locator(MaxNLocator(6)) ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) # Aesthetics: no inner grid, outside ticks ax.set_facecolor('white') ax.xaxis.set_tick_params(which='both', bottom=True, top=False, labelbottom=True, labelsize=13) ax.yaxis.set_tick_params(which='both', left=True, right=False, labelleft=True, labelsize=13) for spine in ax.spines.values(): spine.set_visible(True) # Axis labels with bold mathbf and larger font ax.set_xlabel(r'$\mathbf{Combined\ Confidence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}(\max_i\ p_{pt}^{(i)}\ +\ \max_i\ p_{ft}^{(i)})}$', fontsize=13) ax.set_ylabel(r'$\mathbf{Divergence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}[KL(P_{pt}\|M)\ +\ KL(P_{ft}\|M)]}$', fontsize=13) # Title and legend with larger fonts # ax.set_title( # 'Combined Confidence vs. JS Divergence (Agreement in Violet, Disagreement in Blue)', # fontsize=18 # ) ax.legend(fontsize=12, frameon=False, loc='best') # Ensure directory exists and save os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches='tight') plt.close(fig)