|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy): |
|
|
if remap_count > 0: |
|
|
print(f"Remapped {remap_count} state dict keys") |
|
|
|
|
|
# Load the remapped state dict into the model |
|
|
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) |
|
|
+ |
|
|
+ # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt --- |
|
|
+ if any("embed_tokens.weight" in k for k in missing_keys): |
|
|
+ with torch.no_grad(): |
|
|
+ embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens |
|
|
+ lm_head = model.model.paligemma_with_expert.paligemma.lm_head |
|
|
+ embed.weight = lm_head.weight |
|
|
|
|
|
return model |
|
|
|