Spaces:
Paused
Paused
| import torch | |
| def set_stand_in(pipe, train=False, model_path=None): | |
| for block in pipe.dit.blocks: | |
| block.self_attn.init_lora(train) | |
| if model_path is not None: | |
| print(f"Loading Stand-In weights from: {model_path}") | |
| load_lora_weights_into_pipe(pipe, model_path) | |
| def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True): | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = ckpt.get("state_dict", ckpt) | |
| model = {} | |
| for i, block in enumerate(pipe.dit.blocks): | |
| prefix = f"blocks.{i}.self_attn." | |
| attn = block.self_attn | |
| for name in ["q_loras", "k_loras", "v_loras"]: | |
| for sub in ["down", "up"]: | |
| key = f"{prefix}{name}.{sub}.weight" | |
| if hasattr(getattr(attn, name), sub): | |
| model[key] = getattr(getattr(attn, name), sub).weight | |
| else: | |
| if strict: | |
| raise KeyError(f"Missing module: {key}") | |
| for k, param in state_dict.items(): | |
| if k in model: | |
| if model[k].shape != param.shape: | |
| if strict: | |
| raise ValueError( | |
| f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}" | |
| ) | |
| else: | |
| continue | |
| model[k].data.copy_(param) | |
| else: | |
| if strict: | |
| raise KeyError(f"Unexpected key in ckpt: {k}") | |