Sigma / pi05_embed_tie.patch
ConorWang's picture
Upload 10 files
03426f9 verified
diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
index b017bbc5..d6290da6 100644
--- a/src/lerobot/policies/pi05/modeling_pi05.py
+++ b/src/lerobot/policies/pi05/modeling_pi05.py
@@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy):
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
+
+ # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt ---
+ if any("embed_tokens.weight" in k for k in missing_keys):
+ with torch.no_grad():
+ embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
+ lm_head = model.model.paligemma_with_expert.paligemma.lm_head
+ embed.weight = lm_head.weight
return model