from __future__ import annotations import math from typing import Any, Dict, Optional, Tuple import os import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, PreTrainedModel try: from huggingface_hub import snapshot_download except Exception: # pragma: no cover - optional dependency on HF hub download snapshot_download = None try: from .configuration_rapido import RapidoNerConfig except Exception: # when loaded by HF, configuration is part of the same repo from configuration_rapido import RapidoNerConfig class AttentionPool(nn.Module): """Learnable attention-based pooling for mention spans. Matches training module signature/state dict: `attention.weight`, `attention.bias`. Supports pooling over [B,T,H] + [B,T] or [B,M,T,H] + [B,M,T]. """ def __init__(self, hidden_size: int) -> None: super().__init__() self.attention = nn.Linear(hidden_size, 1) nn.init.zeros_(self.attention.weight) if self.attention.bias is not None: nn.init.zeros_(self.attention.bias) def forward(self, H: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: if H.dim() == 3 and mask.dim() == 2: scores = self.attention(H).squeeze(-1) # [B,T] scores = scores.masked_fill(~mask.bool(), -1e9) weights = torch.softmax(scores, dim=-1).unsqueeze(-1) # [B,T,1] return (H * weights).sum(1) # [B,H] elif H.dim() == 4 and mask.dim() == 3: B, M, T, H_dim = H.shape H_flat = H.view(B * M, T, H_dim) mask_flat = mask.view(B * M, T) scores = self.attention(H_flat).squeeze(-1) scores = scores.masked_fill(~mask_flat.bool(), -1e9) weights = torch.softmax(scores, dim=-1).unsqueeze(-1) pooled = (H_flat * weights).sum(1) return pooled.view(B, M, H_dim) else: raise ValueError(f"Unexpected shapes: H dim={H.dim()} mask dim={mask.dim()}") class EnhancedEntityProjector(nn.Module): """Matches training projector: Linear -> LayerNorm -> GELU -> Dropout -> Linear, + optional residual. Always returns L2-normalized embeddings. """ def __init__( self, in_dim: int, out_dim: int = 768, bottleneck: int = 768, use_residual: bool = True, use_layernorm: bool = True, dropout_p: float = 0.1, ) -> None: super().__init__() self.use_residual = use_residual bn = min(bottleneck, in_dim) layers = [ nn.Linear(in_dim, bn, bias=False), nn.LayerNorm(bn) if use_layernorm else nn.Identity(), nn.GELU(), nn.Dropout(float(dropout_p)), nn.Linear(bn, out_dim, bias=True), ] self.net = nn.Sequential(*layers) # Zero-init final linear for residual-friendly start last = self.net[-1] if isinstance(last, nn.Linear): nn.init.zeros_(last.weight) if last.bias is not None: nn.init.zeros_(last.bias) self.residual_proj = nn.Linear(in_dim, out_dim, bias=False) if (use_residual and in_dim != out_dim) else None def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.net(x) if self.use_residual: residual = self.residual_proj(x) if self.residual_proj is not None else x out = out + residual return F.normalize(out, dim=-1) class TypeAwareProjector(nn.Module): """Type-conditioned projector: base EnhancedEntityProjector followed by FiLM modulation. State dict keys: `base.net.*`, `type_emb.weight`. """ def __init__( self, in_dim: int, out_dim: int, num_types: int, *, bottleneck: int = 256, use_residual: bool = True, use_layernorm: bool = True, dropout_p: float = 0.1, gamma_init: float = 0.0, beta_init: float = 0.0, ) -> None: super().__init__() self.out_dim = out_dim self.base = EnhancedEntityProjector( in_dim=in_dim, out_dim=out_dim, bottleneck=bottleneck, use_residual=use_residual, use_layernorm=use_layernorm, dropout_p=dropout_p, ) self.type_emb = nn.Embedding(int(num_types), int(2 * out_dim)) with torch.no_grad(): self.type_emb.weight.zero_() if gamma_init != 0.0 or beta_init != 0.0: self.type_emb.weight[:, :out_dim].fill_(float(gamma_init)) self.type_emb.weight[:, out_dim:].fill_(float(beta_init)) def forward(self, x: torch.Tensor, type_ids: torch.Tensor) -> torch.Tensor: base = self.base(x) if type_ids is None: return base tb = self.type_emb(type_ids.long()) gamma, beta = tb.split(self.out_dim, dim=-1) y = base * (1.0 + gamma) + beta return F.normalize(y, dim=-1) class RapidoForTokenClassificationAndEntity(PreTrainedModel): """Custom remote-coded model wrapping a HF backbone with NER + entity projection. Usage: model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) tok = AutoTokenizer.from_pretrained(model.config.backbone_model_name_or_path, trust_remote_code=True) out = model(input_ids, attention_mask, mention_mask=mask) logits = out["logits"]; embeddings = out.get("entity_embeddings") """ config_class = RapidoNerConfig def __init__(self, config: RapidoNerConfig) -> None: # type: ignore[override] super().__init__(config) self.num_labels = int(config.num_labels) self.id2label = {int(k): str(v) for k, v in (getattr(config, "id2label", {}) or {}).items()} self._use_crf = bool(getattr(config, "use_crf", False)) self.crf = None # Backbone bb_ref = config.backbone_model_name_or_path load_kwargs: Dict[str, Any] = {"trust_remote_code": True} commit_hash = getattr(config, "_commit_hash", None) def _download_repo(repo_id: Optional[str], rel: Optional[str]) -> Optional[str]: if snapshot_download is None or not repo_id or not rel: return None try: root = snapshot_download( repo_id, repo_type="model", revision=commit_hash, allow_patterns=[f"{rel}/*"], local_files_only=False, ) candidate = os.path.join(root, rel) return candidate if os.path.isdir(candidate) else None except Exception: return None if isinstance(bb_ref, str) and not os.path.isdir(bb_ref): repo_hint = getattr(config, "name_or_path", None) or getattr(config, "_name_or_path", None) if "/" not in bb_ref: local_path = _download_repo(repo_hint if isinstance(repo_hint, str) else None, bb_ref) if local_path is not None: bb_ref = local_path else: if isinstance(repo_hint, str) and repo_hint: load_kwargs["subfolder"] = bb_ref bb_ref = repo_hint else: parts = bb_ref.split("/") if len(parts) > 2: repo_id = "/".join(parts[:2]) sub_path = "/".join(parts[2:]) local_path = _download_repo(repo_id, sub_path) if local_path is not None: bb_ref = local_path else: load_kwargs["subfolder"] = sub_path bb_ref = repo_id if os.environ.get("RAPIDO_DEBUG_BACKBONE"): print(f"[Rapido] loading backbone from {bb_ref} with kwargs {load_kwargs}") self.backbone = AutoModel.from_pretrained(bb_ref, **load_kwargs) hidden = getattr(self.backbone.config, "hidden_size", None) or getattr(self.backbone.config, "hidden_dim", None) if hidden is None: raise ValueError("Cannot infer hidden size from backbone.config") first_param = next(self.backbone.parameters()) _dtype = first_param.dtype _device = first_param.device # Heads self.dropout = nn.Dropout(float(config.dropout)) self.ner_head = nn.Linear(hidden, self.num_labels).to(device=_device, dtype=torch.float32) self.attention_pool = AttentionPool(hidden).to(device=_device, dtype=torch.float32) if self._use_crf: try: start_ok, end_ok, trans_ok = build_bio_constraints(self.id2label) except Exception: start_ok = end_ok = trans_ok = None # type: ignore[assignment] self.crf = LinearCRF( self.num_labels, start_ok=start_ok, end_ok=end_ok, trans_ok=trans_ok, ).to(device=_device) # Projector: optionally type-aware self._type_aware = bool(getattr(config, "type_aware_proj", False)) if self._type_aware: num_types = int(getattr(config, "num_types", 1) or 1) self.projector = TypeAwareProjector( in_dim=hidden, out_dim=int(config.proj_out), num_types=num_types, bottleneck=min(hidden, 256), use_residual=True, use_layernorm=True, dropout_p=float(config.dropout), ).to(device=_device, dtype=torch.float32) else: self.projector = EnhancedEntityProjector( in_dim=hidden, out_dim=int(config.proj_out), bottleneck=min(hidden, 768), dropout_p=float(config.dropout), use_residual=True, use_layernorm=True, ).to(device=_device, dtype=torch.float32) self._dropout_p = float(config.dropout) # set to eval mode by default self.eval() def encode_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: out = self.backbone(input_ids=input_ids, attention_mask=attention_mask) H = out.last_hidden_state return self.dropout(H) def forward_ner(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: H = self.encode_tokens(input_ids, attention_mask) if H.dtype != self.ner_head.weight.dtype: H = H.to(self.ner_head.weight.dtype) return self.ner_head(H) def encode_mentions_with_attention(self, H: torch.Tensor, mention_mask: torch.Tensor) -> torch.Tensor: if H.dim() != 3: raise ValueError("Expected hidden states [B,T,D]") try: pool_dtype = next(self.attention_pool.parameters()).dtype except StopIteration: pool_dtype = H.dtype if H.dtype != pool_dtype: H = H.to(pool_dtype) if mention_mask.dim() == 2: return self.attention_pool(H, mention_mask) if mention_mask.dim() == 3: B, T, D = H.shape M = mention_mask.size(1) H_rep = H.unsqueeze(1).expand(B, M, T, D) return self.attention_pool(H_rep, mention_mask) raise ValueError("Unexpected mention_mask shape") def project_mentions(self, span_embeddings: torch.Tensor, type_ids: Optional[torch.Tensor] = None) -> torch.Tensor: if self._type_aware and isinstance(self.projector, TypeAwareProjector): if type_ids is None: type_ids = torch.zeros(span_embeddings.size(0), dtype=torch.long, device=span_embeddings.device) elif type_ids.dim() > 1: type_ids = type_ids.view(-1) return self.projector(span_embeddings, type_ids) return self.projector(span_embeddings) def enable_type_aware_projector(self, num_types: int) -> None: hidden = getattr(self.backbone.config, "hidden_size", None) or getattr(self.backbone.config, "hidden_dim", None) if hidden is None: hidden = next(self.backbone.parameters()).shape[-1] try: out_dim = self.projector.out_dim # type: ignore[attr-defined] except AttributeError: out_dim = getattr(self.projector, "net", None)[-1].out_features if hasattr(self.projector, "net") else hidden self.projector = TypeAwareProjector( in_dim=hidden, out_dim=int(out_dim), num_types=int(num_types), bottleneck=min(hidden, 256), use_residual=True, use_layernorm=True, dropout_p=float(self._dropout_p), ).to(device=next(self.backbone.parameters()).device, dtype=torch.float32) self._type_aware = True def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, mention_mask: Optional[torch.Tensor] = None, type_ids: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Dict[str, torch.Tensor]: # Encode H = self.encode_tokens(input_ids, attention_mask if attention_mask is not None else torch.ones_like(input_ids)) # NER logits H_logits = H if H.dtype == self.ner_head.weight.dtype else H.to(self.ner_head.weight.dtype) logits = self.ner_head(H_logits) out: Dict[str, torch.Tensor] = {"logits": logits} # Optional entity embeddings if mention_mask is not None: if mention_mask.dim() == 2: spans = self.encode_mentions_with_attention(H, mention_mask) emb = self.project_mentions(spans, type_ids) elif mention_mask.dim() == 3: spans = self.encode_mentions_with_attention(H, mention_mask) if isinstance(spans, torch.Tensor) and spans.dim() == 3: B, M, D = spans.shape spans_flat = spans.reshape(B * M, D) if type_ids is not None: type_ids_flat = type_ids.view(-1) else: type_ids_flat = None proj = self.project_mentions(spans_flat, type_ids_flat) emb = proj.view(B, M, -1) else: emb = self.project_mentions(spans, type_ids) else: raise ValueError("mention_mask must be [B,T] or [B,M,T]") out["entity_embeddings"] = emb # Optional training loss (simple CE) if labels is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) out["loss"] = loss return out def crf_neg_log_likelihood( self, emissions: torch.Tensor, labels: torch.Tensor, mask: Optional[torch.Tensor] = None, ignore_index: int = -100, ) -> torch.Tensor: if self.crf is None: raise RuntimeError("CRF not initialized on model") return self.crf.neg_log_likelihood(emissions, labels, mask=mask, ignore_index=ignore_index) def crf_decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> list[list[int]]: if self.crf is None: raise RuntimeError("CRF not initialized on model") return self.crf.decode(emissions, mask) # --- Minimal CRF implementation + BIO constraints (for inference parity) --- def _parse_bio(label: str): if label == "O": return "O", None if "-" in label: p, t = label.split("-", 1) p = p.upper() return (p if p in {"B", "I"} else "B"), t return "B", label def build_bio_constraints(id2label: Dict[int, str]): C = 1 + max(id2label.keys()) if id2label else 1 start_ok = torch.ones(C, dtype=torch.bool) end_ok = torch.ones(C, dtype=torch.bool) trans_ok = torch.ones(C, C, dtype=torch.bool) prefixes = {} types = {} for i in range(C): p, t = _parse_bio(id2label.get(i, "O")) prefixes[i] = p types[i] = t for j in range(C): if prefixes[j] == "I": start_ok[j] = False for i in range(C): pi, ti = prefixes[i], types[i] for j in range(C): pj, tj = prefixes[j], types[j] ok = True if pj == "I": ok = (pi in {"B", "I"}) and (ti == tj) and (tj is not None) trans_ok[i, j] = ok return start_ok, end_ok, trans_ok class LinearCRF(nn.Module): def __init__( self, num_tags: int, *, start_ok: Optional[torch.Tensor] = None, end_ok: Optional[torch.Tensor] = None, trans_ok: Optional[torch.Tensor] = None, constraint_weight: float = 1e4, ) -> None: super().__init__() self.num_tags = int(num_tags) self.start_transitions = nn.Parameter(torch.zeros(self.num_tags)) self.end_transitions = nn.Parameter(torch.zeros(self.num_tags)) self.transitions = nn.Parameter(torch.zeros(self.num_tags, self.num_tags)) self.register_buffer("start_ok", start_ok if start_ok is not None else torch.ones(self.num_tags, dtype=torch.bool)) self.register_buffer("end_ok", end_ok if end_ok is not None else torch.ones(self.num_tags, dtype=torch.bool)) self.register_buffer("trans_ok", trans_ok if trans_ok is not None else torch.ones(self.num_tags, self.num_tags, dtype=torch.bool)) self.constraint_penalty = float(constraint_weight) def _constrain(self): neg_inf = -self.constraint_penalty start = self.start_transitions.masked_fill(~self.start_ok, neg_inf) end = self.end_transitions.masked_fill(~self.end_ok, neg_inf) trans = self.transitions.masked_fill(~self.trans_ok, neg_inf) return start, end, trans @torch.no_grad() def decode(self, emissions: torch.Tensor, mask: torch.Tensor): emissions = emissions.float() B, T, C = emissions.shape start, end, trans = self._constrain() paths = [] for b in range(B): valid_idx = torch.nonzero(mask[b], as_tuple=False).view(-1) if valid_idx.numel() == 0: paths.append([]) continue delta = start + emissions[b, valid_idx[0]] # [C] psi = emissions.new_full((valid_idx.numel(), C), -1, dtype=torch.long) for ti in range(1, valid_idx.numel()): t = int(valid_idx[ti].item()) prev = delta.unsqueeze(1) + trans best_prev, best_tag = prev.max(dim=0) delta = best_prev + emissions[b, t] psi[ti] = best_tag delta = delta + end last = int(torch.argmax(delta).item()) out = [last] for ti in range(valid_idx.numel() - 1, 0, -1): last = int(psi[ti, last].item()) out.append(last) out.reverse() paths.append(out) return paths