""" Standalone utilities and lightweight Hugging Face integrations for running Provence reranker checkpoints. `OpenProvenceModel` provides a self-contained wrapper that can be copied next to a checkpoint and executed without installing the full ``open_provence`` package. In addition, this module now exposes `OpenProvenceConfig`, `OpenProvenceForSequenceClassification`, and `OpenProvenceForTokenClassification` so that checkpoints can be loaded via ``transformers.AutoModel`` without shipping extra modeling files. Keep this module self-contained—avoid intra-package imports—so exported checkpoints remain portable. """ from __future__ import annotations import contextlib import logging import math import os import platform import re import warnings from collections import OrderedDict, defaultdict from collections.abc import Callable, Iterable, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass from pathlib import Path from time import perf_counter from typing import Any, TypeAlias, cast import numpy as np import torch import transformers.utils.logging as hf_logging from torch import FloatTensor, Tensor, nn from torch.utils.data import DataLoader, Dataset from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils.generic import ModelOutput try: import nltk from nltk.tokenize import PunktSentenceTokenizer except ImportError as exc: # pragma: no cover - mandatory dependency raise ImportError( "modeling_open_provence_standalone.py requires `nltk`. Install via `uv add nltk`." ) from exc LOGGER = logging.getLogger(__name__) DEFAULT_SPLITTER_LANGUAGE = "auto" # Updated during export; keep marker for tooling DEFAULT_PROCESS_THRESHOLD = 0.1 # Default pruning threshold when config does not specify one _PROGRESS_BAR_ENABLED = True def enable_progress_bar() -> None: """Enable progress output for preprocessing and inference helpers.""" global _PROGRESS_BAR_ENABLED _PROGRESS_BAR_ENABLED = True def disable_progress_bar() -> None: """Disable progress output for preprocessing and inference helpers.""" global _PROGRESS_BAR_ENABLED _PROGRESS_BAR_ENABLED = False def is_progress_bar_enabled() -> bool: """Return True when progress output should be shown.""" return _PROGRESS_BAR_ENABLED def _default_preprocess_workers() -> int: """Infer a reasonable default number of preprocessing workers.""" cpu_total: int | None = None try: # pragma: no cover - optional dependency import psutil cpu_total = psutil.cpu_count(logical=False) or psutil.cpu_count(logical=True) except Exception: cpu_total = os.cpu_count() if cpu_total is None: return 0 return max(0, int(cpu_total) - 1) _ENGLISH_SENTENCE_TOKENIZER: PunktSentenceTokenizer | None = None DEFAULT_ENGLISH_SENTENCE_MAX_CHARS = 1200 _ENGLISH_LANGUAGE_ALIASES = { "en", "english", "en-us", "en_gb", "en-gb", "en_us", } _BULLET_PREFIX_RE = re.compile( r"""^\s*(?:[\-\*\u2022•]+|\d{1,4}[:.)]|[A-Za-z]{1}[:.)])\s+""", re.UNICODE, ) _WORD_TOKEN_RE = re.compile(r"[A-Za-z0-9']+") _TABLE_ROW_RE = re.compile(r"^\s*\|") _NUMERIC_HEADING_RE = re.compile(r"^\s*\d{3,}[:\-]") SUPPORTED_SPLITTER_LANGUAGES = {"ja", "en", "auto"} def _is_kana_letter_cp(cp: int) -> bool: """Return True when code point corresponds to a kana letter.""" if 0x3041 <= cp <= 0x3096: # Hiragana letters (ぁ-ゖ) return True if 0x30A1 <= cp <= 0x30FA: # Katakana letters (ァ-ヺ) return True if 0x31F0 <= cp <= 0x31FF: # Katakana phonetic extensions (ㇰ-ㇿ) return True if 0xFF71 <= cp <= 0xFF9D: # Half-width katakana letters (ア-ン) return True return False def is_japanese_fast(text: str, window: int = 500, min_kana_per_window: int = 1) -> bool: """Heuristic that quickly classifies text as Japanese when kana density is high.""" if not text: return False if text.isascii(): return False required = math.ceil(len(text) / window) * min_kana_per_window if required <= 0: return False count = 0 for ch in text: cp = ord(ch) if cp > 0x7F and _is_kana_letter_cp(cp): count += 1 if count >= required: return True return False warnings.filterwarnings("ignore", message="Flash Attention 2 only supports") os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") _transformers_logger = logging.getLogger("transformers.modeling_utils") _dynamic_module_logger = logging.getLogger("transformers.dynamic_module_utils") class _SuppressTransformersWarnings(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - log hygiene message = record.getMessage() if "Flash Attention 2 only supports" in message: return False if "`torch_dtype` is deprecated" in message: return False return True _transformers_logger.addFilter(_SuppressTransformersWarnings()) class _SuppressDynamicModuleWarnings(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - log hygiene message = record.getMessage() if "The module name" in message and "is not a valid Python identifier" in message: return False if "The module name" in message and "is a reserved keyword" in message: return False return True _dynamic_module_logger.addFilter(_SuppressDynamicModuleWarnings()) _LOGGING_CONFIGURED = False def _ensure_transformers_logging_configured() -> None: """Configure transformers logging once to suppress noisy warnings in standalone mode.""" global _LOGGING_CONFIGURED if _LOGGING_CONFIGURED: return hf_logging.set_verbosity_error() _LOGGING_CONFIGURED = True def _supports_flash_attention() -> bool: """Return True when CUDA is available and we optimistically enable FlashAttention v2.""" if not torch.cuda.is_available(): return False try: pass # type: ignore[import-not-found] except Exception: return False return True def _select_default_torch_dtype(device: str | None) -> torch.dtype | str | None: """Select a sensible default dtype based on the target device.""" if not device: return None normalized = str(device).lower() if normalized == "cuda" and torch.cuda.is_available(): supports_bf16 = getattr(torch.cuda, "is_bf16_supported", None) try: if callable(supports_bf16) and supports_bf16(): return torch.bfloat16 except Exception: pass return torch.float16 if normalized == "mps": return "auto" if normalized == "cpu": system = platform.system() machine = platform.machine().lower() if system == "Darwin" and machine in {"arm64", "aarch64"}: return "auto" return None def _coerce_dtype_for_torch_to(value: torch.dtype | str | None) -> torch.dtype | None: """Convert user/config provided dtype hints into torch.dtype for Module.to.""" if value is None or isinstance(value, torch.dtype): return value normalized = str(value).strip().lower() if normalized == "auto": return None # Accept common dtype aliases used by Transformers configs/CLI flags. alias_map: dict[str, torch.dtype] = { "float32": torch.float32, "fp32": torch.float32, "32": torch.float32, "float16": torch.float16, "fp16": torch.float16, "half": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, } resolved = alias_map.get(normalized) if resolved is None: raise TypeError(f"Unsupported dtype value for torch.to(): {value!r}") return resolved def _mps_is_available() -> bool: backend = getattr(torch, "backends", None) if backend is None: return False mps = getattr(backend, "mps", None) if mps is None: return False try: return bool(mps.is_available()) except Exception: return False def auto_detect_device() -> torch.device: system = platform.system() machine = platform.machine().lower() if system == "Darwin" and machine in {"arm64", "aarch64"} and _mps_is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") if _mps_is_available(): return torch.device("mps") return torch.device("cpu") def _validate_device(candidate: torch.device) -> None: if candidate.type == "cuda": if not torch.cuda.is_available(): raise ValueError("CUDA device requested but CUDA is not available.") if candidate.index is not None: total = torch.cuda.device_count() if candidate.index < 0 or candidate.index >= total: raise ValueError( f"CUDA device index {candidate.index} out of range (count={total})." ) elif candidate.type == "mps": if not _mps_is_available(): raise ValueError("MPS device requested but MPS backend is not available.") def resolve_inference_device(device: str | torch.device | None) -> torch.device: if isinstance(device, torch.device): candidate = device elif device is None: return auto_detect_device() else: normalized = str(device).strip().lower() if not normalized or normalized == "auto": return auto_detect_device() if normalized == "cpu": candidate = torch.device("cpu") elif normalized.startswith("cuda"): candidate = torch.device(normalized) elif normalized.startswith("mps"): candidate = torch.device("mps") else: raise ValueError(f"Unsupported device specification: {device!r}") _validate_device(candidate) return candidate try: from fast_bunkai import FastBunkai except ImportError: # pragma: no cover - optional dependency FastBunkai = None _FAST_BUNKAI = None if FastBunkai is not None: # pragma: no branch try: _FAST_BUNKAI = FastBunkai() except Exception as exc: # pragma: no cover - runtime safety raise RuntimeError("Failed to initialize FastBunkai sentence splitter") from exc @dataclass class OpenProvenceHeadConfig: """Lightweight configuration for the pruning head.""" hidden_size: int = 768 num_labels: int = 2 classifier_dropout: float = 0.1 sentence_pooling: str = "mean" use_weighted_pooling: bool = False def __init__(self, **kwargs: Any) -> None: self.hidden_size = int(kwargs.pop("hidden_size", 768)) self.num_labels = int(kwargs.pop("num_labels", 2)) self.classifier_dropout = float(kwargs.pop("classifier_dropout", 0.1)) self.sentence_pooling = kwargs.pop("sentence_pooling", "mean") self.use_weighted_pooling = bool(kwargs.pop("use_weighted_pooling", False)) # Store any additional fields for completeness for key, value in kwargs.items(): setattr(self, key, value) @dataclass(frozen=True) class ProcessPerformanceTrace: """Structured runtime telemetry for `OpenProvenceModel.process` calls.""" preprocess_seconds: float = 0.0 assembly_seconds: float = 0.0 inference_seconds: float = 0.0 postprocess_seconds: float = 0.0 total_seconds: float = 0.0 sentence_collect_seconds: float = 0.0 sentence_normalize_seconds: float = 0.0 tokenize_seconds: float = 0.0 fragment_split_seconds: float = 0.0 fragment_decode_seconds: float = 0.0 def as_dict(self) -> dict[str, float]: return { "preprocess_seconds": float(self.preprocess_seconds), "assembly_seconds": float(self.assembly_seconds), "inference_seconds": float(self.inference_seconds), "postprocess_seconds": float(self.postprocess_seconds), "total_seconds": float(self.total_seconds), "sentence_collect_seconds": float(self.sentence_collect_seconds), "sentence_normalize_seconds": float(self.sentence_normalize_seconds), "tokenize_seconds": float(self.tokenize_seconds), "fragment_split_seconds": float(self.fragment_split_seconds), "fragment_decode_seconds": float(self.fragment_decode_seconds), } class OpenProvenceHead(nn.Module): """Minimal pruning head used by Provence pruning checkpoints.""" def __init__(self, config: OpenProvenceHeadConfig): super().__init__() self.config = config self.num_labels = getattr(config, "num_labels", 2) self.sentence_pooling = getattr(config, "sentence_pooling", "mean") self.use_weighted_pooling = getattr(config, "use_weighted_pooling", False) dropout_prob = float(getattr(config, "classifier_dropout", 0.1)) self.dropout = nn.Dropout(dropout_prob) hidden_size = int(getattr(config, "hidden_size", 768)) self.classifier = nn.Linear(hidden_size, self.num_labels) if self.use_weighted_pooling: self.pooling_weights = nn.Linear(hidden_size, 1) self._init_weights() def _init_weights(self) -> None: nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) if hasattr(self, "pooling_weights"): nn.init.xavier_uniform_(self.pooling_weights.weight) nn.init.zeros_(self.pooling_weights.bias) def forward( self, *, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, sentence_boundaries: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Produce token-level pruning logits.""" _ = attention_mask # not required for current inference path _ = sentence_boundaries hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) return {"logits": logits} @dataclass class OpenProvenceRawPrediction: """Container for raw pruning outputs.""" query: str contexts: list[str] ranking_score: float | None pruning_probs: np.ndarray context_ranges: list[tuple[int, int]] # Type alias for sentence splitter functions SentenceSplitter = Callable[[str], list[str]] def _get_english_sentence_tokenizer() -> PunktSentenceTokenizer: global _ENGLISH_SENTENCE_TOKENIZER if _ENGLISH_SENTENCE_TOKENIZER is None: try: tokenizer = nltk.data.load("tokenizers/punkt/english.pickle") except LookupError as exc: # pragma: no cover - requires punkt download raise LookupError( "Missing NLTK punkt tokenizer data. Run `python -m nltk.downloader punkt`." ) from exc if not isinstance(tokenizer, PunktSentenceTokenizer): raise TypeError(f"Expected PunktSentenceTokenizer, got {type(tokenizer).__name__}.") _ENGLISH_SENTENCE_TOKENIZER = tokenizer return _ENGLISH_SENTENCE_TOKENIZER def _looks_like_bullet_line(line: str) -> bool: return bool(_BULLET_PREFIX_RE.match(line)) def _iter_english_blocks(text: str) -> Iterable[tuple[str, int, int]]: """Yield text blocks with their span indices for English sentence segmentation.""" if not text: return total_len = len(text) lines = text.splitlines(keepends=True) if not lines: block = text if block: yield block, 0, total_len return accumulated = 0 current_parts: list[str] = [] current_start = 0 for line in lines: line_start = accumulated accumulated += len(line) plain_line = line.rstrip("\r\n") if _looks_like_bullet_line(plain_line) and current_parts: block_text = "".join(current_parts) if block_text: block_end = current_start + len(block_text) yield block_text, current_start, block_end current_parts = [line] current_start = line_start else: if not current_parts: current_start = line_start current_parts.append(line) if current_parts: block_text = "".join(current_parts) if block_text: block_end = current_start + len(block_text) yield block_text, current_start, block_end if accumulated < total_len: remainder = text[accumulated:] if remainder: yield remainder, accumulated, total_len def _split_overlong_sentence( sentence: str, max_chars: int = DEFAULT_ENGLISH_SENTENCE_MAX_CHARS, *, preserve_whitespace: bool = False, ) -> list[str]: if preserve_whitespace: working = sentence else: working = sentence.strip() if not working: return [] if len(working) <= max_chars: return [working if preserve_whitespace else working.strip()] chunks: list[str] = [] start = 0 length = len(working) punctuation = ".?!;:\n" while start < length: target = min(start + max_chars, length) # Prefer a newline boundary when available within the window to keep list items concise. newline_idx = working.rfind("\n", start + 1, target) boundary = None if newline_idx != -1 and newline_idx >= start + 1: boundary = newline_idx + 1 if boundary is None or boundary <= start: for idx in range(target, start, -1): if working[idx - 1] in punctuation: boundary = idx break if boundary is None or boundary <= start: boundary = target chunk = working[start:boundary] if not preserve_whitespace: chunk = chunk.strip() if chunk: chunks.append(chunk) start = boundary return chunks or ([working] if preserve_whitespace else [working.strip()]) def _split_multiline_sentence(text: str, strip_sentences: bool) -> list[str]: if "\n" not in text: return [text.strip() if strip_sentences else text] segments = text.splitlines(keepends=not strip_sentences) meaningful = [segment for segment in segments if segment.strip()] if len(meaningful) <= 1: return [text.strip() if strip_sentences else text] # Skip splitting when the sentence already contains clear punctuation across lines. punctuation_count = sum(1 for ch in text if ch in ".?!") if punctuation_count >= len(meaningful): return [text.strip() if strip_sentences else text] # Avoid splitting when any line is excessively long (likely already handled elsewhere). if any(len(seg.strip()) > DEFAULT_ENGLISH_SENTENCE_MAX_CHARS for seg in meaningful): return [text.strip() if strip_sentences else text] processed: list[str] = [] for segment in meaningful: if strip_sentences: value = segment.strip() if value: processed.append(value) else: processed.append(segment) if processed: return processed return [text.strip() if strip_sentences else text] def _collect_candidate_sentences( example: Mapping[str, Any], splitter: SentenceSplitter ) -> list[str]: """Collect sentences from prefixes, manual overrides, or by splitting the context text.""" prefix_sentences = example.get("prefix_sentences") or [] manual_sentences = example.get("manual_sentences") context_text = str(example.get("context_text", "")) sentences: list[str] = [str(s) for s in prefix_sentences if s is not None] if manual_sentences is not None: sentences.extend(str(s) for s in manual_sentences if s is not None) else: sentences.extend(str(s) for s in splitter(context_text) if s is not None) return sentences def _fallback_sentence(context_text: str, strip_sentences: bool) -> str: if not strip_sentences: return context_text stripped = context_text.strip() return stripped or context_text def _normalize_sentences( raw_sentences: Sequence[str], context_text: str, strip_sentences: bool ) -> list[str]: sentences: list[str] = [] for entry in raw_sentences: text = str(entry) if not text: continue segmented = _split_multiline_sentence(text, strip_sentences) for segment in segmented: if strip_sentences: if segment: sentences.append(segment) else: if segment: sentences.append(segment) if sentences: return sentences return [_fallback_sentence(context_text, strip_sentences)] def _tokenize_sentences(tokenizer: Any, sentences: Sequence[str]) -> list[list[int]]: if not sentences: return [] tokenized = tokenizer( list(sentences), add_special_tokens=False, return_attention_mask=False, ) return tokenized.get("input_ids", []) if isinstance(tokenized, Mapping) else [] def _tokenize_sentences_with_context( tokenizer: Any, sentences: Sequence[str], prefix_count: int, context_text: str, *, strip_sentences: bool, ) -> list[list[int]]: return _tokenize_sentences(tokenizer, sentences) def _split_token_lists( token_lists: Sequence[Sequence[int]], max_fragment_tokens: int, *, keep_sentence_boundaries: bool = False, ) -> list[tuple[list[int], int, int, int]]: fragments: list[tuple[list[int], int, int, int]] = [] global_index = 0 step = max(1, int(max_fragment_tokens)) for sentence_index, token_ids in enumerate(token_lists): tokens = list(token_ids) if not tokens: continue if keep_sentence_boundaries and len(tokens) <= max_fragment_tokens: fragments.append((tokens, int(sentence_index), 0, global_index)) global_index += 1 continue for fragment_index, start in enumerate(range(0, len(tokens), step)): fragment_tokens = tokens[start : start + step] if not fragment_tokens: continue fragments.append( (fragment_tokens, int(sentence_index), int(fragment_index), global_index) ) global_index += 1 return fragments def _collect_sentences_for_job( example: Mapping[str, Any], splitter: SentenceSplitter, strip_sentences: bool, ) -> tuple[list[str], float, float]: context_text = str(example.get("context_text", "")) cached_sentences = example.get("cached_sentences") if cached_sentences is not None: sentences = [str(sentence) for sentence in cached_sentences] return sentences, 0.0, 0.0 start = perf_counter() raw_sentences = _collect_candidate_sentences(example, splitter) sentence_collect_time = perf_counter() - start start = perf_counter() sentences = _normalize_sentences(raw_sentences, context_text, strip_sentences) sentence_normalize_time = perf_counter() - start return sentences, sentence_collect_time, sentence_normalize_time def _tokenize_sentences_for_examples( tokenizer: Any, sentences_nested: Sequence[Sequence[str]], cached_token_lists: Sequence[Any] | None, ) -> tuple[list[list[list[int]]], list[float]]: result_token_ids: list[list[list[int]] | None] = [] timings: list[float | None] = [] sentences_to_tokenize: list[str] = [] mapping: list[tuple[int, int]] = [] total_examples = len(sentences_nested) cached_token_lists = cached_token_lists or [None] * total_examples for example_index, (sentences, cached_tokens) in enumerate( zip(sentences_nested, cached_token_lists) ): if cached_tokens is not None: token_lists = [[int(token) for token in tokens] for tokens in cached_tokens] result_token_ids.append(token_lists) timings.append(0.0) continue if sentences: mapping.append((example_index, len(sentences))) sentences_to_tokenize.extend(sentences) result_token_ids.append(None) timings.append(None) if sentences_to_tokenize: start = perf_counter() tokenized = tokenizer( sentences_to_tokenize, add_special_tokens=False, return_attention_mask=False, ) tokenize_time = perf_counter() - start input_ids = tokenized.get("input_ids", []) pointer = 0 total_sentences = len(sentences_to_tokenize) time_per_sentence = tokenize_time / total_sentences if total_sentences else 0.0 for example_index, sentence_count in mapping: slice_ids = input_ids[pointer : pointer + sentence_count] pointer += sentence_count result_token_ids[example_index] = [ [int(token) for token in tokens] for tokens in slice_ids ] timings[example_index] = time_per_sentence * sentence_count finalized_token_ids: list[list[list[int]]] = [] finalized_timings: list[float] = [] for tokens, timing in zip(result_token_ids, timings): finalized_token_ids.append(tokens or []) finalized_timings.append(float(timing or 0.0)) return finalized_token_ids, finalized_timings def _build_fragment_payload( tokenizer: Any, sentences: Sequence[str], token_lists: Sequence[Sequence[int]], context_text: str, max_fragment_tokens: int, strip_sentences: bool, respect_sentence_boundaries: bool, ) -> tuple[dict[str, Any], float, float]: normalized_tokens = [[int(token) for token in tokens] for tokens in token_lists] start = perf_counter() fragments = _split_token_lists( normalized_tokens, max_fragment_tokens, keep_sentence_boundaries=respect_sentence_boundaries, ) fragment_split_time = perf_counter() - start if not fragments: fallback_source = _fallback_sentence(context_text, strip_sentences) fallback_tokens = tokenizer.encode(fallback_source, add_special_tokens=False) fragments = [(list(fallback_tokens), 0, 0, 0)] start = perf_counter() fragment_payload = _decode_and_filter_fragments( tokenizer, fragments, strip_sentences=strip_sentences, ) fragment_decode_time = perf_counter() - start if not fragment_payload["fragment_token_ids"]: tokens, sentence_idx, fragment_idx, global_idx = fragments[0] decoded_text = tokenizer.decode( tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) processed_text = decoded_text.strip() if strip_sentences else decoded_text fragment_payload = { "fragment_texts": [processed_text], "fragment_token_ids": [list(tokens)], "fragment_sentence_index": [sentence_idx], "fragment_fragment_index": [fragment_idx], "fragment_global_index": [global_idx], } return fragment_payload, fragment_split_time, fragment_decode_time def _decode_and_filter_fragments( tokenizer: Any, fragments: Sequence[tuple[list[int], int, int, int]], *, strip_sentences: bool, ) -> dict[str, list[Any]]: if not fragments: return { "fragment_texts": [], "fragment_token_ids": [], "fragment_sentence_index": [], "fragment_fragment_index": [], "fragment_global_index": [], } token_sequences = [tokens for tokens, _, _, _ in fragments] fragment_texts = tokenizer.batch_decode( token_sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) filtered_tokens: list[list[int]] = [] filtered_texts: list[str] = [] sentence_indices: list[int] = [] fragment_indices: list[int] = [] global_indices: list[int] = [] for text, (tokens, sentence_idx, fragment_idx, global_idx) in zip(fragment_texts, fragments): processed_text = text.strip() if strip_sentences else text if strip_sentences: if not processed_text: continue else: if not text: continue filtered_tokens.append(list(tokens)) filtered_texts.append(processed_text) sentence_indices.append(sentence_idx) fragment_indices.append(fragment_idx) global_indices.append(global_idx) return { "fragment_texts": filtered_texts, "fragment_token_ids": filtered_tokens, "fragment_sentence_index": sentence_indices, "fragment_fragment_index": fragment_indices, "fragment_global_index": global_indices, } def _fragmentize_single_job( tokenizer: Any, job: dict[str, Any], *, max_fragment_tokens: int, splitter: SentenceSplitter, strip_sentences: bool, respect_sentence_boundaries: bool, ) -> dict[str, Any]: sentences, collect_time, normalize_time = _collect_sentences_for_job( job, splitter, strip_sentences, ) token_ids_nested, tokenize_timings = _tokenize_sentences_for_examples( tokenizer, [sentences], [job.get("cached_token_lists")], ) token_lists = token_ids_nested[0] if not token_lists: cached_lists = job.get("cached_token_lists") token_lists = ( [[int(token) for token in tokens] for tokens in cached_lists] if cached_lists else [] ) fragment_payload, fragment_split_time, fragment_decode_time = _build_fragment_payload( tokenizer=tokenizer, sentences=sentences, token_lists=token_lists, context_text=str(job.get("context_text", "")), max_fragment_tokens=max_fragment_tokens, strip_sentences=strip_sentences, respect_sentence_boundaries=respect_sentence_boundaries, ) entry = { "sentences": sentences, "timing_sentence_collect": collect_time, "timing_sentence_normalize": normalize_time, "timing_tokenize": tokenize_timings[0], "timing_fragment_split": fragment_split_time, "timing_fragment_decode": fragment_decode_time, } entry.update(fragment_payload) return entry def _preprocess_collate_fn( batch: Sequence[tuple[dict[str, Any], dict[str, Any]]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: if not batch: return [], [] jobs, entries = zip(*batch) return list(jobs), list(entries) class _PreprocessDataset(Dataset): """Map-style dataset that fragmentizes preprocessing jobs.""" def __init__( self, jobs: Sequence[dict[str, Any]], tokenizer: Any, splitter: SentenceSplitter, max_fragment_tokens: int, strip_sentences: bool, respect_sentence_boundaries: bool, ) -> None: self._jobs = list(jobs) self._tokenizer = tokenizer self._splitter = splitter self._max_fragment_tokens = max_fragment_tokens self._strip_sentences = strip_sentences self._respect_sentence_boundaries = respect_sentence_boundaries def __len__(self) -> int: return len(self._jobs) def __getitem__(self, index: int) -> tuple[dict[str, Any], dict[str, Any]]: job = self._jobs[index] entry = _fragmentize_single_job( self._tokenizer, job, max_fragment_tokens=self._max_fragment_tokens, splitter=self._splitter, strip_sentences=self._strip_sentences, respect_sentence_boundaries=self._respect_sentence_boundaries, ) return job, entry @dataclass class _FragmentRecord: """Metadata for a context fragment produced during long-context splitting.""" text: str sentence_index: int fragment_index: int global_index: int token_length: int token_ids: list[int] def fast_bunkai_sentence_splitter(text: str) -> list[str]: """Split sentences with fast-bunkai. Raises if the library is unavailable.""" if _FAST_BUNKAI is None: raise RuntimeError( "fast-bunkai is not installed. Install `fast-bunkai` or provide a custom sentence_splitter " "(e.g. `simple_sentence_splitter`)." ) sentences = [sentence for sentence in _FAST_BUNKAI(text) if sentence] if sentences: return sentences return [text] if text else [] def simple_sentence_splitter(text: str) -> list[str]: """Lightweight regex-based sentence splitter for Japanese text.""" if not text: return [] pattern = re.compile(r".+?(?:。|!|?|!|\?|\n|$)", re.S) sentences = [match for match in pattern.findall(text) if match] if sentences: return sentences return [text] if text else [] def create_english_sentence_splitter( max_chars: int = DEFAULT_ENGLISH_SENTENCE_MAX_CHARS, ) -> SentenceSplitter: """Factory for English sentence splitters that preserve whitespace and newlines. Processing pipeline (executed for every call of the returned splitter): 1. `_iter_english_blocks` walks the source text line-by-line, grouping adjacent lines while respecting bullet-style headings. This yields blocks together with their start/end byte offsets so we always know where we are in the original string. 2. Each block is tokenised with NLTK's Punkt model (`span_tokenize`). The spans are mapped back to absolute offsets (`global_start`/`global_end`). We stretch the end offset across trailing whitespace so that paragraph boundaries keep their newline markers. 3. Every raw segment is routed through `_split_overlong_sentence`, which trims *nothing* but ensures no fragment exceeds ``max_chars``. When Punkt does not emit any spans (e.g., extremely long strings without punctuation), the whole block is handed directly to this fallback splitter so we still return manageable chunks. 4. Empty segments and whitespace-only fragments are skipped. If the whole text reduces to whitespace we fall back to returning the stripped source. This design guarantees that: * sentence boundaries preserve the original whitespace/newline layout, * sections and lists stay intact because block slicing mirrors the input, and * even pathological long sentences are clipped deterministically at ``max_chars`` before downstream tokenisation. """ if max_chars <= 0: raise ValueError("max_chars must be positive") def _split_text(text: str) -> list[str]: if not text: return [] tokenizer = _get_english_sentence_tokenizer() sentences: list[str] = [] for block_text, block_start, block_end in _iter_english_blocks(text): if not block_text: continue try: spans = list(tokenizer.span_tokenize(block_text)) except LookupError as exc: # pragma: no cover - requires punkt download raise LookupError( "Missing NLTK punkt tokenizer. Run `python -m nltk.downloader punkt`." ) from exc if not spans: segment = text[block_start:block_end] if segment.strip(): sentences.extend( _split_overlong_sentence( segment, max_chars=max_chars, preserve_whitespace=True, ) ) continue for span_start, span_end in spans: global_start = block_start + span_start global_end = block_start + span_end extended_end = global_end while extended_end < block_end and text[extended_end].isspace(): extended_end += 1 segment = text[global_start:extended_end] if segment and segment.strip(): sentences.extend( _split_overlong_sentence( segment, max_chars=max_chars, preserve_whitespace=True, ) ) if sentences: return sentences fallback = text.strip() return [fallback] if fallback else [] return _split_text _DEFAULT_ENGLISH_SENTENCE_SPLITTER = create_english_sentence_splitter() def english_sentence_splitter(text: str) -> list[str]: """Default English sentence splitter using the module's configured limit.""" return _DEFAULT_ENGLISH_SENTENCE_SPLITTER(text) def create_auto_sentence_splitter( *, japanese_splitter: SentenceSplitter = fast_bunkai_sentence_splitter, english_splitter: SentenceSplitter = english_sentence_splitter, kana_window: int = 500, min_kana_per_window: int = 1, ) -> SentenceSplitter: """Return a splitter that detects Japanese text via kana density before splitting.""" def _split_text(text: str) -> list[str]: if is_japanese_fast(text, window=kana_window, min_kana_per_window=min_kana_per_window): return japanese_splitter(text) return english_splitter(text) return _split_text def _fragmentize_example( # pyright: ignore[reportUnusedFunction] example: dict[str, Any], tokenizer, max_fragment_tokens: int, splitter: SentenceSplitter, strip_sentences: bool, *, respect_sentence_boundaries: bool = False, ) -> dict[str, Any]: """Fragmentize a single context example for parallel preprocessing.""" context_text = str(example.get("context_text", "")) cached_sentences = example.get("cached_sentences") cached_token_lists = example.get("cached_token_lists") timer_start = perf_counter() if cached_sentences is not None: sentences = [str(sentence) for sentence in cached_sentences] sentence_collect_time = 0.0 sentence_normalize_time = 0.0 else: raw_sentences = _collect_candidate_sentences(example, splitter) sentence_collect_time = perf_counter() - timer_start timer_start = perf_counter() sentences = _normalize_sentences(raw_sentences, context_text, strip_sentences) sentence_normalize_time = perf_counter() - timer_start prefix_sentences = example.get("prefix_sentences") or [] if cached_token_lists is not None: token_lists = [[int(token) for token in tokens] for tokens in cached_token_lists] tokenize_time = 0.0 else: timer_start = perf_counter() token_lists = _tokenize_sentences_with_context( tokenizer, sentences, len(prefix_sentences), context_text, strip_sentences=strip_sentences, ) tokenize_time = perf_counter() - timer_start timer_start = perf_counter() fragments = _split_token_lists( token_lists, max_fragment_tokens, keep_sentence_boundaries=respect_sentence_boundaries, ) fragment_split_time = perf_counter() - timer_start if not fragments: timer_start = perf_counter() fallback_source = _fallback_sentence(context_text, strip_sentences) fallback_tokens = tokenizer.encode(fallback_source, add_special_tokens=False) tokenize_time += perf_counter() - timer_start fragments = [(list(fallback_tokens), 0, 0, 0)] sentences = [fallback_source] timer_start = perf_counter() fragment_payload = _decode_and_filter_fragments( tokenizer, fragments, strip_sentences=strip_sentences, ) decode_time = perf_counter() - timer_start if not fragment_payload["fragment_token_ids"]: tokens, sentence_idx, fragment_idx, global_idx = fragments[0] timer_start = perf_counter() decoded_text = tokenizer.decode( tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) decode_time += perf_counter() - timer_start processed_text = decoded_text.strip() if strip_sentences else decoded_text fragment_payload = { "fragment_texts": [processed_text], "fragment_token_ids": [list(tokens)], "fragment_sentence_index": [sentence_idx], "fragment_fragment_index": [fragment_idx], "fragment_global_index": [global_idx], } return { "sentences": sentences, "fragment_texts": fragment_payload["fragment_texts"], "fragment_sentence_index": fragment_payload["fragment_sentence_index"], "fragment_fragment_index": fragment_payload["fragment_fragment_index"], "fragment_global_index": fragment_payload["fragment_global_index"], "fragment_token_ids": fragment_payload["fragment_token_ids"], "timing_sentence_collect": sentence_collect_time, "timing_sentence_normalize": sentence_normalize_time, "timing_tokenize": tokenize_time, "timing_fragment_split": fragment_split_time, "timing_fragment_decode": decode_time, } class OpenProvenceConfig(PretrainedConfig): """Configuration metadata for OpenProvence checkpoints.""" model_type = "open_provence" def __init__( self, mode: str = "reranking_pruning", base_model_name_or_path: str | None = None, base_model_config: dict[str, Any] | PretrainedConfig | None = None, tokenizer_name_or_path: str | None = None, pruning_config: dict | None = None, max_length: int = 512, num_labels: int | None = None, num_pruning_labels: int | None = None, encoder_architecture: str | None = None, **kwargs: Any, ) -> None: raw_default_threadshold = kwargs.pop("default_threadshold", None) alt_default_threshold = kwargs.pop("default_threshold", None) # Backwards compatibility: drop deprecated language hints from historical configs. kwargs.pop("splitter_default_language", None) kwargs.pop("standalone_process_default_language", None) super().__init__(**kwargs) self.mode = mode if isinstance(base_model_config, PretrainedConfig): base_model_config = base_model_config.to_dict() self.base_model_name_or_path = base_model_name_or_path self.base_model_config = dict(base_model_config) if base_model_config is not None else None self.tokenizer_name_or_path = tokenizer_name_or_path self.pruning_config = pruning_config or {} self.max_length = max_length self.encoder_architecture = encoder_architecture self.num_labels = 1 if num_labels is None else num_labels self.num_pruning_labels = 2 if num_pruning_labels is None else num_pruning_labels self.default_threadshold = None if raw_default_threadshold is not None: try: self.default_threadshold = float(raw_default_threadshold) except (TypeError, ValueError) as exc: raise TypeError( "Config value 'default_threadshold' must be a numeric type convertible to float." ) from exc elif alt_default_threshold is not None: warnings.warn( "Config key 'default_threshold' detected. Did you intend 'default_threadshold'? " "Using the provided value for backwards compatibility.", RuntimeWarning, stacklevel=2, ) try: self.default_threadshold = float(alt_default_threshold) except (TypeError, ValueError) as exc: raise TypeError( "Config value 'default_threshold' must be a numeric type convertible to float." ) from exc self.default_threshold = self.default_threadshold class OpenProvencePreTrainedModel(PreTrainedModel): """Base class implementing the shared Provence reranker backbone.""" config_class = OpenProvenceConfig base_model_prefix = "open_provence" def __init__( self, config: OpenProvenceConfig, *model_args: Any, device: str | torch.device | None = None, **model_kwargs: Any, ) -> None: _ensure_transformers_logging_configured() cleaned_kwargs = dict(model_kwargs) cleaned_kwargs.pop("device", None) resolved_device: torch.device | None = None if device is not None: try: resolved_device = resolve_inference_device(device) except ValueError as exc: class_name = self.__class__.__name__ raise ValueError( f"Invalid device specification for {class_name}: {device!r}" ) from exc super().__init__(config, *model_args, **cleaned_kwargs) self.max_length = config.max_length self.num_labels = config.num_labels self.num_pruning_labels = config.num_pruning_labels self.default_splitter_language = DEFAULT_SPLITTER_LANGUAGE self._runtime_device = torch.device("cpu") self.base_model_config = self._build_base_model_config(config) self.ranking_model = AutoModelForSequenceClassification.from_config(self.base_model_config) self.pruning_head = OpenProvenceHead(OpenProvenceHeadConfig(**config.pruning_config)) self.tokenizer = self._init_tokenizer(config) self._manual_special_tokens_required = False self._manual_cls_token_id: int | None = None self._manual_sep_token_id: int | None = None self._update_tokenizer_runtime() self.default_threshold = self._resolve_default_threshold(config) self.eval() if resolved_device is not None: self.to(device=resolved_device) def _build_base_model_config(self, config: OpenProvenceConfig) -> PretrainedConfig: if config.base_model_config: config_dict = deepcopy(config.base_model_config) model_type = config_dict.pop("model_type", None) if model_type is None: raise ValueError( "base_model_config must include 'model_type' to rebuild the backbone." ) base_config = AutoConfig.for_model(model_type, **config_dict) else: base_reference = ( config.base_model_name_or_path or config._name_or_path or config.encoder_architecture ) if not base_reference: raise ValueError( "OpenProvenceConfig must define base_model_config or base_model_name_or_path." ) base_config = AutoConfig.from_pretrained(base_reference, trust_remote_code=True) base_config.num_labels = config.num_labels return base_config def _init_tokenizer(self, config: OpenProvenceConfig): tokenizer_reference = ( config.tokenizer_name_or_path or config._name_or_path or config.base_model_name_or_path ) if not tokenizer_reference: raise ValueError("Unable to determine tokenizer reference for OpenProvence model.") try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_reference) except Exception as exc: # pragma: no cover - surface failure to caller raise RuntimeError( f"Failed to initialize tokenizer from '{tokenizer_reference}'." ) from exc return tokenizer def _update_tokenizer_runtime(self, max_length_override: int | None = None) -> None: if self.tokenizer is None: return upper_bound = max(getattr(self.tokenizer, "model_max_length", 0) or 0, 1_000_000) if max_length_override is not None and max_length_override > 0: upper_bound = max(upper_bound, int(max_length_override)) elif self.max_length and self.max_length > 0: upper_bound = max(upper_bound, int(self.max_length)) self.tokenizer.model_max_length = upper_bound def _update_runtime_defaults(self) -> None: tokenizer = cast(Any, self.tokenizer) special_map = cast(Mapping[str, Any], getattr(tokenizer, "special_tokens_map", {})) self._manual_special_tokens_required = self._requires_manual_special_tokens() # type: ignore[reportCallIssue] if self._manual_special_tokens_required: self._manual_cls_token_id = self._resolve_special_token_id( getattr(tokenizer, "cls_token_id", None), special_map.get("cls_token_id"), getattr(tokenizer, "bos_token_id", None), special_map.get("bos_token_id"), ) # type: ignore[reportCallIssue] self._manual_sep_token_id = self._resolve_special_token_id( getattr(tokenizer, "sep_token_id", None), special_map.get("sep_token_id"), getattr(tokenizer, "eos_token_id", None), special_map.get("eos_token_id"), ) # type: ignore[reportCallIssue] else: self._manual_cls_token_id = None self._manual_sep_token_id = None def _resolve_default_threshold(self, config: OpenProvenceConfig) -> float: value = getattr(config, "default_threadshold", None) if value is None: return DEFAULT_PROCESS_THRESHOLD try: return float(value) except (TypeError, ValueError) as exc: # pragma: no cover - config validation raise TypeError( "OpenProvenceConfig.default_threadshold must be numeric when provided." ) from exc def to(self, *args: Any, **kwargs: Any) -> OpenProvencePreTrainedModel: # type: ignore[override] result = super().to(*args, **kwargs) candidate = kwargs.get("device") if kwargs else None if candidate is None and args: candidate = args[0] if candidate is not None: self._runtime_device = torch.device(candidate) return cast("OpenProvencePreTrainedModel", result) def get_input_embeddings(self): return self.ranking_model.get_input_embeddings() def set_input_embeddings(self, value): self.ranking_model.set_input_embeddings(value) def load_state_dict(self, state_dict: Mapping[str, torch.Tensor], strict: bool = True): # type: ignore[override] converted = self._convert_legacy_state_dict(state_dict) return super().load_state_dict(converted, strict=strict) @staticmethod def _convert_legacy_state_dict( state_dict: Mapping[str, torch.Tensor], ) -> Mapping[str, torch.Tensor]: if any(key.startswith("ranking_model.") for key in state_dict): return state_dict converted: OrderedDict[str, torch.Tensor] = OrderedDict() for key, value in state_dict.items(): if key.startswith("pruning_head."): converted[key] = value else: converted[f"ranking_model.{key}"] = value return converted class OpenProvenceModel(OpenProvencePreTrainedModel): """Lightweight wrapper around the Provence reranker checkpoint.""" def __init__( self, config: OpenProvenceConfig, *model_args: Any, device: str | torch.device | None = None, **model_kwargs: Any, ) -> None: super().__init__(config, *model_args, device=device, **model_kwargs) self.default_splitter_language = DEFAULT_SPLITTER_LANGUAGE self._update_tokenizer_runtime() self._update_runtime_defaults() def _resolve_process_threshold(self, threshold: float | None) -> float: if threshold is None: resolved = getattr(self, "default_threshold", DEFAULT_PROCESS_THRESHOLD) if resolved is None: resolved = DEFAULT_PROCESS_THRESHOLD else: resolved = threshold try: return float(resolved) except (TypeError, ValueError) as exc: raise TypeError("Resolved threshold must be numeric.") from exc def _resolve_special_token_id(self, *candidates: int | None) -> int | None: for candidate in candidates: if isinstance(candidate, int): return candidate return None def _requires_manual_special_tokens(self) -> bool: """Detect tokenizers (e.g., ModernBERT) that omit special tokens in build_inputs.""" tokenizer = cast(Any, self.tokenizer) try: query_tokens = tokenizer.encode("open provence query", add_special_tokens=False) context_tokens = tokenizer.encode("open provence document", add_special_tokens=False) except Exception: # pragma: no cover - tokenizer specific errors return False if not query_tokens or not context_tokens: return False built = tokenizer.build_inputs_with_special_tokens(query_tokens, context_tokens) built = [int(token) for token in built] special_map = cast(Mapping[str, Any], getattr(tokenizer, "special_tokens_map", {})) cls_candidates = [ getattr(tokenizer, "cls_token_id", None), special_map.get("cls_token_id"), getattr(tokenizer, "bos_token_id", None), special_map.get("bos_token_id"), ] cls_candidates = [value for value in cls_candidates if isinstance(value, int)] sep_candidates = [ getattr(tokenizer, "sep_token_id", None), special_map.get("sep_token_id"), getattr(tokenizer, "eos_token_id", None), special_map.get("eos_token_id"), ] sep_candidates = [value for value in sep_candidates if isinstance(value, int)] missing_cls = bool(cls_candidates) and not any(token in cls_candidates for token in built) missing_sep = bool(sep_candidates) and not any(token in sep_candidates for token in built) return missing_cls or missing_sep @staticmethod def _extract_model_output(outputs: Any, key: str) -> torch.Tensor: candidate: torch.Tensor | None = None if isinstance(outputs, Mapping): candidate = outputs.get(key) if candidate is None and key == "ranking_logits": candidate = outputs.get("logits") if candidate is None: candidate = getattr(outputs, key, None) if candidate is None and key == "ranking_logits": candidate = getattr(outputs, "logits", None) if candidate is None: raise KeyError(f"{key} not found in model outputs") return candidate @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str | Path, *, device: str | torch.device | None = None, trust_remote_code: bool = True, max_length: int | None = None, torch_dtype: torch.dtype | str | None = None, **kwargs: Any, ) -> OpenProvenceModel: """Load a finetuned Provence reranker with pruning head.""" _ensure_transformers_logging_configured() try: resolved_device = resolve_inference_device(device) except ValueError as exc: raise ValueError( f"Invalid device specification for OpenProvenceModel: {device!r}" ) from exc resolved_device_str = str(resolved_device).lower() if "torch_dtype" in kwargs and "dtype" not in kwargs: kwargs["dtype"] = kwargs.pop("torch_dtype") target_dtype = kwargs.get("dtype") if target_dtype is None and torch_dtype is not None: target_dtype = torch_dtype if target_dtype is None: dtype_hint = _select_default_torch_dtype(resolved_device_str) if dtype_hint is not None: target_dtype = dtype_hint attn_impl = kwargs.get("attn_implementation") want_flash_attention = False if resolved_device_str.startswith("cuda"): if _supports_flash_attention(): want_flash_attention = True if target_dtype is None: bf16_supported = getattr(torch.cuda, "is_bf16_supported", lambda: False)() target_dtype = torch.bfloat16 if bf16_supported else torch.float16 if attn_impl is None: attn_impl = "flash_attention_2" else: if attn_impl is None: attn_impl = "eager" elif resolved_device_str.startswith("mps"): if attn_impl is None: attn_impl = "eager" if target_dtype is not None: kwargs["dtype"] = target_dtype if attn_impl is not None: kwargs["attn_implementation"] = attn_impl def _apply_config_overrides(target: Any) -> None: attn_impl = kwargs.get("attn_implementation") if attn_impl is not None and hasattr(target, "config"): setattr(target.config, "attn_implementation", attn_impl) dtype_value = kwargs.get("dtype") if dtype_value is not None and hasattr(target, "config"): setattr(target.config, "torch_dtype", dtype_value) try: model = super().from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs, ) except Exception: if not want_flash_attention: raise kwargs["attn_implementation"] = "eager" kwargs["dtype"] = torch.float32 model = super().from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs, ) requested_dtype = kwargs.get("dtype") _apply_config_overrides(model) if hasattr(model, "ranking_model"): _apply_config_overrides(getattr(model, "ranking_model")) dtype_for_to = _coerce_dtype_for_torch_to(requested_dtype) if dtype_for_to is not None: model.to(device=resolved_device, dtype=dtype_for_to) else: model.to(resolved_device) if max_length is not None: model.max_length = int(max_length) if hasattr(model.config, "max_length"): model.config.max_length = int(max_length) model._update_tokenizer_runtime(max_length_override=max_length) model._update_runtime_defaults() model.eval() return model def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, return_dict: bool | None = None, **kwargs: Any, ) -> ModelOutput | tuple[torch.Tensor, ...]: """Run the ranking backbone and pruning head.""" if input_ids is None: raise ValueError("input_ids must be provided") effective_return_dict = return_dict if return_dict is not None else True attention_mask = ( attention_mask.to(self._runtime_device) if attention_mask is not None else None ) input_ids = input_ids.to(self._runtime_device) outputs = self.ranking_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, **kwargs, ) ranking_logits = cast(FloatTensor, outputs.logits) hidden_states = outputs.hidden_states[-1] pruning_inputs = hidden_states head_param = next(self.pruning_head.parameters(), None) if head_param is not None and pruning_inputs.dtype != head_param.dtype: pruning_inputs = pruning_inputs.to(head_param.dtype) pruning_outputs = self.pruning_head( hidden_states=pruning_inputs, attention_mask=attention_mask, ) pruning_logits = cast(Tensor, pruning_outputs["logits"]) loss_tensor: torch.Tensor | None = None if labels is not None: if self.config.num_labels == 1: loss_fct = nn.BCEWithLogitsLoss() loss_tensor = loss_fct(ranking_logits.view(-1), labels.float()) else: loss_fct = nn.CrossEntropyLoss() loss_tensor = loss_fct( ranking_logits.view(-1, self.config.num_labels), labels.view(-1) ) loss_output: FloatTensor | None if loss_tensor is None: loss_output = None else: loss_output = cast(FloatTensor, loss_tensor.to(dtype=ranking_logits.dtype)) result = SequenceClassifierOutput( loss=loss_output, logits=ranking_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) setattr(result, "pruning_logits", pruning_logits) setattr(result, "ranking_logits", ranking_logits) if not effective_return_dict: output: tuple[torch.Tensor, ...] = (ranking_logits, pruning_logits) if loss_output is not None: return (loss_output,) + output return output return result @torch.no_grad() def get_raw_predictions( self, query: str, contexts: Iterable[str], ) -> OpenProvenceRawPrediction: """Compute token-level keep probabilities for a single context list.""" batch_result = self.get_raw_predictions_batch(query, [list(contexts)]) return batch_result[0] def get_raw_predictions_batch( self, query: str | Sequence[str], contexts_batch: Sequence[Sequence[str]], batch_size: int | None = None, ) -> list[OpenProvenceRawPrediction]: """Compute raw predictions for multiple context lists. Supports either a single query string shared across the batch or a sequence of per-sample queries matching ``contexts_batch``. """ if not contexts_batch: return [] sep_token = self.tokenizer.sep_token or "" if batch_size is None or batch_size <= 0: batch_size = len(contexts_batch) if isinstance(query, Sequence) and not isinstance(query, str): query_list = [str(entry) for entry in query] if len(query_list) != len(contexts_batch): raise ValueError( "When providing multiple queries, their count must match contexts_batch." ) else: query_list = [str(query)] * len(contexts_batch) results: list[OpenProvenceRawPrediction] = [] for start in range(0, len(contexts_batch), batch_size): chunk = contexts_batch[start : start + batch_size] chunk_queries = query_list[start : start + batch_size] chunk_combined = [ chunk_queries[idx] + sep_token + "".join(contexts) for idx, contexts in enumerate(chunk) ] encoding = self.tokenizer( chunk_combined, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt", ) encoding = {key: value.to(self._runtime_device) for key, value in encoding.items()} model_outputs = self.forward(return_dict=True, **encoding) ranking_logits = self._extract_model_output(model_outputs, "ranking_logits") pruning_logits = self._extract_model_output(model_outputs, "pruning_logits") ranking_logits = ranking_logits.detach().cpu() pruning_logits = pruning_logits.detach().cpu() if ranking_logits.dtype != torch.float32: ranking_logits = ranking_logits.to(dtype=torch.float32) if pruning_logits.dtype != torch.float32: pruning_logits = pruning_logits.to(dtype=torch.float32) for idx, contexts in enumerate(chunk): if len(contexts) == 0: continue logits = ranking_logits[idx] if logits.ndim == 0 or logits.numel() == 1: ranking_score = torch.sigmoid(logits.flatten())[0].item() else: ranking_score = torch.sigmoid(logits[..., 0]).item() pruning_logit = pruning_logits[idx] pruning_probs = torch.softmax(pruning_logit, dim=-1).numpy() if pruning_probs.ndim == 2 and pruning_probs.shape[1] == 2: pruning_probs = pruning_probs[:, 1] elif pruning_probs.ndim == 1: pruning_probs = pruning_probs else: pruning_probs = pruning_probs.reshape(-1) context_ranges = self._context_ranges_from_contexts(chunk_queries[idx], contexts) results.append( OpenProvenceRawPrediction( query=chunk_queries[idx], contexts=list(contexts), ranking_score=ranking_score, pruning_probs=pruning_probs, context_ranges=context_ranges, ) ) return results def predict_with_thresholds( self, query: str, contexts: Iterable[str], thresholds: Iterable[float], *, use_majority: bool = False, ) -> dict[str, Any]: """Return keep/delete decisions for each context under the thresholds.""" raw = self.get_raw_predictions(query, contexts) predictions: dict[float, list[int]] = {} for threshold in thresholds: context_predictions: list[int] = [] for start, end in raw.context_ranges: segment = raw.pruning_probs[start:end] if segment.size == 0: context_predictions.append(1) continue if use_majority: kept_tokens = np.count_nonzero(segment > threshold) context_predictions.append(1 if kept_tokens >= (segment.size / 2) else 0) else: mean_prob = float(segment.mean()) context_predictions.append(1 if mean_prob > threshold else 0) predictions[threshold] = context_predictions return { "query": raw.query, "contexts": raw.contexts, "ranking_score": raw.ranking_score, "predictions": predictions, "context_ranges": raw.context_ranges, "pruning_probs": raw.pruning_probs, } def _compute_context_ranges( self, query: str, contexts: list[str], pruning_probs: np.ndarray, ) -> list[tuple[int, int]]: """Reconstruct token spans for each context string.""" sep_token = self.tokenizer.sep_token or "" prefix = query + sep_token context_boundaries: list[int] = [] for idx in range(len(contexts)): cumulative_text = prefix + "".join(contexts[: idx + 1]) cumulative_encoding = self.tokenizer( cumulative_text, padding=False, truncation=True, max_length=self.max_length, return_tensors="pt", ) input_ids = cast(Tensor, cumulative_encoding["input_ids"]) context_boundaries.append(int(input_ids.shape[1])) prefix_encoding = self.tokenizer( prefix, padding=False, truncation=False, return_tensors="pt", ) prefix_len = int(cast(Tensor, prefix_encoding["input_ids"]).shape[1]) context_ranges: list[tuple[int, int]] = [] prev = prefix_len total = pruning_probs.shape[0] for boundary in context_boundaries: end = min(boundary, total) context_ranges.append((prev, end)) prev = end return context_ranges def _context_ranges_from_contexts( self, query: str, contexts: Sequence[str], ) -> list[tuple[int, int]]: """Compute token index ranges for a list of contexts given a query.""" if not contexts: return [] sep_token = self.tokenizer.sep_token or "" prefix = query + sep_token cumulative_texts = [] for idx in range(len(contexts)): cumulative_texts.append(prefix + "".join(contexts[: idx + 1])) boundaries: list[int] = [] for text in cumulative_texts: encoding = self.tokenizer( text, padding=False, truncation=True, max_length=self.max_length, return_tensors="pt", ) input_ids = cast(Tensor, encoding["input_ids"]) boundaries.append(int(input_ids.shape[1])) prefix_encoding = self.tokenizer( prefix, padding=False, truncation=False, return_tensors="pt", ) prefix_len = int(cast(Tensor, prefix_encoding["input_ids"]).shape[1]) ranges: list[tuple[int, int]] = [] prev = prefix_len for boundary in boundaries: ranges.append((prev, boundary)) prev = boundary return ranges def _resolve_prefix_sentences( self, title_spec: None | str | list[str] | list[list[str]], context_idx: int, ) -> tuple[list[str], bool]: """Determine prefix sentences and whether the first context sentence is a title.""" prefix_sentences: list[str] = [] title_is_first_sentence = False if title_spec == "first_sentence": title_is_first_sentence = True elif isinstance(title_spec, list): if title_spec and isinstance(title_spec[0], list): raw_title = title_spec[context_idx] if context_idx < len(title_spec) else None if raw_title: prefix_sentences.extend( [ title.strip() for title in raw_title if isinstance(title, str) and title.strip() ] ) else: raw_title = title_spec[context_idx] if context_idx < len(title_spec) else None if isinstance(raw_title, str) and raw_title.strip(): prefix_sentences.append(raw_title.strip()) elif isinstance(title_spec, str) and title_spec.strip(): prefix_sentences.append(title_spec.strip()) if prefix_sentences: last_idx = len(prefix_sentences) - 1 prefix_sentences[last_idx] = prefix_sentences[last_idx].rstrip("\n") + "\n" return prefix_sentences, title_is_first_sentence def _resolve_sentence_splitter( self, splitter: SentenceSplitter | Mapping[str, SentenceSplitter] | None, language: str | None, ) -> SentenceSplitter: if isinstance(splitter, Mapping): if language is None: raise ValueError("language must be provided when sentence_splitter is a mapping") if language in splitter: return splitter[language] raise ValueError(f"No sentence splitter registered for language '{language}'") if callable(splitter): return splitter default_language = getattr(self, "default_splitter_language", None) lang = language if language is not None else default_language if lang is None: lang = "auto" lang_normalized = str(lang).lower() if lang_normalized == "auto": return create_auto_sentence_splitter() if lang_normalized == "ja": return fast_bunkai_sentence_splitter if lang_normalized == "en": return english_sentence_splitter raise ValueError( f"Unsupported language code for sentence splitting: '{lang}'. Supported values are 'auto', 'en', and 'ja'." ) def _run_sequential_fragmentize( self, jobs: list[dict[str, Any]], *, max_fragment_tokens: int, splitter: SentenceSplitter, show_progress: bool, strip_sentences: bool, respect_sentence_boundaries: bool, ) -> list[dict[str, Any]]: processed_entries: list[dict[str, Any]] = [] if not jobs: return processed_entries progress = None if show_progress and is_progress_bar_enabled(): try: from tqdm import tqdm # pragma: no cover - optional dependency except Exception: # pragma: no cover - tqdm may be unavailable progress = None else: progress = tqdm(total=len(jobs), desc="Preprocess") for job in jobs: entry = _fragmentize_single_job( self.tokenizer, job, max_fragment_tokens=max_fragment_tokens, splitter=splitter, strip_sentences=strip_sentences, respect_sentence_boundaries=respect_sentence_boundaries, ) processed_entries.append(entry) if progress is not None: progress.update(1) if progress is not None: progress.close() return processed_entries def _truncate_fragment(self, fragment: _FragmentRecord, max_tokens: int) -> _FragmentRecord: if max_tokens <= 0: max_tokens = 1 if fragment.token_length <= max_tokens: return fragment new_tokens = fragment.token_ids[:max_tokens] new_text = self.tokenizer.decode( new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return _FragmentRecord( text=new_text, sentence_index=fragment.sentence_index, fragment_index=fragment.fragment_index, global_index=fragment.global_index, token_length=len(new_tokens), token_ids=list(new_tokens), ) def _prepare_block_inputs( self, query_tokens: Sequence[int], fragments: Sequence[_FragmentRecord], ) -> tuple[list[int], list[int], list[int] | None, list[tuple[int, int]]]: query_list = [int(token) for token in query_tokens] context_tokens: list[int] = [] for fragment in fragments: context_tokens.extend(int(token) for token in fragment.token_ids) built_with_specials = self.tokenizer.build_inputs_with_special_tokens( query_list, context_tokens ) built_with_specials = [int(token) for token in built_with_specials] manual_override = getattr(self, "_manual_special_tokens_required", False) manual_cls_token = getattr(self, "_manual_cls_token_id", None) manual_sep_token = getattr(self, "_manual_sep_token_id", None) if manual_override: # Some tokenizers, notably ModernBERT, omit CLS/SEP when provided with pre-tokenised # input. We rebuild the sequence manually so that downstream code sees consistent # boundaries without ever converting back to strings. input_ids: list[int] = [] if manual_cls_token is not None: input_ids.append(manual_cls_token) input_ids.extend(int(token) for token in query_list) if manual_sep_token is not None: input_ids.append(manual_sep_token) input_ids.extend(int(token) for token in context_tokens) if manual_sep_token is not None and context_tokens: input_ids.append(manual_sep_token) else: # Most tokenizers already handle special tokens correctly, so we can reuse the # sequence they produce directly. if built_with_specials: input_ids = built_with_specials else: input_ids = [int(token) for token in query_list] input_ids.extend(int(token) for token in context_tokens) attention_mask = [1] * len(input_ids) token_type_ids: list[int] | None try: token_type_ids = self.tokenizer.create_token_type_ids_from_sequences( query_list, context_tokens, ) except Exception: token_type_ids = None else: if token_type_ids is not None: token_type_ids = [int(token) for token in token_type_ids] def _find_subsequence_start( haystack: Sequence[int], needle: Sequence[int], ) -> int: if not needle: return -1 needle_list = list(needle) limit = len(haystack) - len(needle_list) + 1 for idx in range(max(limit, 0)): if haystack[idx : idx + len(needle_list)] == needle_list: return idx return -1 ranges: list[tuple[int, int]] = [] if context_tokens: context_start = _find_subsequence_start(input_ids, context_tokens) if context_start < 0: prefix_ids = self.tokenizer.build_inputs_with_special_tokens(query_list, []) context_start = len(prefix_ids) cursor = context_start for fragment in fragments: start = cursor cursor += len(fragment.token_ids) ranges.append((start, cursor)) else: ranges = [] if token_type_ids is not None and len(token_type_ids) < len(input_ids): pad_value = token_type_ids[-1] if token_type_ids else 0 token_type_ids = token_type_ids + [pad_value] * (len(input_ids) - len(token_type_ids)) if token_type_ids is None: token_type_ids = [0] * len(input_ids) context_start = ranges[0][0] if context_tokens else len(input_ids) for idx in range(context_start, len(input_ids)): token_type_ids[idx] = 1 return input_ids, attention_mask, token_type_ids, ranges def _precompute_sentences_and_tokens( self, context_text: str, prefix_sentences: list[str], manual_sentences: list[str] | None, splitter: SentenceSplitter, strip_sentences: bool, ) -> tuple[list[str], list[list[int]]]: example_payload = { "context_text": context_text, "prefix_sentences": prefix_sentences, "manual_sentences": manual_sentences, } raw_sentences = _collect_candidate_sentences(example_payload, splitter) sentences = _normalize_sentences(raw_sentences, context_text, strip_sentences) token_lists = _tokenize_sentences_with_context( self.tokenizer, sentences, len(prefix_sentences), context_text, strip_sentences=strip_sentences, ) return sentences, token_lists def _assemble_blocks_from_fragments( self, query_token_length: int, sep_token_length: int, fragments: list[_FragmentRecord], ) -> list[list[_FragmentRecord]]: if not fragments: return [] available_len = self.max_length - 2 # [CLS], [SEP] base_len = query_token_length + sep_token_length max_fragment_capacity = max(1, available_len - base_len) blocks: list[list[_FragmentRecord]] = [] current_block: list[_FragmentRecord] = [] current_len = base_len for fragment in fragments: fragment_len = fragment.token_length if current_len + fragment_len <= available_len: current_block.append(fragment) current_len += fragment_len continue if current_block: blocks.append(current_block) current_block = [] current_len = base_len truncated_fragment = self._truncate_fragment(fragment, max_fragment_capacity) current_block.append(truncated_fragment) current_len = base_len + truncated_fragment.token_length if current_block: blocks.append(current_block) return blocks def _normalize_inputs( self, question: str | Sequence[str], context: ContextInput, ) -> tuple[list[str], list[list[Any]], str]: """Normalize input structures for process().""" if isinstance(question, str): queries = [question] else: queries = [str(q) for q in question] def _is_sequence(value: Any) -> bool: return isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)) def _normalize_context_collection(values: Sequence[Any]) -> list[Any]: normalized: list[Any] = [] for item in values: if _is_sequence(item): normalized.append([str(element) for element in item]) else: normalized.append(str(item)) return normalized if isinstance(context, str): context_structure = "str" contexts: list[list[Any]] = [[context]] elif not _is_sequence(context): raise ValueError("Unsupported context format") elif len(queries) == 1: normalized_contexts = _normalize_context_collection(context) context_structure = "list" contexts = [normalized_contexts] else: context_sequence = list(context) all_scalars = all(not _is_sequence(entry) for entry in context_sequence) if all_scalars: if len(context_sequence) != len(queries): raise ValueError("Number of contexts must match number of queries") context_structure = "aligned" contexts = [[str(entry)] for entry in context_sequence] else: context_structure = "nested" normalized_nested: list[list[Any]] = [] for entry in context_sequence: if not _is_sequence(entry): raise ValueError("Number of context lists must match number of queries") normalized_nested.append(_normalize_context_collection(entry)) contexts = normalized_nested if context_structure == "list" and len(queries) != 1: raise ValueError("Single list of contexts requires a single query") if context_structure == "nested" and len(contexts) != len(queries): raise ValueError("Number of context lists must match number of queries") if context_structure == "str" and len(queries) != 1: raise ValueError("Single context string requires a single query") if context_structure in {"str", "list"}: contexts = [contexts[0]] return queries, contexts, context_structure def _prepare_titles( self, title: None | str | Sequence[str] | Sequence[Sequence[str]], queries: list[str], contexts: list[list[str]], ) -> list[Any]: """Normalize title inputs for process().""" n_queries = len(queries) if title is None: return [None] * n_queries if isinstance(title, str): if title == "first_sentence": return ["first_sentence"] * n_queries return [[title for _ in ctxs] for ctxs in contexts] if isinstance(title, Sequence): normalized: list[Any] = [] for entry in title: if isinstance(entry, Sequence) and not isinstance(entry, str): normalized.append([str(value) for value in entry]) else: normalized.append(str(entry)) if n_queries == 1 and all(isinstance(item, str) for item in normalized): return [[str(item) for item in normalized]] if len(normalized) == n_queries and all(isinstance(item, list) for item in normalized): return [list(map(str, item)) for item in normalized] # type: ignore[list-item] if len(normalized) == n_queries and all(isinstance(item, str) for item in normalized): return [[value for _ in contexts[idx]] for idx, value in enumerate(normalized)] raise ValueError("Unsupported title format") def _extract_first_line_titles( self, contexts: list[list[Any]], ) -> tuple[list[list[Any]], list[list[str]]]: """Split the first non-empty line from each context as a title candidate.""" updated_contexts: list[list[Any]] = [] extracted_titles: list[list[str]] = [] for context_group in contexts: group_titles: list[str] = [] updated_group: list[Any] = [] for entry in context_group: if isinstance(entry, list): normalized = [str(value) for value in entry] title_candidate = "" remainder: list[str] = [] for idx, segment in enumerate(normalized): if segment.strip(): title_candidate = segment.rstrip("\r\n") remainder = normalized[idx + 1 :] break else: remainder = normalized group_titles.append(title_candidate) updated_group.append(remainder) else: text_entry = str(entry) title_candidate = "" remainder_text = "" if text_entry: lines = text_entry.splitlines(keepends=True) remainder_segments: list[str] = [] for idx, line in enumerate(lines): if line.strip(): title_candidate = line.rstrip("\r\n") remainder_segments = lines[idx + 1 :] break else: remainder_segments = lines remainder_text = "".join(remainder_segments) group_titles.append(title_candidate) updated_group.append(remainder_text) extracted_titles.append(group_titles) updated_contexts.append(updated_group) return updated_contexts, extracted_titles def _resolve_titles( self, queries: list[str], contexts: list[list[Any]], title: None | str | Sequence[str] | Sequence[Sequence[str]], *, first_line_as_title: bool, ) -> tuple[list[list[Any]], list[Any]]: """Resolve title inputs, optionally extracting first lines from contexts.""" title_payload: None | str | Sequence[str] | Sequence[Sequence[str]] if first_line_as_title: if title not in (None, "first_sentence"): raise ValueError( "first_line_as_title=True cannot be combined with an explicit title override." ) contexts, extracted_titles = self._extract_first_line_titles(contexts) title_payload = extracted_titles else: title_payload = title titles = self._prepare_titles(title_payload, queries, contexts) return contexts, titles def _build_preprocess_jobs( self, queries: list[str], contexts: list[list[Any]], titles: list[Any], splitter: SentenceSplitter, *, strip_sentences: bool, show_progress: bool, ) -> tuple[list[dict[str, Any]], list[list[int]]]: """Construct preprocessing jobs and cache query token ids.""" preprocess_jobs: list[dict[str, Any]] = [] query_token_ids: list[list[int]] = [] total_contexts = sum(len(context_collection) for context_collection in contexts) progress = None if show_progress and is_progress_bar_enabled() and total_contexts: try: from tqdm import tqdm # pragma: no cover - optional dependency except Exception: # pragma: no cover - tqdm may be unavailable progress = None else: progress = tqdm(total=total_contexts, desc="Prepare contexts") for query_idx, query_text in enumerate(queries): query_tokens = self.tokenizer.encode(query_text, add_special_tokens=False) query_token_ids.append(query_tokens) title_spec = titles[query_idx] for context_idx, context_entry in enumerate(contexts[query_idx]): if isinstance(context_entry, list): manual_sentences = [str(s) for s in context_entry if str(s).strip()] context_text = "".join(manual_sentences) else: manual_sentences = None context_text = context_entry prefix_sentences, title_is_first_sentence = self._resolve_prefix_sentences( title_spec, context_idx, ) cached_sentences, cached_token_lists = self._precompute_sentences_and_tokens( context_text, prefix_sentences, manual_sentences, splitter, strip_sentences, ) prefix_count = len(prefix_sentences) if cached_token_lists is not None: prefix_token_counts = [ len(tokens) for tokens in cached_token_lists[:prefix_count] ] else: prefix_token_counts = [ len(self.tokenizer.encode(sentence, add_special_tokens=False)) if sentence else 0 for sentence in prefix_sentences ] preprocess_jobs.append( { "query_idx": query_idx, "context_idx": context_idx, "context_text": context_text, "prefix_sentences": prefix_sentences, "title_is_first_sentence": title_is_first_sentence, "prefix_token_counts": prefix_token_counts, "manual_sentences": manual_sentences, "cached_sentences": cached_sentences, "cached_token_lists": cached_token_lists, } ) if progress is not None: progress.update(1) if progress is not None: progress.close() return preprocess_jobs, query_token_ids def _resolve_preprocess_workers(self, override: int | None) -> int: if override is not None: return max(0, int(override)) env_value = os.getenv("OPEN_PROVENCE_PREPROCESS_WORKERS") if env_value: try: parsed = int(env_value) except ValueError: parsed = 0 if parsed > 0: return parsed return _default_preprocess_workers() def _estimate_device_memory_bytes(self) -> int | None: override_gb = os.getenv("OPEN_PROVENCE_DEVICE_MEMORY_GB") if override_gb: try: parsed = float(override_gb) except ValueError: parsed = None else: if parsed > 0: return int(parsed * (1024**3)) device = getattr(self, "_runtime_device", None) if not isinstance(device, torch.device): return None if device.type == "cuda": try: index = device.index if device.index is not None else torch.cuda.current_device() except Exception: index = None if index is None: return None try: props = torch.cuda.get_device_properties(index) except Exception: return None total = getattr(props, "total_memory", None) return int(total) if total is not None else None return None def _auto_tune_preprocess_loader( self, *, total_jobs: int, inference_batch_size: int, current_workers: int, current_preprocess_batch: int, current_prefetch: int | None, workers_explicit: bool, batch_explicit: bool, prefetch_explicit: bool, ) -> tuple[int, int, int | None]: # NOTE: This helper encapsulates several heuristics that evolved from # manual benchmarking. Adding a comment here keeps the expectations # close to the code, so future refactors know which behaviours must # stay stable (and are covered by tests). jobs_count = max(0, int(total_jobs)) workers = max(0, int(current_workers)) preprocess_batch = max(1, int(current_preprocess_batch)) prefetch_factor = current_prefetch if prefetch_explicit else None if not workers_explicit: cpu_limit = max(0, _default_preprocess_workers()) workers = min(workers or cpu_limit, cpu_limit) if jobs_count < 2_000: workers = 0 elif workers == 0 and cpu_limit > 0: workers = min(cpu_limit, 4) if jobs_count: workers = min(workers, jobs_count) if not batch_explicit: device_bytes = self._estimate_device_memory_bytes() cap_from_device: int | None = None if device_bytes: device_gb = device_bytes / float(1024**3) if device_gb < 12: cap_from_device = 64 elif device_gb < 20: cap_from_device = 128 else: cap_from_device = 192 fallback_cap = min(96, max(32, inference_batch_size)) target_cap = cap_from_device or fallback_cap preprocess_batch = min(preprocess_batch, target_cap) preprocess_batch = min(preprocess_batch, max(1, inference_batch_size)) if jobs_count: preprocess_batch = min(preprocess_batch, jobs_count) if workers <= 0: workers = 0 if workers == 0 and not prefetch_explicit: prefetch_factor = None elif workers > 0 and not prefetch_explicit: prefetch_factor = max(2, min(8, math.ceil(preprocess_batch / workers))) return workers, preprocess_batch, prefetch_factor def _run_preprocess_pipeline( self, jobs: list[dict[str, Any]], max_fragment_tokens: int, splitter: SentenceSplitter, show_progress: bool, strip_sentences: bool, *, respect_sentence_boundaries: bool, ) -> tuple[list[dict[str, Any]], float]: """Execute the preprocessing pipeline and return processed entries with timing.""" preprocess_start = perf_counter() processed_entries = self._run_sequential_fragmentize( jobs, max_fragment_tokens=max_fragment_tokens, splitter=splitter, show_progress=show_progress, strip_sentences=strip_sentences, respect_sentence_boundaries=respect_sentence_boundaries, ) preprocess_time = perf_counter() - preprocess_start return processed_entries, preprocess_time def _assemble_inference_inputs( self, preprocess_jobs: list[dict[str, Any]], processed_entries: list[dict[str, Any]], query_token_ids: list[list[int]], sep_token_ids: list[int], ) -> tuple[ dict[tuple[int, int], dict[str, Any]], list[dict[str, Any]], dict[str, float], float, ]: """Convert processed entries into inference jobs and aggregate timing metrics.""" contexts_info: dict[tuple[int, int], dict[str, Any]] = {} inference_jobs: list[dict[str, Any]] = [] timing_totals = { "sentence_collect_seconds": 0.0, "sentence_normalize_seconds": 0.0, "tokenize_seconds": 0.0, "fragment_split_seconds": 0.0, "fragment_decode_seconds": 0.0, } def _consume_timing(payload: dict[str, Any], key: str) -> float: value = payload.pop(key, 0.0) if isinstance(value, (list, tuple)): value = sum(value) try: return float(value) except (TypeError, ValueError): return 0.0 assembly_start = perf_counter() for job, processed in zip(preprocess_jobs, processed_entries): job.pop("cached_sentences", None) job.pop("cached_token_lists", None) timing_totals["sentence_collect_seconds"] += _consume_timing( processed, "timing_sentence_collect" ) timing_totals["sentence_normalize_seconds"] += _consume_timing( processed, "timing_sentence_normalize" ) timing_totals["tokenize_seconds"] += _consume_timing(processed, "timing_tokenize") timing_totals["fragment_split_seconds"] += _consume_timing( processed, "timing_fragment_split" ) timing_totals["fragment_decode_seconds"] += _consume_timing( processed, "timing_fragment_decode" ) fragment_texts = processed.get("fragment_texts", []) sentence_indices = processed.get("fragment_sentence_index", []) fragment_indices = processed.get("fragment_fragment_index", []) global_indices = processed.get("fragment_global_index", []) token_id_lists = processed.get("fragment_token_ids", []) fragments: list[_FragmentRecord] = [] for idx, text in enumerate(fragment_texts): tokens = list(token_id_lists[idx]) if idx < len(token_id_lists) else [] fragments.append( _FragmentRecord( text=text, sentence_index=int(sentence_indices[idx]) if idx < len(sentence_indices) else 0, fragment_index=int(fragment_indices[idx]) if idx < len(fragment_indices) else 0, global_index=int(global_indices[idx]) if idx < len(global_indices) else idx, token_length=len(tokens), token_ids=tokens, ) ) sentences: list[str] = processed.get("sentences", []) query_idx = job["query_idx"] context_idx = job["context_idx"] prefix_len = len(job.get("prefix_sentences", [])) prefix_token_counts = job.get("prefix_token_counts", []) blocks = self._assemble_blocks_from_fragments( len(query_token_ids[query_idx]), len(sep_token_ids), fragments ) contexts_info[(query_idx, context_idx)] = { "sentences": sentences, "fragments": fragments, "blocks": blocks, "prefix_length": prefix_len, "prefix_sentences": job.get("prefix_sentences", []), "prefix_token_counts": prefix_token_counts, "title_is_first_sentence": job.get("title_is_first_sentence", False), "original_text": job["context_text"], "raw_blocks": [], } for block_idx, block in enumerate(blocks): inference_jobs.append( { "query_idx": query_idx, "context_idx": context_idx, "block_idx": block_idx, "texts": [fragment.text for fragment in block], } ) assembly_time = perf_counter() - assembly_start return contexts_info, inference_jobs, timing_totals, assembly_time def _run_inference_batches( self, inference_jobs: list[dict[str, Any]], batch_size: int, queries: list[str], query_token_ids: list[list[int]], contexts_info: dict[tuple[int, int], dict[str, Any]], *, show_inference_progress: bool, show_progress: bool, ) -> float: """Execute model inference over prepared jobs and attach raw predictions.""" inference_time = 0.0 total_inference_jobs = len(inference_jobs) progress_bar: Any | None = None if not total_inference_jobs: return inference_time if show_inference_progress: from tqdm import tqdm # inline import to avoid dependency when unused total_batches = (total_inference_jobs + batch_size - 1) // batch_size progress_bar = tqdm( range(0, total_inference_jobs, batch_size), total=total_batches, desc="Model inference", unit="batch", leave=False, ) batch_indices: Iterable[int] = progress_bar else: batch_indices = range(0, total_inference_jobs, batch_size) pad_token_raw = getattr(self.tokenizer, "pad_token_id", None) pad_token_id = int(pad_token_raw) if pad_token_raw is not None else 0 for start in batch_indices: chunk_jobs = inference_jobs[start : start + batch_size] if not chunk_jobs: continue chunk_queries = [queries[job["query_idx"]] for job in chunk_jobs] chunk_context_texts = [job["texts"] for job in chunk_jobs] chunk_query_tokens = [query_token_ids[job["query_idx"]] for job in chunk_jobs] prepared_inputs: list[dict[str, Any]] = [] ranges_per_job: list[list[tuple[int, int]]] = [] for job_entry, query_tokens_entry in zip(chunk_jobs, chunk_query_tokens): block_fragments = contexts_info[ (job_entry["query_idx"], job_entry["context_idx"]) ]["blocks"][job_entry["block_idx"]] ( input_ids_prepared, attention_mask_prepared, token_type_ids, context_ranges, ) = self._prepare_block_inputs( query_tokens_entry, block_fragments, ) prepared_inputs.append( { "input_ids": input_ids_prepared, "attention_mask": attention_mask_prepared, "token_type_ids": token_type_ids, } ) ranges_per_job.append(context_ranges) max_len = ( max(len(entry["input_ids"]) for entry in prepared_inputs) if prepared_inputs else 0 ) input_tensor = torch.full( (len(prepared_inputs), max_len), pad_token_id, dtype=torch.long, device=self._runtime_device, ) attention_tensor = torch.zeros( (len(prepared_inputs), max_len), dtype=torch.long, device=self._runtime_device, ) token_type_tensor: torch.Tensor | None = ( torch.zeros( (len(prepared_inputs), max_len), dtype=torch.long, device=self._runtime_device ) if any(entry.get("token_type_ids") for entry in prepared_inputs) else None ) for tensor_idx, entry in enumerate(prepared_inputs): ids_list = entry["input_ids"] attn_list = entry["attention_mask"] seq_len = len(ids_list) if seq_len == 0: continue input_tensor[tensor_idx, :seq_len] = torch.tensor( ids_list, dtype=torch.long, device=self._runtime_device, ) attention_tensor[tensor_idx, :seq_len] = torch.tensor( attn_list if attn_list else [1] * seq_len, dtype=torch.long, device=self._runtime_device, ) if token_type_tensor is not None: type_ids = entry.get("token_type_ids") or [0] * seq_len if len(type_ids) > seq_len: type_ids = type_ids[:seq_len] if len(type_ids) < seq_len: type_ids = list(type_ids) + [type_ids[-1]] * (seq_len - len(type_ids)) token_type_tensor[tensor_idx, :seq_len] = torch.tensor( type_ids, dtype=torch.long, device=self._runtime_device, ) infer_start = perf_counter() model_inputs = { "input_ids": input_tensor, "attention_mask": attention_tensor, } if token_type_tensor is not None: model_inputs["token_type_ids"] = token_type_tensor model_outputs = self.forward(return_dict=True, **model_inputs) inference_time += perf_counter() - infer_start ranking_logits = ( self._extract_model_output(model_outputs, "ranking_logits").detach().cpu() ) pruning_logits = ( self._extract_model_output(model_outputs, "pruning_logits").detach().cpu() ) if ranking_logits.dtype != torch.float32: ranking_logits = ranking_logits.to(dtype=torch.float32) if pruning_logits.dtype != torch.float32: pruning_logits = pruning_logits.to(dtype=torch.float32) for job_dict, raw_query, raw_contexts, ranges, rank_logits, prune_logits in zip( chunk_jobs, chunk_queries, chunk_context_texts, ranges_per_job, ranking_logits, pruning_logits, ): if rank_logits.ndim == 0 or rank_logits.numel() == 1: ranking_score = torch.sigmoid(rank_logits.flatten())[0].item() else: ranking_score = torch.sigmoid(rank_logits[..., 0]).item() pruning_probs = torch.softmax(prune_logits, dim=-1).numpy() if pruning_probs.ndim == 2 and pruning_probs.shape[1] == 2: pruning_probs = pruning_probs[:, 1] elif pruning_probs.ndim == 1: pruning_probs = pruning_probs else: pruning_probs = pruning_probs.reshape(-1) contexts_info[(job_dict["query_idx"], job_dict["context_idx"])][ "raw_blocks" ].append( ( job_dict["block_idx"], OpenProvenceRawPrediction( query=raw_query, contexts=list(raw_contexts), ranking_score=ranking_score, pruning_probs=pruning_probs, context_ranges=ranges, ), ) ) if progress_bar is not None: try: progress_bar.close() except Exception: # pragma: no cover - harmless pass if show_progress: try: progress_bar.write( f"Model inference time: {inference_time:.2f}s " f"({total_inference_jobs} blocks)" ) except Exception: # pragma: no cover - best effort fallback print( f"[OpenProvenceModel] Model inference took {inference_time:.2f}s " f"({total_inference_jobs} blocks)", flush=True, ) return inference_time def _postprocess_contexts( self, queries: list[str], contexts: list[list[Any]], contexts_info: dict[tuple[int, int], dict[str, Any]], *, threshold: float, always_select_title: bool, use_best_reranker_score: bool, sentence_probability_groups_requested: bool, collect_sentence_texts: bool, first_line_as_title: bool, zero_score_when_empty: bool, ) -> tuple[ list[list[str]], list[list[float | None]], list[list[float]], list[list[list[str]]] | None, list[list[list[str]]] | None, list[list[Any]], list[list[list[float]]] | None, float, ]: """Aggregate pruning outputs into user-facing structures.""" post_start = perf_counter() pruned_contexts: list[list[str]] = [] reranking_scores: list[list[float | None]] = [] compression_rates: list[list[float]] = [] if collect_sentence_texts: kept_sentences: list[list[list[str]]] | None = [] removed_sentences: list[list[list[str]]] | None = [] else: kept_sentences = None removed_sentences = None title_values: list[list[Any]] = [] sentence_probability_groups: list[list[list[float]]] | None = ( [] if sentence_probability_groups_requested else None ) for query_idx, _ in enumerate(queries): query_pruned: list[str] = [] query_scores: list[float | None] = [] query_compression: list[float] = [] query_kept: list[list[str]] | None = [] if collect_sentence_texts else None query_removed: list[list[str]] | None = [] if collect_sentence_texts else None query_titles: list[Any] = [] query_sentence_probabilities: list[list[float]] | None = ( [] if sentence_probability_groups is not None else None ) for context_idx, context_entry in enumerate(contexts[query_idx]): info = contexts_info.get((query_idx, context_idx)) prefix_sentences_value: Sequence[str] = () if info: raw_prefix = info.get("prefix_sentences", []) if isinstance(raw_prefix, str): prefix_sentences_value = (raw_prefix,) elif isinstance(raw_prefix, Sequence): prefix_sentences_value = tuple(str(item) for item in raw_prefix) if first_line_as_title and prefix_sentences_value: if len(prefix_sentences_value) == 1: fallback_title: Any = prefix_sentences_value[0] else: fallback_title = list(prefix_sentences_value) else: fallback_title = None context_sentence_probs: list[float] | None = ( [] if sentence_probability_groups is not None else None ) if not info or not info.get("fragments"): query_pruned.append(context_entry) query_scores.append(None) query_compression.append(0.0) if query_kept is not None: query_kept.append([context_entry] if context_entry else []) if query_removed is not None: query_removed.append([]) query_titles.append(fallback_title) if query_sentence_probabilities is not None: query_sentence_probabilities.append(context_sentence_probs or []) continue blocks = info["blocks"] raw_blocks = sorted(info["raw_blocks"], key=lambda x: x[0]) if not blocks or not raw_blocks: query_pruned.append(context_entry) query_scores.append(None) query_compression.append(0.0) if query_kept is not None: query_kept.append(info["sentences"]) if query_removed is not None: query_removed.append([]) query_titles.append(fallback_title) if context_sentence_probs is not None: context_sentence_probs.extend([1.0] * len(info["sentences"])) if query_sentence_probabilities is not None: query_sentence_probabilities.append(context_sentence_probs or []) continue fragment_scores: dict[int, list[float]] = defaultdict(list) ranking_score: float | None = None for (_, raw), block in zip(raw_blocks, blocks): block_probs = raw.pruning_probs ranges = raw.context_ranges prefix_counts = contexts_info[(query_idx, context_idx)].get( "prefix_token_counts", [] ) for fragment, (start, end) in zip(block, ranges): offset = sum(prefix_counts[: fragment.sentence_index]) start = max(0, start - offset) end = max(start, end - offset) end = min(end, len(block_probs)) start = min(start, len(block_probs)) mean_prob = 1.0 if end <= start else float(block_probs[start:end].mean()) fragment_scores[fragment.global_index].append(mean_prob) if raw.ranking_score is not None: if use_best_reranker_score: if ranking_score is None: ranking_score = raw.ranking_score else: ranking_score = max(ranking_score, raw.ranking_score) else: if ranking_score is None: ranking_score = raw.ranking_score sentence_scores: dict[int, list[float]] = defaultdict(list) for fragment in info["fragments"]: if fragment.global_index in fragment_scores: sentence_scores[fragment.sentence_index].extend( fragment_scores[fragment.global_index] ) kept_sentence_texts: list[str] = [] removed_sentence_texts: list[str] = [] sentences = info["sentences"] prefix_len = info["prefix_length"] title_sentence_index: int | None = None sentence_keep_flags: list[bool] = [] if always_select_title: if prefix_len > 0: title_sentence_index = 0 elif info.get("title_is_first_sentence") and len(sentences) > prefix_len: title_sentence_index = prefix_len sentence_avg_probabilities: list[float] = [] has_sentence_above_threshold = False for sentence_index in range(len(sentences)): probabilities = sentence_scores.get(sentence_index) avg_probability = float(np.mean(probabilities)) if probabilities else 0.0 avg_probability = max(0.0, min(avg_probability, 1.0)) sentence_avg_probabilities.append(avg_probability) if avg_probability > threshold: has_sentence_above_threshold = True force_keep_title = ( title_sentence_index is not None and has_sentence_above_threshold ) for sentence_index in range(len(sentences)): avg_probability = sentence_avg_probabilities[sentence_index] keep_flag = avg_probability > threshold if force_keep_title and sentence_index == title_sentence_index: keep_flag = True sentence_keep_flags.append(keep_flag) if context_sentence_probs is not None: context_sentence_probs.append(avg_probability) kept_sentence_texts = [ sentences[idx] for idx, keep in enumerate(sentence_keep_flags) if keep ] removed_sentence_texts = [ sentences[idx] for idx, keep in enumerate(sentence_keep_flags) if not keep ] content_kept_sentences = [ sentences[idx] for idx, keep in enumerate(sentence_keep_flags) if idx >= prefix_len and keep ] pruned_text = "".join(content_kept_sentences) original_text = info["original_text"] original_length = max(len(original_text), 1) compression = (len(original_text) - len(pruned_text)) / original_length * 100.0 if zero_score_when_empty and not pruned_text.strip(): ranking_score = 0.0 prefix_sentences_value = info.get("prefix_sentences", []) if prefix_sentences_value: if len(prefix_sentences_value) == 1: title_value = prefix_sentences_value[0] else: title_value = list(prefix_sentences_value) else: title_value = None query_pruned.append(pruned_text) query_scores.append(ranking_score) query_compression.append(compression) if query_kept is not None: query_kept.append(kept_sentence_texts) if query_removed is not None: query_removed.append(removed_sentence_texts) query_titles.append(title_value) if query_sentence_probabilities is not None: query_sentence_probabilities.append(context_sentence_probs or []) pruned_contexts.append(query_pruned) reranking_scores.append(query_scores) compression_rates.append(query_compression) if kept_sentences is not None and query_kept is not None: kept_sentences.append(query_kept) if removed_sentences is not None and query_removed is not None: removed_sentences.append(query_removed) title_values.append(query_titles) if ( sentence_probability_groups is not None and query_sentence_probabilities is not None ): sentence_probability_groups.append(query_sentence_probabilities) post_time = perf_counter() - post_start return ( pruned_contexts, reranking_scores, compression_rates, kept_sentences, removed_sentences, title_values, sentence_probability_groups, post_time, ) def _apply_reordering( self, pruned_contexts: list[list[str]], reranking_scores: list[list[float | None]], compression_rates: list[list[float]], kept_sentences: list[list[list[str]]] | None, removed_sentences: list[list[list[str]]] | None, title_values: list[list[Any]], sentence_probability_groups: list[list[list[float]]] | None, *, top_k: int | None, ) -> tuple[ list[list[str]], list[list[float | None]], list[list[float]], list[list[list[str]]] | None, list[list[list[str]]] | None, list[list[Any]], list[list[list[float]]] | None, ]: """Reorder contexts by reranker score and apply optional top-k truncation.""" if not pruned_contexts: return ( pruned_contexts, reranking_scores, compression_rates, kept_sentences, removed_sentences, title_values, sentence_probability_groups, ) if top_k is None: effective_top_k = None else: effective_top_k = max(0, int(top_k)) reordered_pruned: list[list[str]] = [] reordered_scores: list[list[float | None]] = [] reordered_compression: list[list[float]] = [] reordered_kept: list[list[list[str]]] | None = [] if kept_sentences is not None else None reordered_removed: list[list[list[str]]] | None = ( [] if removed_sentences is not None else None ) reordered_titles: list[list[Any]] = [] reordered_probs: list[list[list[float]]] | None = ( [] if sentence_probability_groups is not None else None ) for query_idx, scores in enumerate(reranking_scores): if not scores: reordered_pruned.append(pruned_contexts[query_idx]) reordered_scores.append(scores) reordered_compression.append(compression_rates[query_idx]) if reordered_kept is not None and kept_sentences is not None: reordered_kept.append(kept_sentences[query_idx]) if reordered_removed is not None and removed_sentences is not None: reordered_removed.append(removed_sentences[query_idx]) reordered_titles.append(title_values[query_idx]) if reordered_probs is not None: reordered_probs.append( sentence_probability_groups[query_idx] if sentence_probability_groups is not None else [] ) continue def _score_key(idx: int) -> float: value = scores[idx] if value is None: return float("-inf") return float(value) ranking_indices = sorted(range(len(scores)), key=_score_key, reverse=True) if effective_top_k is None: limited_indices = ranking_indices else: limited_indices = ranking_indices[:effective_top_k] reordered_pruned.append([pruned_contexts[query_idx][idx] for idx in limited_indices]) reordered_scores.append([scores[idx] for idx in limited_indices]) reordered_compression.append( [compression_rates[query_idx][idx] for idx in limited_indices] ) if reordered_kept is not None and kept_sentences is not None: reordered_kept.append([kept_sentences[query_idx][idx] for idx in limited_indices]) if reordered_removed is not None and removed_sentences is not None: reordered_removed.append( [removed_sentences[query_idx][idx] for idx in limited_indices] ) reordered_titles.append([title_values[query_idx][idx] for idx in limited_indices]) if reordered_probs is not None: reordered_probs.append( [sentence_probability_groups[query_idx][idx] for idx in limited_indices] if sentence_probability_groups is not None else [] ) return ( reordered_pruned, reordered_scores, reordered_compression, reordered_kept, reordered_removed, reordered_titles, reordered_probs if reordered_probs is not None else None, ) def process( self, question: str | Sequence[str], context: str | Sequence[str] | Sequence[Sequence[str]], title: None | str | Sequence[str] | Sequence[Sequence[str]] = "first_sentence", first_line_as_title: bool = False, *, batch_size: int = 32, threshold: float | None = None, always_select_title: bool = False, reorder: bool = False, top_k: int | None = None, sentence_splitter: SentenceSplitter | Mapping[str, SentenceSplitter] | None = None, language: str | None = None, use_best_reranker_score: bool = True, zero_score_when_empty: bool = True, show_progress: bool = True, debug_messages: bool | Callable[[str], None] = False, enable_warnings: bool = True, strip_sentences: bool = False, respect_sentence_boundaries: bool = False, return_sentence_metrics: bool = False, return_sentence_texts: bool = False, show_inference_progress: bool | None = None, preprocess_workers: int | None = None, preprocess_batch_size: int | None = None, torch_dataloader_kwargs: Mapping[str, Any] | None = None, ) -> dict[str, Any]: """Prune long contexts by chunking them while preserving sentence boundaries. Args: question: Query text or list of queries. context: Context text(s) corresponding to each query. title: Optional title sentences to prepend. Use "first_sentence" to reuse the initial sentence per context (legacy default). first_line_as_title: When True, split the first non-empty line of each context and treat it as the title. Cannot be combined with explicit title overrides. batch_size: GPU batch size for inference. threshold: Pruning probability threshold. When omitted, the method first attempts to read ``self.config.default_threadshold`` (legacy spelling) from the checkpoint's ``config.json``. If that field is absent, the module constant ``DEFAULT_PROCESS_THRESHOLD`` (set to ``0.1``) is used. always_select_title: Force keeping title sentence. reorder: When True, sort contexts for each query by descending reranker score. top_k: When set along with ``reorder=True``, keep only the first ``top_k`` contexts per query after sorting. sentence_splitter: Callable that splits text into sentences or a mapping from language code to splitter. If omitted, the ``language`` parameter selects one of the built-in splitters. language: Language code used when choosing the default splitter or resolving a splitter mapping. When None, ``"auto"`` is assumed, which automatically handles Japanese and English text. Supported values remain ``"auto"``, ``"ja"`` (fast-bunkai), and ``"en"`` (NLTK Punkt with additional heuristics) for backwards compatibility. use_best_reranker_score: When True (default), store the highest reranker score among all processed blocks for each context. When False, keep the score from the first block only (original behaviour). If all sentences are discarded, the reranker score is set to 0.0 when ``zero_score_when_empty`` is enabled. zero_score_when_empty: When True (default), force the reranker score to ``0.0`` when the pruned context becomes empty after stripping whitespace. Disable to preserve the original score even when no sentences are kept. show_progress: When True, display progress bars for preprocessing and inference stages. debug_messages: Enable verbose timing diagnostics. When True, messages are logged via this module's logger. Provide a callable to redirect messages elsewhere. Timing summaries are also attached to the return payload. enable_warnings: Suppress warning output from dependencies when set to False. strip_sentences: When True, trim sentence text with `strip()` after splitting and filter out blank sentences (legacy behaviour). When False (default), preserve leading and trailing whitespace for downstream scoring. respect_sentence_boundaries: When True, keep each sentence produced by the splitter as a single fragment whenever it fits within the model's maximum token window, only falling back to token-level splitting when a sentence exceeds the allowed length. return_sentence_metrics: When True, include per-sentence probabilities in the response payload under ``sentence_probabilities``. return_sentence_texts: When True, include ``kept_sentences`` / ``removed_sentences`` in the response payload. Defaults to False to minimise payload size. preprocess_workers: Number of DataLoader worker processes to use while fragmentizing contexts. When None, respects the ``OPEN_PROVENCE_PREPROCESS_WORKERS`` environment variable and defaults to 0 (main-process preprocessing). preprocess_batch_size: Number of contexts processed per preprocessing batch. Defaults to ``batch_size`` when omitted. torch_dataloader_kwargs: Optional mapping forwarded directly to the preprocessing ``DataLoader`` to fine-tune worker behaviour (e.g., setting a custom ``worker_init_fn`` or pinning strategy). .. caution:: Input shape determines how batching behaves. Passing ``question: str`` with ``context: List[str]`` is interpreted as *one* query paired with multiple documents. To batch distinct question–context pairs, provide ``question: List[str]`` and ``context: List[str]`` of equal length. If you supply ``context: List[List[str]]`` the inner lists are assumed to be pre-split sentences and the sentence splitter is skipped—use this form only when you have already segmented the text yourself. """ progress_restore: Callable[[], None] | None = None original_progress_enabled = is_progress_bar_enabled() if show_progress and not original_progress_enabled: enable_progress_bar() progress_restore = disable_progress_bar elif not show_progress and original_progress_enabled: disable_progress_bar() progress_restore = enable_progress_bar try: batch_size = max(1, batch_size) threshold = self._resolve_process_threshold(threshold) start_total = perf_counter() splitter = OpenProvenceModel._resolve_sentence_splitter( self, sentence_splitter, language ) debug_callback: Callable[[str], None] | None if isinstance(debug_messages, bool): debug_callback = LOGGER.info if debug_messages else None elif callable(debug_messages): debug_callback = debug_messages else: raise TypeError( "debug_messages must be a bool or a callable that accepts a string" ) def _log_debug(message: str) -> None: if debug_callback is not None: debug_callback(message) if show_inference_progress is None: show_inference_progress = show_progress warnings_cm: contextlib.AbstractContextManager[Any] warnings_entered = False if enable_warnings: warnings_cm = contextlib.nullcontext() else: # pragma: no cover - depends on caller preference warnings_cm = warnings.catch_warnings() warnings_cm.__enter__() warnings.simplefilter("ignore") warnings_entered = True preprocess_time = 0.0 assembly_time = 0.0 inference_time = 0.0 post_time = 0.0 timing_totals: dict[str, float] = { "sentence_collect_seconds": 0.0, "sentence_normalize_seconds": 0.0, "tokenize_seconds": 0.0, "fragment_split_seconds": 0.0, "fragment_decode_seconds": 0.0, } queries: list[str] = [] contexts: list[list[Any]] = [] structure = "str" preprocess_jobs: list[dict[str, Any]] = [] query_token_ids: list[list[int]] = [] contexts_info: dict[tuple[int, int], dict[str, Any]] = {} pruned_contexts: list[list[str]] = [] reranking_scores: list[list[float | None]] = [] compression_rates: list[list[float]] = [] kept_sentences: list[list[list[str]]] | None = None removed_sentences: list[list[list[str]]] | None = None title_values: list[list[Any]] = [] sentence_probability_groups: list[list[list[float]]] | None = None try: queries, contexts, structure = OpenProvenceModel._normalize_inputs( self, question, context ) contexts, titles = self._resolve_titles( queries, contexts, title, first_line_as_title=first_line_as_title, ) if respect_sentence_boundaries: max_fragment_tokens = max(16, self.max_length - 2) else: max_fragment_tokens = max(16, self.max_length // 2) sep_token_ids = self.tokenizer.encode( self.tokenizer.sep_token or "", add_special_tokens=False ) preprocess_jobs, query_token_ids = self._build_preprocess_jobs( queries, contexts, titles, splitter, strip_sentences=strip_sentences, show_progress=show_progress, ) resolved_workers = self._resolve_preprocess_workers(preprocess_workers) preprocess_batch = max(1, int(preprocess_batch_size or batch_size)) dataset = _PreprocessDataset( preprocess_jobs, self.tokenizer, splitter, max_fragment_tokens, strip_sentences, respect_sentence_boundaries, ) loader_kwargs: dict[str, Any] = { "batch_size": preprocess_batch, "shuffle": False, "num_workers": resolved_workers, "collate_fn": _preprocess_collate_fn, "pin_memory": False, "persistent_workers": resolved_workers > 0, } total_jobs = len(preprocess_jobs) workers_explicit = preprocess_workers is not None batch_explicit = preprocess_batch_size is not None prefetch_explicit = False if not workers_explicit and preprocess_workers is None: env_workers_raw = os.getenv("OPEN_PROVENCE_PREPROCESS_WORKERS") if env_workers_raw: try: workers_explicit = int(env_workers_raw) > 0 except ValueError: workers_explicit = False if torch_dataloader_kwargs: custom_kwargs = dict(torch_dataloader_kwargs) if "num_workers" in custom_kwargs: workers_explicit = True if "batch_size" in custom_kwargs: batch_explicit = True if "prefetch_factor" in custom_kwargs: prefetch_explicit = True loader_kwargs.update(custom_kwargs) resolved_workers = int(loader_kwargs.get("num_workers", resolved_workers)) preprocess_batch = int(loader_kwargs.get("batch_size", preprocess_batch)) current_prefetch_raw = loader_kwargs.get("prefetch_factor") current_prefetch: int | None if isinstance(current_prefetch_raw, (int, float)): current_prefetch = int(current_prefetch_raw) elif isinstance(current_prefetch_raw, str) and current_prefetch_raw.isdigit(): current_prefetch = int(current_prefetch_raw) else: current_prefetch = None if "multiprocessing_context" in loader_kwargs: loader_kwargs.pop("multiprocessing_context") ( resolved_workers, preprocess_batch, tuned_prefetch, ) = self._auto_tune_preprocess_loader( total_jobs=total_jobs, inference_batch_size=batch_size, current_workers=resolved_workers, current_preprocess_batch=preprocess_batch, current_prefetch=current_prefetch, workers_explicit=workers_explicit, batch_explicit=batch_explicit, prefetch_explicit=prefetch_explicit, ) loader_kwargs["num_workers"] = resolved_workers loader_kwargs["batch_size"] = preprocess_batch loader_kwargs["persistent_workers"] = resolved_workers > 0 if tuned_prefetch is not None: loader_kwargs["prefetch_factor"] = tuned_prefetch elif not prefetch_explicit and "prefetch_factor" in loader_kwargs: loader_kwargs.pop("prefetch_factor", None) loader = DataLoader(dataset, **loader_kwargs) if debug_callback is not None: _log_debug( "[OpenProvenceModel] " f"preprocess_workers={resolved_workers} " f"preprocess_batch={preprocess_batch} " f"default_workers={_default_preprocess_workers()}" ) total_blocks_processed = 0 loader_iter = iter(loader) shutdown_workers = getattr(loader_iter, "_shutdown_workers", None) try: for jobs_batch, entries_batch in loader_iter: if not jobs_batch: continue ( batch_contexts, batch_inference_jobs, batch_timing_totals, batch_assembly, ) = self._assemble_inference_inputs( jobs_batch, entries_batch, query_token_ids, sep_token_ids, ) assembly_time += batch_assembly preprocess_time += sum(batch_timing_totals.values()) for key, value in batch_timing_totals.items(): timing_totals[key] += value for key, info in batch_contexts.items(): existing = contexts_info.get(key) if existing is None: contexts_info[key] = info continue existing_raw = existing.setdefault("raw_blocks", []) existing_raw.extend(info.get("raw_blocks", [])) if not batch_inference_jobs: continue inference_time += self._run_inference_batches( batch_inference_jobs, batch_size, queries, query_token_ids, contexts_info, show_inference_progress=False, show_progress=show_progress, ) total_blocks_processed += len(batch_inference_jobs) finally: if shutdown_workers is not None: shutdown_workers() if show_progress and total_blocks_processed: message = ( f"[OpenProvenceModel] Model inference time: {inference_time:.2f}s " f"({total_blocks_processed} blocks)" ) if debug_callback is None: print(message, flush=True) else: _log_debug(message) ( pruned_contexts, reranking_scores, compression_rates, kept_sentences, removed_sentences, title_values, sentence_probability_groups, post_time, ) = self._postprocess_contexts( queries, contexts, contexts_info, threshold=threshold, always_select_title=always_select_title, use_best_reranker_score=use_best_reranker_score, sentence_probability_groups_requested=return_sentence_metrics, collect_sentence_texts=return_sentence_texts, first_line_as_title=first_line_as_title, zero_score_when_empty=zero_score_when_empty, ) finally: if warnings_entered: # pragma: no cover - depends on caller preference warnings_cm.__exit__(None, None, None) total_time = perf_counter() - start_total performance_trace = ProcessPerformanceTrace( preprocess_seconds=preprocess_time, assembly_seconds=assembly_time, inference_seconds=inference_time, postprocess_seconds=post_time, total_seconds=total_time, sentence_collect_seconds=timing_totals.get("sentence_collect_seconds", 0.0), sentence_normalize_seconds=timing_totals.get("sentence_normalize_seconds", 0.0), tokenize_seconds=timing_totals.get("tokenize_seconds", 0.0), fragment_split_seconds=timing_totals.get("fragment_split_seconds", 0.0), fragment_decode_seconds=timing_totals.get("fragment_decode_seconds", 0.0), ) timing_summary = performance_trace.as_dict() timing_line = ( "Timing: " f"preprocess={performance_trace.preprocess_seconds:.2f}s " f"[collect={performance_trace.sentence_collect_seconds:.2f}s " f"normalize={performance_trace.sentence_normalize_seconds:.2f}s " f"tokenize={performance_trace.tokenize_seconds:.2f}s " f"fragment_split={performance_trace.fragment_split_seconds:.2f}s " f"fragment_decode={performance_trace.fragment_decode_seconds:.2f}s] " f"assembly={performance_trace.assembly_seconds:.2f}s " f"inference={performance_trace.inference_seconds:.2f}s " f"postprocess={performance_trace.postprocess_seconds:.2f}s " f"total={performance_trace.total_seconds:.2f}s" ) _log_debug(f"[OpenProvenceModel] {timing_line}") if reorder: ( pruned_contexts, reranking_scores, compression_rates, kept_sentences, removed_sentences, title_values, sentence_probability_groups, ) = self._apply_reordering( pruned_contexts, reranking_scores, compression_rates, kept_sentences, removed_sentences, title_values, sentence_probability_groups, top_k=top_k, ) pruned_output: Any = pruned_contexts score_output: Any = reranking_scores compression_output: Any = compression_rates kept_output: Any = kept_sentences if kept_sentences is not None else None removed_output: Any = removed_sentences if removed_sentences is not None else None title_output: Any = title_values sentence_prob_output: Any = sentence_probability_groups if structure == "str" and pruned_contexts: pruned_output = pruned_contexts[0][0] if pruned_contexts[0] else "" score_output = reranking_scores[0][0] if reranking_scores[0] else None compression_output = compression_rates[0][0] if compression_rates[0] else 0.0 if kept_sentences is not None: kept_output = kept_sentences[0][0] if kept_sentences[0] else [] if removed_sentences is not None: removed_output = removed_sentences[0][0] if removed_sentences[0] else [] title_output = title_values[0][0] if title_values[0] else None if ( sentence_probability_groups is not None and sentence_probability_groups and sentence_probability_groups[0] ): sentence_prob_output = sentence_probability_groups[0][0] elif structure == "list" and pruned_contexts: pruned_output = pruned_contexts[0] score_output = reranking_scores[0] compression_output = compression_rates[0] if kept_sentences is not None: kept_output = kept_sentences[0] if removed_sentences is not None: removed_output = removed_sentences[0] title_output = title_values[0] if sentence_probability_groups is not None: sentence_prob_output = ( sentence_probability_groups[0] if sentence_probability_groups else [] ) elif structure == "aligned" and pruned_contexts: pruned_output = [entry[0] if entry else "" for entry in pruned_contexts] score_output = [scores[0] if scores else None for scores in reranking_scores] compression_output = [rates[0] if rates else 0.0 for rates in compression_rates] if kept_sentences is not None: kept_output = [values[0] if values else [] for values in kept_sentences] if removed_sentences is not None: removed_output = [values[0] if values else [] for values in removed_sentences] title_output = [values[0] if values else None for values in title_values] if sentence_probability_groups is not None: sentence_prob_output = [ values[0] if values else [] for values in sentence_probability_groups ] result_payload = { "pruned_context": pruned_output, "reranking_score": score_output, "compression_rate": compression_output, "title": title_output, "timing": timing_summary, "performance_trace": performance_trace, } if kept_output is not None: result_payload["kept_sentences"] = kept_output if removed_output is not None: result_payload["removed_sentences"] = removed_output if sentence_prob_output is not None: result_payload["sentence_probabilities"] = sentence_prob_output return result_payload finally: if progress_restore is not None: progress_restore() # Hugging Face integration ------------------------------------------------- class OpenProvenceForSequenceClassification(OpenProvenceModel): """Sequence classification wrapper compatible with transformers.AutoModel.""" def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, return_dict: bool | None = None, **kwargs: Any, ): return super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=return_dict, **kwargs, ) class OpenProvenceForTokenClassification(OpenProvenceModel): """Token classification wrapper that exposes pruning logits.""" def __init__( self, config: OpenProvenceConfig, *model_args: Any, device: str | torch.device | None = None, **model_kwargs: Any, ) -> None: super().__init__(config, *model_args, device=device, **model_kwargs) self.num_labels = config.num_pruning_labels def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, return_dict: bool | None = None, **kwargs: Any, ): effective_return_dict = return_dict if return_dict is not None else True base_output = super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=None, return_dict=True, **kwargs, ) classifier_output = cast(SequenceClassifierOutput, base_output) pruning_logits = cast(Tensor, getattr(classifier_output, "pruning_logits")) ranking_logits = cast(Tensor, getattr(classifier_output, "ranking_logits")) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = pruning_logits.view(-1, self.num_labels)[active_loss] active_labels = labels.view(-1)[active_loss] if active_logits.numel() > 0: loss = loss_fct(active_logits, active_labels) else: loss = torch.tensor(0.0, device=pruning_logits.device, requires_grad=True) else: loss = loss_fct(pruning_logits.view(-1, self.num_labels), labels.view(-1)) if not effective_return_dict: output: tuple[torch.Tensor, ...] = (pruning_logits,) if loss is not None: return (loss,) + output return output logits_output = cast(FloatTensor, pruning_logits) loss_output: FloatTensor | None = None if loss is not None: loss_output = cast(FloatTensor, loss.to(dtype=logits_output.dtype)) result = TokenClassifierOutput( loss=loss_output, logits=logits_output, hidden_states=classifier_output.hidden_states, attentions=classifier_output.attentions, ) setattr(result, "ranking_logits", ranking_logits) return result OpenProvenceEncoderConfig = OpenProvenceConfig OpenProvenceEncoderForSequenceClassification = OpenProvenceForSequenceClassification OpenProvenceEncoderForTokenClassification = OpenProvenceForTokenClassification __all__ = [ "OpenProvenceModel", "OpenProvenceRawPrediction", "OpenProvenceConfig", "OpenProvenceForSequenceClassification", "OpenProvenceForTokenClassification", ] ContextItem: TypeAlias = str | Sequence[str] ContextInput: TypeAlias = str | Sequence[ContextItem] | Sequence[Sequence[ContextItem]]