TCube_Merging / baselines.py
razaimam45's picture
Upload 108 files
a96891a verified
import torch
import copy
import numpy as np
from scipy.stats import pearsonr
from t_cube import evaluate_model
def evaluate_slerp(clip_pt, sd_pt, sd_ft, dataloader, args, alpha=0.5):
"""
SLERP (spherical linear interpolation) between pretrained (pt) and fine-tuned (ft) weights.
alpha=0 -> pt only; alpha=1 -> ft only.
"""
model = copy.deepcopy(clip_pt)
merged_sd = {}
# flatten-per-key SLERP
for k in sd_pt.keys():
w1 = sd_pt[k].flatten().float()
w2 = sd_ft[k].flatten().float()
# cosine similarity
cos_val = torch.dot(w1, w2) / (w1.norm() * w2.norm() + 1e-8)
omega = torch.acos(torch.clamp(cos_val, -1+1e-6, 1-1e-6))
sin_omega = torch.sin(omega)
if sin_omega < 1e-6:
w_interp = (1-alpha)*w1 + alpha*w2
else:
w_interp = (torch.sin((1-alpha)*omega)/sin_omega)*w1 + \
(torch.sin(alpha*omega)/sin_omega)*w2
merged_sd[k] = w_interp.view_as(sd_pt[k])
model.load_state_dict(merged_sd)
return evaluate_model(model, dataloader, args)
def evaluate_m3(clip_pt, sd_pt, sd_ft, dataloader, args):
"""
M^3 (Mixup Model Merge): sample lambda ~ Uniform(0,1) and do linear interpolation.
"""
model = copy.deepcopy(clip_pt)
lam = np.random.rand()
merged_sd = {k: lam * sd_ft[k] + (1 - lam) * sd_pt[k]
for k in sd_pt.keys()}
model.load_state_dict(merged_sd)
return evaluate_model(model, dataloader, args)
def evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, dataloader, args):
"""
Task Arithmetic: extrapolate along the ft−pt vector, i.e. 2*ft – pt.
"""
model = copy.deepcopy(clip_pt)
merged_sd = {k: 2 * sd_ft[k] - sd_pt[k] for k in sd_pt.keys()}
model.load_state_dict(merged_sd)
return evaluate_model(model, dataloader, args)