File size: 1,020 Bytes
03426f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
index b017bbc5..d6290da6 100644
--- a/src/lerobot/policies/pi05/modeling_pi05.py
+++ b/src/lerobot/policies/pi05/modeling_pi05.py
@@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy):
             if remap_count > 0:
                 print(f"Remapped {remap_count} state dict keys")
 
             # Load the remapped state dict into the model
             missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
+
+            # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt ---
+            if any("embed_tokens.weight" in k for k in missing_keys):
+                with torch.no_grad():
+                    embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens
+                    lm_head = model.model.paligemma_with_expert.paligemma.lm_head
+                    embed.weight = lm_head.weight
 
             return model