|
|
import torch |
|
|
import copy |
|
|
import numpy as np |
|
|
from scipy.stats import pearsonr |
|
|
from t_cube import evaluate_model |
|
|
|
|
|
def evaluate_slerp(clip_pt, sd_pt, sd_ft, dataloader, args, alpha=0.5): |
|
|
""" |
|
|
SLERP (spherical linear interpolation) between pretrained (pt) and fine-tuned (ft) weights. |
|
|
alpha=0 -> pt only; alpha=1 -> ft only. |
|
|
""" |
|
|
model = copy.deepcopy(clip_pt) |
|
|
merged_sd = {} |
|
|
|
|
|
for k in sd_pt.keys(): |
|
|
w1 = sd_pt[k].flatten().float() |
|
|
w2 = sd_ft[k].flatten().float() |
|
|
|
|
|
cos_val = torch.dot(w1, w2) / (w1.norm() * w2.norm() + 1e-8) |
|
|
omega = torch.acos(torch.clamp(cos_val, -1+1e-6, 1-1e-6)) |
|
|
sin_omega = torch.sin(omega) |
|
|
if sin_omega < 1e-6: |
|
|
w_interp = (1-alpha)*w1 + alpha*w2 |
|
|
else: |
|
|
w_interp = (torch.sin((1-alpha)*omega)/sin_omega)*w1 + \ |
|
|
(torch.sin(alpha*omega)/sin_omega)*w2 |
|
|
merged_sd[k] = w_interp.view_as(sd_pt[k]) |
|
|
model.load_state_dict(merged_sd) |
|
|
return evaluate_model(model, dataloader, args) |
|
|
|
|
|
|
|
|
def evaluate_m3(clip_pt, sd_pt, sd_ft, dataloader, args): |
|
|
""" |
|
|
M^3 (Mixup Model Merge): sample lambda ~ Uniform(0,1) and do linear interpolation. |
|
|
""" |
|
|
model = copy.deepcopy(clip_pt) |
|
|
lam = np.random.rand() |
|
|
merged_sd = {k: lam * sd_ft[k] + (1 - lam) * sd_pt[k] |
|
|
for k in sd_pt.keys()} |
|
|
model.load_state_dict(merged_sd) |
|
|
return evaluate_model(model, dataloader, args) |
|
|
|
|
|
|
|
|
def evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, dataloader, args): |
|
|
""" |
|
|
Task Arithmetic: extrapolate along the ft−pt vector, i.e. 2*ft – pt. |
|
|
""" |
|
|
model = copy.deepcopy(clip_pt) |
|
|
merged_sd = {k: 2 * sd_ft[k] - sd_pt[k] for k in sd_pt.keys()} |
|
|
model.load_state_dict(merged_sd) |
|
|
return evaluate_model(model, dataloader, args) |
|
|
|