Sigma / patch_sigma_env.py
ConorWang's picture
Upload 10 files
03426f9 verified
#!/usr/bin/env python3
"""
patch_sigma_env.py
Idempotent patcher for Sigma VLA experiments.
Patch goals:
1) LeRobot PI05Policy (modeling_pi05.py):
1.1 If ckpt omits embed_tokens.weight, tie embed_tokens.weight to lm_head.weight
*after* load_state_dict runs.
1.2 Ensure torch is imported if target file lacks it.
1.3 Downgrade the "incorrect transformer version" hard guard
(ValueError) to a WARNING so new GPU environments don't crash.
IMPORTANT: preserve indentation and patch only the intended guard.
2) LeRobot policies __init__ (lerobot/policies/__init__.py):
2.1 Make ONLY Groot/Diffusers-related imports optional (wrapped in try/except),
leaving all other exports untouched.
This prevents errors like: No module named 'triton.ops'
or diffusers/peft chain issues on fresh GPUs.
3) eval_sigma_vla_rollout.py (your /workspace eval script):
3.1 Force strict=False for PI05Policy.from_pretrained calls:
- strict=True -> strict=False
- if a PI05Policy load call has no strict arg, add strict=False
3.2 Ensure randomized subset evaluation is possible:
- add --shuffle arg if missing
- change DataLoader shuffle=False -> shuffle=getattr(args,"shuffle",False)
Safe to run multiple times; no-op if already patched.
"""
import os
import re
import sys
import pathlib
from typing import Optional, Tuple, List
# -------------------------
# Utilities
# -------------------------
def _read_text(p: pathlib.Path) -> str:
return p.read_text(encoding="utf-8")
def _write_text(p: pathlib.Path, s: str) -> None:
p.write_text(s, encoding="utf-8")
def _search_file(
roots: List[os.PathLike],
filename: str,
must_contain: Optional[str] = None
) -> Optional[pathlib.Path]:
for r in roots:
r = pathlib.Path(r)
if not r.exists():
continue
for p in r.rglob(filename):
if must_contain and must_contain not in str(p):
continue
return p
return None
def _default_roots():
return [
"/workspace/lerobot/src",
"/workspace/lerobot",
pathlib.Path(sys.prefix)
/ "lib"
/ f"python{sys.version_info.major}.{sys.version_info.minor}"
/ "site-packages",
]
# -------------------------
# Patch 1: PI05Policy (LeRobot)
# -------------------------
def find_pi05_file() -> pathlib.Path:
env = os.getenv("PI05_FILE")
if env:
p = pathlib.Path(env)
if p.exists():
return p
p = _search_file(_default_roots(), "modeling_pi05.py", must_contain="/pi05/")
if p and p.exists():
return p
raise FileNotFoundError("modeling_pi05.py not found. Set PI05_FILE env var to its path.")
def ensure_torch_import(s: str) -> str:
if re.search(r"(?m)^\s*import\s+torch\b", s) or re.search(r"(?m)^\s*from\s+torch\b", s):
return s
lines = s.splitlines(True)
insert_idx = 0
if lines and lines[0].startswith("#!"):
insert_idx = 1
# skip module docstring block if present
if insert_idx < len(lines) and lines[insert_idx].lstrip().startswith('"""'):
i = insert_idx + 1
while i < len(lines) and '"""' not in lines[i]:
i += 1
if i < len(lines):
insert_idx = i + 1
lines.insert(insert_idx, "import torch # PATCH: required for embed/lm_head tying\n")
return "".join(lines)
def patch_pi05_embed_tie(p: pathlib.Path) -> Tuple[bool, str]:
s = _read_text(p)
s = ensure_torch_import(s)
marker = "PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens"
if marker in s:
_write_text(p, s)
return False, f"PI05 embed-tie patch already present: {p}"
pat = r"(?m)^(\s*)missing_keys,\s*unexpected_keys\s*=\s*model\.load_state_dict\(\s*remapped_state_dict\s*,\s*strict\s*=\s*strict\s*\)\s*$"
m = re.search(pat, s)
if not m:
_write_text(p, s)
return False, f"Could not find load_state_dict line to patch in PI05 file: {p}"
indent = m.group(1)
inject = (
f"\n{indent}# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---\n"
f"{indent}if any('embed_tokens.weight' in k for k in missing_keys):\n"
f"{indent} try:\n"
f"{indent} with torch.no_grad():\n"
f"{indent} embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens\n"
f"{indent} lm_head = model.model.paligemma_with_expert.paligemma.lm_head\n"
f"{indent} if embed is not None and lm_head is not None:\n"
f"{indent} embed.weight = lm_head.weight # {marker}\n"
f"{indent} except Exception as _e:\n"
f"{indent} print('[patch_pi05] Could not tie embed_tokens to lm_head:', _e)\n"
)
s2 = re.sub(pat, lambda mm: mm.group(0) + inject, s, count=1)
_write_text(p, s2)
return True, f"Patched PI05 embed-tie in: {p}"
def patch_pi05_transformers_guard(p: pathlib.Path) -> Tuple[bool, str]:
"""
Downgrade ONLY the PI05 hard guard:
ValueError: An incorrect transformer version is used...
to WARNING print, preserving indentation.
Strategy:
- Find raise ValueError(msg) from None lines.
- Only patch the one whose nearby context contains
"incorrect transformer version".
"""
s = _read_text(p)
marker = "PATCH: downgrade transformer version guard"
if marker in s:
return False, f"PI05 transformers-guard patch already present: {p}"
if "incorrect transformer version" not in s:
return False, f"No transformers guard message found to patch in: {p}"
lines = s.splitlines(True)
raise_pat = re.compile(r"^(\s*)raise\s+ValueError\(\s*msg\s*\)\s*from\s*None\s*$")
target_idx = None
target_indent = ""
for i, line in enumerate(lines):
m = raise_pat.match(line)
if not m:
continue
# look back a few lines for the specific guard text
window_start = max(0, i - 8)
window = "".join(lines[window_start:i+1]).lower()
if "incorrect transformer version" in window:
target_idx = i
target_indent = m.group(1)
break
if target_idx is None:
return False, f"Guard raise line with context not found in: {p}"
repl = (
f"{target_indent}# --- PATCH: downgrade transformer version guard ---\n"
f"{target_indent}print('[patch_pi05] WARNING:', msg) # {marker}\n"
f"{target_indent}# continues execution despite version mismatch\n"
)
lines[target_idx] = repl
s2 = "".join(lines)
_write_text(p, s2)
return True, f"Patched PI05 transformers guard (raise->warn) in: {p}"
# -------------------------
# Patch 2: LeRobot policies optional imports
# -------------------------
def find_policies_init() -> pathlib.Path:
env = os.getenv("POLICIES_INIT_FILE")
if env:
p = pathlib.Path(env)
if p.exists():
return p
p = _search_file(_default_roots(), "__init__.py", must_contain="/lerobot/policies/")
if p and p.exists():
return p
raise FileNotFoundError("lerobot/policies/__init__.py not found. Set POLICIES_INIT_FILE env var.")
def patch_policies_optional_imports(p: pathlib.Path) -> Tuple[bool, str]:
"""
Make ONLY Groot/Diffusers imports optional.
This avoids wrapping unrelated exports/imports.
"""
s = _read_text(p)
marker = "PATCH: optional Groot/Diffusers imports"
if marker in s:
return False, f"Policies optional-import patch already present: {p}"
lines = s.splitlines(True)
def is_groot_line(line: str) -> bool:
# strict filter: only lines that import groot submodule
return bool(re.search(r"^\s*from\s+\.\s*groot\b|^\s*from\s+\.groot\b|^\s*import\s+.*\bgroot\b", line))
idxs = [i for i, l in enumerate(lines) if is_groot_line(l)]
if not idxs:
return False, f"No Groot imports found to wrap in: {p}"
# group consecutive indices
groups = []
start = prev = idxs[0]
for i in idxs[1:]:
if i == prev + 1:
prev = i
else:
groups.append((start, prev))
start = prev = i
groups.append((start, prev))
new_lines = []
last_end = -1
for (a, b) in groups:
# copy lines before this group
new_lines.extend(lines[last_end + 1:a])
# wrap group
new_lines.append("# --- PATCH: optional Groot/Diffusers imports ---\n")
new_lines.append(f"try: # {marker}\n")
for j in range(a, b + 1):
new_lines.append(" " + lines[j].lstrip())
new_lines.append("except Exception as _e:\n")
new_lines.append(" print('[policies_init] WARNING: optional groot deps missing:', _e)\n")
last_end = b
# copy rest
new_lines.extend(lines[last_end + 1:])
s2 = "".join(new_lines)
if s2 == s:
return False, f"Policies file unchanged after optional-import attempt: {p}"
_write_text(p, s2)
return True, f"Patched policies __init__ optional imports in: {p}"
# -------------------------
# Patch 3: eval_sigma_vla_rollout.py
# -------------------------
def find_eval_file() -> pathlib.Path:
env = os.getenv("EVAL_FILE")
if env:
p = pathlib.Path(env)
if p.exists():
return p
p = pathlib.Path("/workspace/eval_sigma_vla_rollout.py")
if p.exists():
return p
pp = _search_file(["/workspace", "/workspace/lerobot"], "eval_sigma_vla_rollout.py")
if pp and pp.exists():
return pp
raise FileNotFoundError("eval_sigma_vla_rollout.py not found. Set EVAL_FILE env var.")
def patch_eval_force_strict_false(p: pathlib.Path) -> Tuple[bool, str]:
s = _read_text(p)
marker = "PATCH: force strict=False for PI05Policy"
# 1) strict=True -> strict=False in PI05 loads
pat_strict_true = r"(policy_cls\.from_pretrained\([^)]*strict\s*=\s*)True(\s*[^)]*\))"
s2, n_true = re.subn(pat_strict_true, r"\1False\2", s)
# 2) add strict=False if missing on PI05 loads
def _add_strict_false_call(match: re.Match) -> str:
call = match.group(0)
if "strict" in call:
return call
return call[:-1] + ", strict=False)"
pat_no_strict_1 = r"policy_cls\.from_pretrained\(\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
pat_no_strict_2 = r"policy_cls\.from_pretrained\(\s*pretrained_name_or_path\s*=\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
s3, n_add1 = re.subn(pat_no_strict_1, _add_strict_false_call, s2)
s4, n_add2 = re.subn(pat_no_strict_2, _add_strict_false_call, s3)
changed = (n_true + n_add1 + n_add2) > 0
if not changed:
if marker in s:
return False, f"Eval strict patch already present: {p}"
return False, f"Eval already strict=False or no PI05 strict targets found: {p}"
if marker not in s4:
# annotate the first strict=False we introduced / touched
s4 = s4.replace("strict=False)", f"strict=False) # {marker}", 1)
_write_text(p, s4)
return True, f"Patched eval PI05 strict=False in: {p}"
def patch_eval_shuffle_support(p: pathlib.Path) -> Tuple[bool, str]:
s = _read_text(p)
marker_arg = "PATCH: add --shuffle arg"
marker_dl = "PATCH: DataLoader shuffle uses args.shuffle"
changed = False
# 1) add CLI arg --shuffle if absent
if re.search(r'add_argument\(\s*["\']--shuffle["\']', s) is None:
# find last parser.add_argument(...) to insert after
arg_pat = re.compile(r"(?m)^\s*parser\.add_argument\(.+?\)\s*$")
matches = list(arg_pat.finditer(s))
if matches:
last = matches[-1]
insert_pos = last.end()
insert_text = (
"\nparser.add_argument("
"\"--shuffle\", action=\"store_true\", "
"help=\"Shuffle dataset order to sample different subsets per seed.\")"
f" # {marker_arg}\n"
)
s = s[:insert_pos] + insert_text + s[insert_pos:]
changed = True
# 2) DataLoader(... shuffle=False ...) -> args.shuffle
if marker_dl not in s:
def _dl_repl(m: re.Match) -> str:
prefix = m.group(1)
return prefix + f'getattr(args, "shuffle", False) # {marker_dl}'
# replace only literal shuffle=False
pat_dl = re.compile(r"(?s)(DataLoader\([\s\S]{0,1200}?shuffle\s*=\s*)False")
if pat_dl.search(s):
s = pat_dl.sub(_dl_repl, s, count=1)
changed = True
if changed:
_write_text(p, s)
return True, f"Patched eval shuffle support in: {p}"
return False, f"Eval shuffle support already present or no targets found: {p}"
# -------------------------
# Main
# -------------------------
def main():
changed_any = False
try:
pi05_file = find_pi05_file()
changed, msg = patch_pi05_embed_tie(pi05_file)
print(msg)
changed_any |= changed
except Exception as e:
print("[patch_sigma_env] PI05 embed-tie patch skipped:", e)
try:
pi05_file = find_pi05_file()
changed, msg = patch_pi05_transformers_guard(pi05_file)
print(msg)
changed_any |= changed
except Exception as e:
print("[patch_sigma_env] PI05 transformers-guard patch skipped:", e)
try:
policies_init = find_policies_init()
changed, msg = patch_policies_optional_imports(policies_init)
print(msg)
changed_any |= changed
except Exception as e:
print("[patch_sigma_env] policies __init__ patch skipped:", e)
try:
eval_file = find_eval_file()
changed, msg = patch_eval_force_strict_false(eval_file)
print(msg)
changed_any |= changed
except Exception as e:
print("[patch_sigma_env] Eval strict patch skipped:", e)
try:
eval_file = find_eval_file()
changed, msg = patch_eval_shuffle_support(eval_file)
print(msg)
changed_any |= changed
except Exception as e:
print("[patch_sigma_env] Eval shuffle patch skipped:", e)
if changed_any:
print("[patch_sigma_env] Done. Patches applied.")
else:
print("[patch_sigma_env] Done. Nothing to change (already patched).")
if __name__ == "__main__":
main()