|
|
|
|
|
"""
|
|
|
patch_sigma_env.py
|
|
|
|
|
|
Idempotent patcher for Sigma VLA experiments.
|
|
|
|
|
|
Patch goals:
|
|
|
1) LeRobot PI05Policy (modeling_pi05.py):
|
|
|
1.1 If ckpt omits embed_tokens.weight, tie embed_tokens.weight to lm_head.weight
|
|
|
*after* load_state_dict runs.
|
|
|
1.2 Ensure torch is imported if target file lacks it.
|
|
|
1.3 Downgrade the "incorrect transformer version" hard guard
|
|
|
(ValueError) to a WARNING so new GPU environments don't crash.
|
|
|
IMPORTANT: preserve indentation and patch only the intended guard.
|
|
|
|
|
|
2) LeRobot policies __init__ (lerobot/policies/__init__.py):
|
|
|
2.1 Make ONLY Groot/Diffusers-related imports optional (wrapped in try/except),
|
|
|
leaving all other exports untouched.
|
|
|
This prevents errors like: No module named 'triton.ops'
|
|
|
or diffusers/peft chain issues on fresh GPUs.
|
|
|
|
|
|
3) eval_sigma_vla_rollout.py (your /workspace eval script):
|
|
|
3.1 Force strict=False for PI05Policy.from_pretrained calls:
|
|
|
- strict=True -> strict=False
|
|
|
- if a PI05Policy load call has no strict arg, add strict=False
|
|
|
3.2 Ensure randomized subset evaluation is possible:
|
|
|
- add --shuffle arg if missing
|
|
|
- change DataLoader shuffle=False -> shuffle=getattr(args,"shuffle",False)
|
|
|
|
|
|
Safe to run multiple times; no-op if already patched.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import re
|
|
|
import sys
|
|
|
import pathlib
|
|
|
from typing import Optional, Tuple, List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _read_text(p: pathlib.Path) -> str:
|
|
|
return p.read_text(encoding="utf-8")
|
|
|
|
|
|
def _write_text(p: pathlib.Path, s: str) -> None:
|
|
|
p.write_text(s, encoding="utf-8")
|
|
|
|
|
|
def _search_file(
|
|
|
roots: List[os.PathLike],
|
|
|
filename: str,
|
|
|
must_contain: Optional[str] = None
|
|
|
) -> Optional[pathlib.Path]:
|
|
|
for r in roots:
|
|
|
r = pathlib.Path(r)
|
|
|
if not r.exists():
|
|
|
continue
|
|
|
for p in r.rglob(filename):
|
|
|
if must_contain and must_contain not in str(p):
|
|
|
continue
|
|
|
return p
|
|
|
return None
|
|
|
|
|
|
def _default_roots():
|
|
|
return [
|
|
|
"/workspace/lerobot/src",
|
|
|
"/workspace/lerobot",
|
|
|
pathlib.Path(sys.prefix)
|
|
|
/ "lib"
|
|
|
/ f"python{sys.version_info.major}.{sys.version_info.minor}"
|
|
|
/ "site-packages",
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_pi05_file() -> pathlib.Path:
|
|
|
env = os.getenv("PI05_FILE")
|
|
|
if env:
|
|
|
p = pathlib.Path(env)
|
|
|
if p.exists():
|
|
|
return p
|
|
|
|
|
|
p = _search_file(_default_roots(), "modeling_pi05.py", must_contain="/pi05/")
|
|
|
if p and p.exists():
|
|
|
return p
|
|
|
|
|
|
raise FileNotFoundError("modeling_pi05.py not found. Set PI05_FILE env var to its path.")
|
|
|
|
|
|
|
|
|
def ensure_torch_import(s: str) -> str:
|
|
|
if re.search(r"(?m)^\s*import\s+torch\b", s) or re.search(r"(?m)^\s*from\s+torch\b", s):
|
|
|
return s
|
|
|
|
|
|
lines = s.splitlines(True)
|
|
|
insert_idx = 0
|
|
|
|
|
|
if lines and lines[0].startswith("#!"):
|
|
|
insert_idx = 1
|
|
|
|
|
|
|
|
|
if insert_idx < len(lines) and lines[insert_idx].lstrip().startswith('"""'):
|
|
|
i = insert_idx + 1
|
|
|
while i < len(lines) and '"""' not in lines[i]:
|
|
|
i += 1
|
|
|
if i < len(lines):
|
|
|
insert_idx = i + 1
|
|
|
|
|
|
lines.insert(insert_idx, "import torch # PATCH: required for embed/lm_head tying\n")
|
|
|
return "".join(lines)
|
|
|
|
|
|
|
|
|
def patch_pi05_embed_tie(p: pathlib.Path) -> Tuple[bool, str]:
|
|
|
s = _read_text(p)
|
|
|
s = ensure_torch_import(s)
|
|
|
|
|
|
marker = "PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens"
|
|
|
if marker in s:
|
|
|
_write_text(p, s)
|
|
|
return False, f"PI05 embed-tie patch already present: {p}"
|
|
|
|
|
|
pat = r"(?m)^(\s*)missing_keys,\s*unexpected_keys\s*=\s*model\.load_state_dict\(\s*remapped_state_dict\s*,\s*strict\s*=\s*strict\s*\)\s*$"
|
|
|
m = re.search(pat, s)
|
|
|
if not m:
|
|
|
_write_text(p, s)
|
|
|
return False, f"Could not find load_state_dict line to patch in PI05 file: {p}"
|
|
|
|
|
|
indent = m.group(1)
|
|
|
inject = (
|
|
|
f"\n{indent}# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---\n"
|
|
|
f"{indent}if any('embed_tokens.weight' in k for k in missing_keys):\n"
|
|
|
f"{indent} try:\n"
|
|
|
f"{indent} with torch.no_grad():\n"
|
|
|
f"{indent} embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens\n"
|
|
|
f"{indent} lm_head = model.model.paligemma_with_expert.paligemma.lm_head\n"
|
|
|
f"{indent} if embed is not None and lm_head is not None:\n"
|
|
|
f"{indent} embed.weight = lm_head.weight # {marker}\n"
|
|
|
f"{indent} except Exception as _e:\n"
|
|
|
f"{indent} print('[patch_pi05] Could not tie embed_tokens to lm_head:', _e)\n"
|
|
|
)
|
|
|
|
|
|
s2 = re.sub(pat, lambda mm: mm.group(0) + inject, s, count=1)
|
|
|
_write_text(p, s2)
|
|
|
return True, f"Patched PI05 embed-tie in: {p}"
|
|
|
|
|
|
|
|
|
def patch_pi05_transformers_guard(p: pathlib.Path) -> Tuple[bool, str]:
|
|
|
"""
|
|
|
Downgrade ONLY the PI05 hard guard:
|
|
|
ValueError: An incorrect transformer version is used...
|
|
|
to WARNING print, preserving indentation.
|
|
|
|
|
|
Strategy:
|
|
|
- Find raise ValueError(msg) from None lines.
|
|
|
- Only patch the one whose nearby context contains
|
|
|
"incorrect transformer version".
|
|
|
"""
|
|
|
s = _read_text(p)
|
|
|
marker = "PATCH: downgrade transformer version guard"
|
|
|
if marker in s:
|
|
|
return False, f"PI05 transformers-guard patch already present: {p}"
|
|
|
|
|
|
if "incorrect transformer version" not in s:
|
|
|
return False, f"No transformers guard message found to patch in: {p}"
|
|
|
|
|
|
lines = s.splitlines(True)
|
|
|
raise_pat = re.compile(r"^(\s*)raise\s+ValueError\(\s*msg\s*\)\s*from\s*None\s*$")
|
|
|
|
|
|
target_idx = None
|
|
|
target_indent = ""
|
|
|
|
|
|
for i, line in enumerate(lines):
|
|
|
m = raise_pat.match(line)
|
|
|
if not m:
|
|
|
continue
|
|
|
|
|
|
window_start = max(0, i - 8)
|
|
|
window = "".join(lines[window_start:i+1]).lower()
|
|
|
if "incorrect transformer version" in window:
|
|
|
target_idx = i
|
|
|
target_indent = m.group(1)
|
|
|
break
|
|
|
|
|
|
if target_idx is None:
|
|
|
return False, f"Guard raise line with context not found in: {p}"
|
|
|
|
|
|
repl = (
|
|
|
f"{target_indent}# --- PATCH: downgrade transformer version guard ---\n"
|
|
|
f"{target_indent}print('[patch_pi05] WARNING:', msg) # {marker}\n"
|
|
|
f"{target_indent}# continues execution despite version mismatch\n"
|
|
|
)
|
|
|
|
|
|
lines[target_idx] = repl
|
|
|
s2 = "".join(lines)
|
|
|
_write_text(p, s2)
|
|
|
return True, f"Patched PI05 transformers guard (raise->warn) in: {p}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_policies_init() -> pathlib.Path:
|
|
|
env = os.getenv("POLICIES_INIT_FILE")
|
|
|
if env:
|
|
|
p = pathlib.Path(env)
|
|
|
if p.exists():
|
|
|
return p
|
|
|
|
|
|
p = _search_file(_default_roots(), "__init__.py", must_contain="/lerobot/policies/")
|
|
|
if p and p.exists():
|
|
|
return p
|
|
|
|
|
|
raise FileNotFoundError("lerobot/policies/__init__.py not found. Set POLICIES_INIT_FILE env var.")
|
|
|
|
|
|
|
|
|
def patch_policies_optional_imports(p: pathlib.Path) -> Tuple[bool, str]:
|
|
|
"""
|
|
|
Make ONLY Groot/Diffusers imports optional.
|
|
|
This avoids wrapping unrelated exports/imports.
|
|
|
"""
|
|
|
s = _read_text(p)
|
|
|
marker = "PATCH: optional Groot/Diffusers imports"
|
|
|
if marker in s:
|
|
|
return False, f"Policies optional-import patch already present: {p}"
|
|
|
|
|
|
lines = s.splitlines(True)
|
|
|
|
|
|
def is_groot_line(line: str) -> bool:
|
|
|
|
|
|
return bool(re.search(r"^\s*from\s+\.\s*groot\b|^\s*from\s+\.groot\b|^\s*import\s+.*\bgroot\b", line))
|
|
|
|
|
|
idxs = [i for i, l in enumerate(lines) if is_groot_line(l)]
|
|
|
if not idxs:
|
|
|
return False, f"No Groot imports found to wrap in: {p}"
|
|
|
|
|
|
|
|
|
groups = []
|
|
|
start = prev = idxs[0]
|
|
|
for i in idxs[1:]:
|
|
|
if i == prev + 1:
|
|
|
prev = i
|
|
|
else:
|
|
|
groups.append((start, prev))
|
|
|
start = prev = i
|
|
|
groups.append((start, prev))
|
|
|
|
|
|
new_lines = []
|
|
|
last_end = -1
|
|
|
for (a, b) in groups:
|
|
|
|
|
|
new_lines.extend(lines[last_end + 1:a])
|
|
|
|
|
|
|
|
|
new_lines.append("# --- PATCH: optional Groot/Diffusers imports ---\n")
|
|
|
new_lines.append(f"try: # {marker}\n")
|
|
|
for j in range(a, b + 1):
|
|
|
new_lines.append(" " + lines[j].lstrip())
|
|
|
new_lines.append("except Exception as _e:\n")
|
|
|
new_lines.append(" print('[policies_init] WARNING: optional groot deps missing:', _e)\n")
|
|
|
|
|
|
last_end = b
|
|
|
|
|
|
|
|
|
new_lines.extend(lines[last_end + 1:])
|
|
|
|
|
|
s2 = "".join(new_lines)
|
|
|
if s2 == s:
|
|
|
return False, f"Policies file unchanged after optional-import attempt: {p}"
|
|
|
|
|
|
_write_text(p, s2)
|
|
|
return True, f"Patched policies __init__ optional imports in: {p}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_eval_file() -> pathlib.Path:
|
|
|
env = os.getenv("EVAL_FILE")
|
|
|
if env:
|
|
|
p = pathlib.Path(env)
|
|
|
if p.exists():
|
|
|
return p
|
|
|
|
|
|
p = pathlib.Path("/workspace/eval_sigma_vla_rollout.py")
|
|
|
if p.exists():
|
|
|
return p
|
|
|
|
|
|
pp = _search_file(["/workspace", "/workspace/lerobot"], "eval_sigma_vla_rollout.py")
|
|
|
if pp and pp.exists():
|
|
|
return pp
|
|
|
|
|
|
raise FileNotFoundError("eval_sigma_vla_rollout.py not found. Set EVAL_FILE env var.")
|
|
|
|
|
|
|
|
|
def patch_eval_force_strict_false(p: pathlib.Path) -> Tuple[bool, str]:
|
|
|
s = _read_text(p)
|
|
|
marker = "PATCH: force strict=False for PI05Policy"
|
|
|
|
|
|
|
|
|
pat_strict_true = r"(policy_cls\.from_pretrained\([^)]*strict\s*=\s*)True(\s*[^)]*\))"
|
|
|
s2, n_true = re.subn(pat_strict_true, r"\1False\2", s)
|
|
|
|
|
|
|
|
|
def _add_strict_false_call(match: re.Match) -> str:
|
|
|
call = match.group(0)
|
|
|
if "strict" in call:
|
|
|
return call
|
|
|
return call[:-1] + ", strict=False)"
|
|
|
|
|
|
pat_no_strict_1 = r"policy_cls\.from_pretrained\(\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
|
|
|
pat_no_strict_2 = r"policy_cls\.from_pretrained\(\s*pretrained_name_or_path\s*=\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
|
|
|
|
|
|
s3, n_add1 = re.subn(pat_no_strict_1, _add_strict_false_call, s2)
|
|
|
s4, n_add2 = re.subn(pat_no_strict_2, _add_strict_false_call, s3)
|
|
|
|
|
|
changed = (n_true + n_add1 + n_add2) > 0
|
|
|
if not changed:
|
|
|
if marker in s:
|
|
|
return False, f"Eval strict patch already present: {p}"
|
|
|
return False, f"Eval already strict=False or no PI05 strict targets found: {p}"
|
|
|
|
|
|
if marker not in s4:
|
|
|
|
|
|
s4 = s4.replace("strict=False)", f"strict=False) # {marker}", 1)
|
|
|
|
|
|
_write_text(p, s4)
|
|
|
return True, f"Patched eval PI05 strict=False in: {p}"
|
|
|
|
|
|
|
|
|
def patch_eval_shuffle_support(p: pathlib.Path) -> Tuple[bool, str]:
|
|
|
s = _read_text(p)
|
|
|
marker_arg = "PATCH: add --shuffle arg"
|
|
|
marker_dl = "PATCH: DataLoader shuffle uses args.shuffle"
|
|
|
|
|
|
changed = False
|
|
|
|
|
|
|
|
|
if re.search(r'add_argument\(\s*["\']--shuffle["\']', s) is None:
|
|
|
|
|
|
arg_pat = re.compile(r"(?m)^\s*parser\.add_argument\(.+?\)\s*$")
|
|
|
matches = list(arg_pat.finditer(s))
|
|
|
if matches:
|
|
|
last = matches[-1]
|
|
|
insert_pos = last.end()
|
|
|
insert_text = (
|
|
|
"\nparser.add_argument("
|
|
|
"\"--shuffle\", action=\"store_true\", "
|
|
|
"help=\"Shuffle dataset order to sample different subsets per seed.\")"
|
|
|
f" # {marker_arg}\n"
|
|
|
)
|
|
|
s = s[:insert_pos] + insert_text + s[insert_pos:]
|
|
|
changed = True
|
|
|
|
|
|
|
|
|
if marker_dl not in s:
|
|
|
def _dl_repl(m: re.Match) -> str:
|
|
|
prefix = m.group(1)
|
|
|
return prefix + f'getattr(args, "shuffle", False) # {marker_dl}'
|
|
|
|
|
|
|
|
|
pat_dl = re.compile(r"(?s)(DataLoader\([\s\S]{0,1200}?shuffle\s*=\s*)False")
|
|
|
if pat_dl.search(s):
|
|
|
s = pat_dl.sub(_dl_repl, s, count=1)
|
|
|
changed = True
|
|
|
|
|
|
if changed:
|
|
|
_write_text(p, s)
|
|
|
return True, f"Patched eval shuffle support in: {p}"
|
|
|
|
|
|
return False, f"Eval shuffle support already present or no targets found: {p}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
changed_any = False
|
|
|
|
|
|
try:
|
|
|
pi05_file = find_pi05_file()
|
|
|
changed, msg = patch_pi05_embed_tie(pi05_file)
|
|
|
print(msg)
|
|
|
changed_any |= changed
|
|
|
except Exception as e:
|
|
|
print("[patch_sigma_env] PI05 embed-tie patch skipped:", e)
|
|
|
|
|
|
try:
|
|
|
pi05_file = find_pi05_file()
|
|
|
changed, msg = patch_pi05_transformers_guard(pi05_file)
|
|
|
print(msg)
|
|
|
changed_any |= changed
|
|
|
except Exception as e:
|
|
|
print("[patch_sigma_env] PI05 transformers-guard patch skipped:", e)
|
|
|
|
|
|
try:
|
|
|
policies_init = find_policies_init()
|
|
|
changed, msg = patch_policies_optional_imports(policies_init)
|
|
|
print(msg)
|
|
|
changed_any |= changed
|
|
|
except Exception as e:
|
|
|
print("[patch_sigma_env] policies __init__ patch skipped:", e)
|
|
|
|
|
|
try:
|
|
|
eval_file = find_eval_file()
|
|
|
changed, msg = patch_eval_force_strict_false(eval_file)
|
|
|
print(msg)
|
|
|
changed_any |= changed
|
|
|
except Exception as e:
|
|
|
print("[patch_sigma_env] Eval strict patch skipped:", e)
|
|
|
|
|
|
try:
|
|
|
eval_file = find_eval_file()
|
|
|
changed, msg = patch_eval_shuffle_support(eval_file)
|
|
|
print(msg)
|
|
|
changed_any |= changed
|
|
|
except Exception as e:
|
|
|
print("[patch_sigma_env] Eval shuffle patch skipped:", e)
|
|
|
|
|
|
if changed_any:
|
|
|
print("[patch_sigma_env] Done. Patches applied.")
|
|
|
else:
|
|
|
print("[patch_sigma_env] Done. Nothing to change (already patched).")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|