#!/usr/bin/env python3 """ patch_sigma_env.py Idempotent patcher for Sigma VLA experiments. Patch goals: 1) LeRobot PI05Policy (modeling_pi05.py): 1.1 If ckpt omits embed_tokens.weight, tie embed_tokens.weight to lm_head.weight *after* load_state_dict runs. 1.2 Ensure torch is imported if target file lacks it. 1.3 Downgrade the "incorrect transformer version" hard guard (ValueError) to a WARNING so new GPU environments don't crash. IMPORTANT: preserve indentation and patch only the intended guard. 2) LeRobot policies __init__ (lerobot/policies/__init__.py): 2.1 Make ONLY Groot/Diffusers-related imports optional (wrapped in try/except), leaving all other exports untouched. This prevents errors like: No module named 'triton.ops' or diffusers/peft chain issues on fresh GPUs. 3) eval_sigma_vla_rollout.py (your /workspace eval script): 3.1 Force strict=False for PI05Policy.from_pretrained calls: - strict=True -> strict=False - if a PI05Policy load call has no strict arg, add strict=False 3.2 Ensure randomized subset evaluation is possible: - add --shuffle arg if missing - change DataLoader shuffle=False -> shuffle=getattr(args,"shuffle",False) Safe to run multiple times; no-op if already patched. """ import os import re import sys import pathlib from typing import Optional, Tuple, List # ------------------------- # Utilities # ------------------------- def _read_text(p: pathlib.Path) -> str: return p.read_text(encoding="utf-8") def _write_text(p: pathlib.Path, s: str) -> None: p.write_text(s, encoding="utf-8") def _search_file( roots: List[os.PathLike], filename: str, must_contain: Optional[str] = None ) -> Optional[pathlib.Path]: for r in roots: r = pathlib.Path(r) if not r.exists(): continue for p in r.rglob(filename): if must_contain and must_contain not in str(p): continue return p return None def _default_roots(): return [ "/workspace/lerobot/src", "/workspace/lerobot", pathlib.Path(sys.prefix) / "lib" / f"python{sys.version_info.major}.{sys.version_info.minor}" / "site-packages", ] # ------------------------- # Patch 1: PI05Policy (LeRobot) # ------------------------- def find_pi05_file() -> pathlib.Path: env = os.getenv("PI05_FILE") if env: p = pathlib.Path(env) if p.exists(): return p p = _search_file(_default_roots(), "modeling_pi05.py", must_contain="/pi05/") if p and p.exists(): return p raise FileNotFoundError("modeling_pi05.py not found. Set PI05_FILE env var to its path.") def ensure_torch_import(s: str) -> str: if re.search(r"(?m)^\s*import\s+torch\b", s) or re.search(r"(?m)^\s*from\s+torch\b", s): return s lines = s.splitlines(True) insert_idx = 0 if lines and lines[0].startswith("#!"): insert_idx = 1 # skip module docstring block if present if insert_idx < len(lines) and lines[insert_idx].lstrip().startswith('"""'): i = insert_idx + 1 while i < len(lines) and '"""' not in lines[i]: i += 1 if i < len(lines): insert_idx = i + 1 lines.insert(insert_idx, "import torch # PATCH: required for embed/lm_head tying\n") return "".join(lines) def patch_pi05_embed_tie(p: pathlib.Path) -> Tuple[bool, str]: s = _read_text(p) s = ensure_torch_import(s) marker = "PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens" if marker in s: _write_text(p, s) return False, f"PI05 embed-tie patch already present: {p}" pat = r"(?m)^(\s*)missing_keys,\s*unexpected_keys\s*=\s*model\.load_state_dict\(\s*remapped_state_dict\s*,\s*strict\s*=\s*strict\s*\)\s*$" m = re.search(pat, s) if not m: _write_text(p, s) return False, f"Could not find load_state_dict line to patch in PI05 file: {p}" indent = m.group(1) inject = ( f"\n{indent}# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---\n" f"{indent}if any('embed_tokens.weight' in k for k in missing_keys):\n" f"{indent} try:\n" f"{indent} with torch.no_grad():\n" f"{indent} embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens\n" f"{indent} lm_head = model.model.paligemma_with_expert.paligemma.lm_head\n" f"{indent} if embed is not None and lm_head is not None:\n" f"{indent} embed.weight = lm_head.weight # {marker}\n" f"{indent} except Exception as _e:\n" f"{indent} print('[patch_pi05] Could not tie embed_tokens to lm_head:', _e)\n" ) s2 = re.sub(pat, lambda mm: mm.group(0) + inject, s, count=1) _write_text(p, s2) return True, f"Patched PI05 embed-tie in: {p}" def patch_pi05_transformers_guard(p: pathlib.Path) -> Tuple[bool, str]: """ Downgrade ONLY the PI05 hard guard: ValueError: An incorrect transformer version is used... to WARNING print, preserving indentation. Strategy: - Find raise ValueError(msg) from None lines. - Only patch the one whose nearby context contains "incorrect transformer version". """ s = _read_text(p) marker = "PATCH: downgrade transformer version guard" if marker in s: return False, f"PI05 transformers-guard patch already present: {p}" if "incorrect transformer version" not in s: return False, f"No transformers guard message found to patch in: {p}" lines = s.splitlines(True) raise_pat = re.compile(r"^(\s*)raise\s+ValueError\(\s*msg\s*\)\s*from\s*None\s*$") target_idx = None target_indent = "" for i, line in enumerate(lines): m = raise_pat.match(line) if not m: continue # look back a few lines for the specific guard text window_start = max(0, i - 8) window = "".join(lines[window_start:i+1]).lower() if "incorrect transformer version" in window: target_idx = i target_indent = m.group(1) break if target_idx is None: return False, f"Guard raise line with context not found in: {p}" repl = ( f"{target_indent}# --- PATCH: downgrade transformer version guard ---\n" f"{target_indent}print('[patch_pi05] WARNING:', msg) # {marker}\n" f"{target_indent}# continues execution despite version mismatch\n" ) lines[target_idx] = repl s2 = "".join(lines) _write_text(p, s2) return True, f"Patched PI05 transformers guard (raise->warn) in: {p}" # ------------------------- # Patch 2: LeRobot policies optional imports # ------------------------- def find_policies_init() -> pathlib.Path: env = os.getenv("POLICIES_INIT_FILE") if env: p = pathlib.Path(env) if p.exists(): return p p = _search_file(_default_roots(), "__init__.py", must_contain="/lerobot/policies/") if p and p.exists(): return p raise FileNotFoundError("lerobot/policies/__init__.py not found. Set POLICIES_INIT_FILE env var.") def patch_policies_optional_imports(p: pathlib.Path) -> Tuple[bool, str]: """ Make ONLY Groot/Diffusers imports optional. This avoids wrapping unrelated exports/imports. """ s = _read_text(p) marker = "PATCH: optional Groot/Diffusers imports" if marker in s: return False, f"Policies optional-import patch already present: {p}" lines = s.splitlines(True) def is_groot_line(line: str) -> bool: # strict filter: only lines that import groot submodule return bool(re.search(r"^\s*from\s+\.\s*groot\b|^\s*from\s+\.groot\b|^\s*import\s+.*\bgroot\b", line)) idxs = [i for i, l in enumerate(lines) if is_groot_line(l)] if not idxs: return False, f"No Groot imports found to wrap in: {p}" # group consecutive indices groups = [] start = prev = idxs[0] for i in idxs[1:]: if i == prev + 1: prev = i else: groups.append((start, prev)) start = prev = i groups.append((start, prev)) new_lines = [] last_end = -1 for (a, b) in groups: # copy lines before this group new_lines.extend(lines[last_end + 1:a]) # wrap group new_lines.append("# --- PATCH: optional Groot/Diffusers imports ---\n") new_lines.append(f"try: # {marker}\n") for j in range(a, b + 1): new_lines.append(" " + lines[j].lstrip()) new_lines.append("except Exception as _e:\n") new_lines.append(" print('[policies_init] WARNING: optional groot deps missing:', _e)\n") last_end = b # copy rest new_lines.extend(lines[last_end + 1:]) s2 = "".join(new_lines) if s2 == s: return False, f"Policies file unchanged after optional-import attempt: {p}" _write_text(p, s2) return True, f"Patched policies __init__ optional imports in: {p}" # ------------------------- # Patch 3: eval_sigma_vla_rollout.py # ------------------------- def find_eval_file() -> pathlib.Path: env = os.getenv("EVAL_FILE") if env: p = pathlib.Path(env) if p.exists(): return p p = pathlib.Path("/workspace/eval_sigma_vla_rollout.py") if p.exists(): return p pp = _search_file(["/workspace", "/workspace/lerobot"], "eval_sigma_vla_rollout.py") if pp and pp.exists(): return pp raise FileNotFoundError("eval_sigma_vla_rollout.py not found. Set EVAL_FILE env var.") def patch_eval_force_strict_false(p: pathlib.Path) -> Tuple[bool, str]: s = _read_text(p) marker = "PATCH: force strict=False for PI05Policy" # 1) strict=True -> strict=False in PI05 loads pat_strict_true = r"(policy_cls\.from_pretrained\([^)]*strict\s*=\s*)True(\s*[^)]*\))" s2, n_true = re.subn(pat_strict_true, r"\1False\2", s) # 2) add strict=False if missing on PI05 loads def _add_strict_false_call(match: re.Match) -> str: call = match.group(0) if "strict" in call: return call return call[:-1] + ", strict=False)" pat_no_strict_1 = r"policy_cls\.from_pretrained\(\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)" pat_no_strict_2 = r"policy_cls\.from_pretrained\(\s*pretrained_name_or_path\s*=\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)" s3, n_add1 = re.subn(pat_no_strict_1, _add_strict_false_call, s2) s4, n_add2 = re.subn(pat_no_strict_2, _add_strict_false_call, s3) changed = (n_true + n_add1 + n_add2) > 0 if not changed: if marker in s: return False, f"Eval strict patch already present: {p}" return False, f"Eval already strict=False or no PI05 strict targets found: {p}" if marker not in s4: # annotate the first strict=False we introduced / touched s4 = s4.replace("strict=False)", f"strict=False) # {marker}", 1) _write_text(p, s4) return True, f"Patched eval PI05 strict=False in: {p}" def patch_eval_shuffle_support(p: pathlib.Path) -> Tuple[bool, str]: s = _read_text(p) marker_arg = "PATCH: add --shuffle arg" marker_dl = "PATCH: DataLoader shuffle uses args.shuffle" changed = False # 1) add CLI arg --shuffle if absent if re.search(r'add_argument\(\s*["\']--shuffle["\']', s) is None: # find last parser.add_argument(...) to insert after arg_pat = re.compile(r"(?m)^\s*parser\.add_argument\(.+?\)\s*$") matches = list(arg_pat.finditer(s)) if matches: last = matches[-1] insert_pos = last.end() insert_text = ( "\nparser.add_argument(" "\"--shuffle\", action=\"store_true\", " "help=\"Shuffle dataset order to sample different subsets per seed.\")" f" # {marker_arg}\n" ) s = s[:insert_pos] + insert_text + s[insert_pos:] changed = True # 2) DataLoader(... shuffle=False ...) -> args.shuffle if marker_dl not in s: def _dl_repl(m: re.Match) -> str: prefix = m.group(1) return prefix + f'getattr(args, "shuffle", False) # {marker_dl}' # replace only literal shuffle=False pat_dl = re.compile(r"(?s)(DataLoader\([\s\S]{0,1200}?shuffle\s*=\s*)False") if pat_dl.search(s): s = pat_dl.sub(_dl_repl, s, count=1) changed = True if changed: _write_text(p, s) return True, f"Patched eval shuffle support in: {p}" return False, f"Eval shuffle support already present or no targets found: {p}" # ------------------------- # Main # ------------------------- def main(): changed_any = False try: pi05_file = find_pi05_file() changed, msg = patch_pi05_embed_tie(pi05_file) print(msg) changed_any |= changed except Exception as e: print("[patch_sigma_env] PI05 embed-tie patch skipped:", e) try: pi05_file = find_pi05_file() changed, msg = patch_pi05_transformers_guard(pi05_file) print(msg) changed_any |= changed except Exception as e: print("[patch_sigma_env] PI05 transformers-guard patch skipped:", e) try: policies_init = find_policies_init() changed, msg = patch_policies_optional_imports(policies_init) print(msg) changed_any |= changed except Exception as e: print("[patch_sigma_env] policies __init__ patch skipped:", e) try: eval_file = find_eval_file() changed, msg = patch_eval_force_strict_false(eval_file) print(msg) changed_any |= changed except Exception as e: print("[patch_sigma_env] Eval strict patch skipped:", e) try: eval_file = find_eval_file() changed, msg = patch_eval_shuffle_support(eval_file) print(msg) changed_any |= changed except Exception as e: print("[patch_sigma_env] Eval shuffle patch skipped:", e) if changed_any: print("[patch_sigma_env] Done. Patches applied.") else: print("[patch_sigma_env] Done. Nothing to change (already patched).") if __name__ == "__main__": main()