from typing import Tuple, List, Dict from dataclasses import dataclass import math import torch import torch.nn.functional as F from torch import nn from pydantic import BaseModel # Reuse the same building blocks as HRM/TRM from models.common import trunc_normal_init_ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear from models.sparse_embedding import CastedSparseEmbedding """ Global-Local Predictive Solver (GLPS) ------------------------------------ A light-weight control-policy on top of the HRM/TRM style blocks: - H1: global scan -> certainty map - L1: fill-obvious (lock stable cells) - H2: dependency scoring over remaining cells - L2: targeted refinement (masked updates) - H3: energy-based confidence -> (optional) one global propagate sweep -> halt This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer. """ @dataclass class GLPS_ACTV1InnerCarry: z_H: torch.Tensor z_L: torch.Tensor @dataclass class GLPS_ACTV1Carry: inner_carry: GLPS_ACTV1InnerCarry steps: torch.Tensor halted: torch.Tensor current_data: Dict[str, torch.Tensor] class GLPS_ACTV1Config(BaseModel): # Core IO / shapes batch_size: int seq_len: int puzzle_emb_ndim: int = 0 num_puzzle_identifiers: int = 1 vocab_size: int = 256 # Cycle schedule H_cycles: int = 3 # (scan -> refine -> check) typical L_cycles: int = 1 # Depth H_layers: int = 2 L_layers: int = 4 # Transformer config hidden_size: int = 512 expansion: float = 2.0 num_heads: int = 8 pos_encodings: str = "rope" rms_norm_eps: float = 1e-5 rope_theta: float = 10000.0 # ACT wrapper halt_max_steps: int = 4 halt_exploration_prob: float = 0.1 forward_dtype: str = "bfloat16" # Optional: use MLP on L instead of attention (matches HRM/TRM option) mlp_t: bool = False # ---- GLPS extras (tiny) ---- glps_enabled: bool = True glps_fill_obvious: bool = True glps_dep_graph: bool = True glps_token_masking: bool = True glps_global_propagate_on_low_conf: bool = True glps_tau_halt: float = 0.95 # final confidence to halt glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold glps_max_targeted_iters: int = 2 # small number: 1-2 # Dependency scorer (low rank bilinear) dep_rank: int = 32 dep_topk: int = 8 class GLPSBlock(nn.Module): def __init__(self, config: GLPS_ACTV1Config) -> None: super().__init__() self.config = config if self.config.mlp_t: self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) self.mlp_t = SwiGLU( hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel expansion=config.expansion, ) else: self.self_attn = Attention( hidden_size=config.hidden_size, head_dim=config.hidden_size // config.num_heads, num_heads=config.num_heads, num_key_value_heads=config.num_heads, causal=False, ) self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion) self.norm_eps = config.rms_norm_eps def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: if self.config.mlp_t: # MLP over sequence dimension (mlp-t) hidden_states = hidden_states.transpose(1, 2) out = self.mlp_t(hidden_states) hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) hidden_states = hidden_states.transpose(1, 2) else: hidden_states = rms_norm( hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps, ) out = self.mlp(hidden_states) hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) return hidden_states class GLPSReasoningModule(nn.Module): """Reasoning stack with optional masked updates (only update uncertain tokens).""" def __init__(self, layers: List[GLPSBlock]): super().__init__() self.layers = torch.nn.ModuleList(layers) def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor: x = hidden_states for layer in self.layers: # Compute candidate update using injected context y = layer(hidden_states=x + input_injection, **kwargs) if update_mask is not None: # Convex blend keeps frozen tokens unchanged m = update_mask.to(x.dtype)[..., None] x = x + m * (y - x) else: x = y return x class GLPS_ACTV1_Inner(nn.Module): def __init__(self, config: GLPS_ACTV1Config) -> None: super().__init__() self.config = config self.forward_dtype = getattr(torch, self.config.forward_dtype) # I/O self.embed_scale = math.sqrt(self.config.hidden_size) embed_init_std = 1.0 / self.embed_scale self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) # Puzzle emb (optional) — same convention as HRM/TRM self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div if self.config.puzzle_emb_ndim > 0: self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) # Positional encodings if self.config.pos_encodings == "rope": self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta) elif self.config.pos_encodings == "learned": self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) # Reasoning stacks self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)]) self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)]) # Initial states (match HRM/TRM style) H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0) L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0) self.register_buffer("H_init", H_init, persistent=True) self.register_buffer("L_init", L_init, persistent=True) # GLPS small heads self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9 self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True) self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True) # Low-rank dependency scorer (shared) r = max(1, self.config.dep_rank) self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False) self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False) # Q head init like HRM/TRM (near-zero -> easier bootstrapping) with torch.no_grad(): self.q_head.weight.zero_() self.q_head.bias.fill_(-5) def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): # Token embedding embedding = self.embed_tokens(input.to(torch.int32)) # Puzzle embeddings if self.config.puzzle_emb_ndim > 0: puzzle_embedding = self.puzzle_emb(puzzle_identifiers) pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] if pad_count > 0: puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) # Position embeddings if self.config.pos_encodings == "learned": embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) return self.embed_scale * embedding def empty_carry(self, batch_size: int): return GLPS_ACTV1InnerCarry( z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), ) def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry): # Explicitly expand buffers and mask to target shapes to avoid shape confusion B, L, D = carry.z_H.shape # Reduce/reset flag to per-batch boolean vector of shape [B] if reset_flag.ndim == 1 and reset_flag.shape[0] == B: reset_b = reset_flag.to(torch.bool) else: # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries try: reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool) except Exception: reset_b = reset_flag.reshape(-1)[:B].to(torch.bool) m = reset_b.view(B, 1, 1) mH = m.expand(B, L, D) mL = mH # same shape for z_L H_init_exp = self.H_init.expand(B, L, D) L_init_exp = self.L_init.expand(B, L, D) return GLPS_ACTV1InnerCarry( z_H=torch.where(mH, H_init_exp, carry.z_H), z_L=torch.where(mL, L_init_exp, carry.z_L), ) def _global_scan(self, z_L, z_H, input_embeddings, seq_info): # One light pass to gather global signals z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info) cand_logits = self.candidate_head(z_scan) # [B, L, C] certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1] return z_scan, cand_logits, certainty def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor): """Compute a dependency-based focus mask from a low-rank bilinear score. uncertain_mask: [B, L] boolean mask of cells that are currently uncertain Returns: dep_mask [B, L] boolean mask of cells to (re)update. """ B, L, D = z_ctx.shape # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r] K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r] r = max(1, int(Q.shape[-1])) sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32) sim = sim / math.sqrt(r) # Aggregate influence from uncertain queries onto target tokens src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L] influence = torch.matmul(src, sim).squeeze(1) # [B, L] # Top-k influenced tokens per batch topk = min(self.config.dep_topk, L) vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk] dep_mask = torch.zeros_like(uncertain_mask) dep_mask.scatter_(1, idx, True) # Always include uncertain cells themselves dep_mask = dep_mask | uncertain_mask return dep_mask def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]): seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None) # Encode inputs input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) # States z_H, z_L = carry.z_H, carry.z_L if not self.config.glps_enabled: # Fallback to an HRM-like single-cycle grad update for compatibility with torch.no_grad(): for _H in range(self.config.H_cycles - 1): for _L in range(self.config.L_cycles): z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info) z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info) # final grad step for _L in range(self.config.L_cycles): z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info) z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info) # Outputs new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) logits = self.lm_head(z_H)[:, self.puzzle_emb_len:] q_logits = self.q_head(z_H[:, 0]).to(torch.float32) conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf # ===== GLPS path ===== # H1: global scan (cheap) with torch.no_grad(): z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info) # L1: fill-obvious -> compute stable vs uncertain masks if self.config.glps_fill_obvious: obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1] else: obvious_mask = torch.zeros_like(certainty).bool() stable_mask = obvious_mask.squeeze(-1) # [B, L] uncertain_mask = ~stable_mask # [B, L] # H2: dependency prediction over remaining cells if self.config.glps_dep_graph: dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L] else: dep_mask = uncertain_mask # L2: targeted refinement (a couple of masked iters) update_mask = dep_mask if self.config.glps_token_masking else None z = z_scan.detach() # use scanned context as start for _ in range(self.config.glps_max_targeted_iters): z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info) # Refresh certainty to shrink mask (optional but cheap) cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1) if self.config.glps_token_masking: update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain) # Merge into H and do a light H update with grad z_L = z z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info) # H3: energy/consistency -> confidence & optional global propagate energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1] conf = 1.0 - energy need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt) if self.config.glps_global_propagate_on_low_conf and need_sweep.any(): # one final full sweep only for rows needing it maskB = need_sweep.view(-1, 1, 1) zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info) zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info) z_L = torch.where(maskB, zL2, z_L) z_H = torch.where(maskB, zH2, z_H) # Outputs new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) logits = self.lm_head(z_H)[:, self.puzzle_emb_len:] q_logits = self.q_head(z_H[:, 0]).to(torch.float32) return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf class GLPS_ACTV1(nn.Module): """ACT-style wrapper that mixes Q-halt with GLPS confidence.""" def __init__(self, config_dict: dict): super().__init__() self.config = GLPS_ACTV1Config(**config_dict) self.inner = GLPS_ACTV1_Inner(self.config) @property def puzzle_emb(self): return self.inner.puzzle_emb def initial_carry(self, batch: Dict[str, torch.Tensor]): batch_size = batch["inputs"].shape[0] return GLPS_ACTV1Carry( inner_carry=self.inner.empty_carry(batch_size), steps=torch.zeros((batch_size,), dtype=torch.int32), halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass current_data={k: torch.empty_like(v) for k, v in batch.items()} ) def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]): # Reset halted seqs new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) new_steps = torch.where(carry.halted, 0, carry.steps) new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} # Inner step new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data) outputs = { "logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits, "conf": conf.squeeze(-1), } with torch.no_grad(): new_steps = new_steps + 1 is_last_step = new_steps >= self.config.halt_max_steps # Combine halt signals: Q or confidence or last-step halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt) # Exploration during training only if self.training and (self.config.halt_max_steps > 1): min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) halted = halted & (new_steps >= min_halt_steps) # Optional target for Q-learning (kept similar to HRM) next_conf = self.inner(new_inner_carry, new_current_data)[-1] outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits))) else: # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior) halted = is_last_step return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs