IP-Assist-Lite-T4 / src /index /chunker_v2.py
russellmiller49's picture
Add enhanced version with full Qdrant support and smart citations
efd6737
# src/index/chunker_v2.py
"""
Advanced chunking system with policy-driven behavior
Implements sentence-safe packing, table coalescing, and quality gates
"""
from __future__ import annotations
import re, json, hashlib, logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Callable
logger = logging.getLogger(__name__)
# ---------- Policy ----------
@dataclass
class ChunkPolicy:
target_tokens: int = 300
max_tokens: int = 500
min_tokens: int = 80
overlap_tokens: int = 45
ensure_sentence_boundary: bool = True
drop_if_short: int = 15
keep_section_intact: bool = False
pack_rows: int = 0
bullet_join: bool = False
drop_patterns: Tuple[re.Pattern, ...] = ()
version: str = "2"
def load_policy(yaml_path: Path) -> Tuple[ChunkPolicy, Dict[str, ChunkPolicy]]:
import yaml
cfg = yaml.safe_load(Path(yaml_path).read_text())
def compile_patterns(items):
return tuple(re.compile(p) for p in items or [])
base = cfg.get("default", {})
base_cp = ChunkPolicy(
target_tokens=base.get("target_tokens", 300),
max_tokens=base.get("max_tokens", 500),
min_tokens=base.get("min_tokens", 80),
overlap_tokens=base.get("overlap_tokens", 45),
ensure_sentence_boundary=base.get("ensure_sentence_boundary", True),
drop_if_short=base.get("drop_if_short", 15),
drop_patterns=compile_patterns(base.get("drop_patterns")),
version=str(cfg.get("version", "2"))
)
overrides = {}
for sec, o in (cfg.get("section_overrides") or {}).items():
tmp = ChunkPolicy(
target_tokens=o.get("target_tokens", base_cp.target_tokens),
max_tokens=o.get("max_tokens", base_cp.max_tokens),
min_tokens=o.get("min_tokens", base_cp.min_tokens),
overlap_tokens=o.get("overlap_tokens", base_cp.overlap_tokens),
ensure_sentence_boundary=o.get("ensure_sentence_boundary", base_cp.ensure_sentence_boundary),
drop_if_short=o.get("drop_if_short", base_cp.drop_if_short),
keep_section_intact=o.get("keep_section_intact", False),
pack_rows=o.get("pack_rows", 0),
bullet_join=o.get("bullet_join", False),
drop_patterns=base_cp.drop_patterns, # inherit base
version=base_cp.version
)
overrides[sec] = tmp
return base_cp, overrides
# ---------- Token length ----------
def make_token_len_fn(model_name: Optional[str] = None) -> Callable[[str], int]:
tok = None
if model_name:
try:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
except Exception:
tok = None
if tok:
return lambda s: len(tok.encode(s, add_special_tokens=False))
try:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
return lambda s: len(enc.encode(s))
except Exception:
return lambda s: max(1, len(s) // 4)
# ---------- Sentence segmentation ----------
def make_sentence_splitter() -> Callable[[str], List[str]]:
try:
import spacy
nlp = spacy.blank("en")
nlp.add_pipe("sentencizer")
# Increase max_length to handle large documents
nlp.max_length = 2000000 # 2M characters
return lambda t: [s.text.strip() for s in nlp(t).sents if s.text.strip()]
except Exception:
pass
try:
import nltk
nltk.download("punkt", quiet=True)
from nltk.tokenize import sent_tokenize
return lambda t: [s.strip() for s in sent_tokenize(t) if s.strip()]
except Exception:
pass
# Fallback: simple regex-based splitter
def simple_sent_split(text):
# Basic sentence splitting on punctuation + space + capital
# This is a simplified pattern that avoids complex lookbehinds
sentences = []
current = []
# Split by periods, exclamations, questions
parts = re.split(r'([.!?]+)', text)
for i in range(0, len(parts)-1, 2):
if i+1 < len(parts):
sent = parts[i] + parts[i+1]
sent = sent.strip()
if sent:
sentences.append(sent)
# Handle last part if it exists
if len(parts) % 2 == 1 and parts[-1].strip():
sentences.append(parts[-1].strip())
return [s for s in sentences if s]
return simple_sent_split
# ---------- Cleaning ----------
def preclean(text: str, patterns: Tuple[re.Pattern, ...]) -> str:
lines = []
for ln in text.splitlines():
if any(p.search(ln) for p in patterns):
continue
lines.append(ln)
out = "\n".join(lines)
out = re.sub(r"(\w)-\n(\w)", r"\1\2", out) # dehyphen
out = re.sub(r"[ \t]+", " ", out)
out = re.sub(r"\s+\n", "\n", out)
out = out.replace("\u00a0", " ")
return out.strip()
def is_garble(s: str) -> bool:
return bool(re.fullmatch(r"(?:/C\d{2,3}){6,}", s.strip()))
def normalize_for_hash(s: str) -> str:
s = re.sub(r"\s+", " ", s).strip().lower()
# Replace smart quotes with regular quotes
s = s.replace('"', '"').replace('"', '"').replace(''', "'").replace(''', "'")
return s
# ---------- Chunking core ----------
@dataclass
class Chunk:
text: str
start: int
end: int
token_count: int
index: int
issues: List[str]
def pack_sentences(sentences: List[Tuple[str,int,int]],
policy: ChunkPolicy, tlen: Callable[[str], int]) -> List[Chunk]:
chunks: List[Chunk] = []
i = 0
idx = 0
while i < len(sentences):
buf, start_char, end_char, toks, issues = [], sentences[i][1], sentences[i][2], 0, []
j = i
while j < len(sentences):
s, s0, s1 = sentences[j]
stoks = tlen(s)
if toks + stoks > policy.max_tokens and buf:
break
buf.append(s)
toks += stoks
end_char = s1
if toks >= policy.target_tokens:
break
j += 1
# Calculate overlap for next chunk
if buf and policy.overlap_tokens > 0 and j < len(sentences) - 1:
ov, k = 0, len(buf) - 1
while k >= 0 and ov < policy.overlap_tokens:
ov += tlen(buf[k])
k -= 1
next_i = max(i + 1, i + (k + 1))
else:
next_i = j + 1
txt = " ".join(buf).strip()
if not txt:
i = next_i
continue
# Check for issues
if not re.search(r'[.!?]"?\s*$', txt) and policy.ensure_sentence_boundary:
issues.append("mid_sentence_end")
if is_garble(txt):
issues.append("garbled_pdf")
if tlen(txt) < policy.min_tokens and len(buf) == 1:
issues.append("very_short")
chunks.append(Chunk(text=txt, start=start_char, end=end_char,
token_count=toks, index=idx, issues=issues))
idx += 1
i = next_i
return chunks
def coalesce_table_rows(rows: List[str], policy: ChunkPolicy,
tlen: Callable[[str], int]) -> List[str]:
if policy.pack_rows <= 1:
return rows
out, cur = [], []
for r in rows:
cur.append(r.strip())
if len(cur) >= policy.pack_rows:
out.append("; ".join(cur))
cur = []
if cur:
out.append("; ".join(cur))
return out
# ---------- Public API ----------
def chunk_document(doc: Dict, policy_base: ChunkPolicy,
overrides: Dict[str, ChunkPolicy],
tlen: Callable[[str], int]) -> List[Dict]:
sec_type = doc.get("section_type", "general")
policy = overrides.get(sec_type, policy_base)
raw_text = doc.get("text", "")
text = preclean(raw_text, policy.drop_patterns)
if not text:
return []
# Handle table rows specially
if sec_type == "table_row":
rows = [ln for ln in text.splitlines() if ln.strip()]
rows = coalesce_table_rows(rows, policy, tlen)
text = "\n".join(rows)
# Split into sentences
splitter = make_sentence_splitter()
sentences, pos = [], 0
for sent in splitter(text):
start = text.find(sent, pos)
if start == -1:
start = pos
end = start + len(sent)
sentences.append((sent, start, end))
pos = end
# Pack sentences into chunks
chunks = pack_sentences(sentences, policy, tlen)
# Deduplicate
seen = set()
out = []
for ch in chunks:
norm = normalize_for_hash(ch.text)
h = hashlib.md5(norm.encode()).hexdigest()
if h in seen:
ch.issues.append("duplicate")
continue
seen.add(h)
chunk_dict = {
"doc_id": doc.get("doc_id", ""),
"section_type": sec_type,
"chunk_id": f"{doc.get('doc_id', '')}:{ch.index}",
"index_in_doc": ch.index,
"text": ch.text,
"token_count": ch.token_count,
"start_char": ch.start,
"end_char": ch.end,
"policy_version": policy.version,
"issues": ch.issues,
}
# Add metadata from document
meta = doc.get("meta", {})
if meta:
chunk_dict.update(meta)
out.append(chunk_dict)
return out
def chunk_stream(docs: Iterable[Dict], policy_path: Path,
tokenizer_model: Optional[str] = None,
out_path: Optional[Path] = None) -> List[Dict]:
base, overrides = load_policy(policy_path)
tlen = make_token_len_fn(tokenizer_model)
all_chunks, qa_rows = [], []
for doc in docs:
chunks = chunk_document(doc, base, overrides, tlen)
for c in chunks:
qa_rows.append({
"doc_id": c["doc_id"],
"section_type": c["section_type"],
"index_in_doc": c["index_in_doc"],
"token_count": c["token_count"],
"ends_with_punct": not ("mid_sentence_end" in c["issues"]),
"issues": "|".join(c["issues"]),
"text_preview": (c["text"][:200] + "…") if len(c["text"]) > 200 else c["text"]
})
all_chunks.extend(chunks)
if out_path:
# Write chunks
with out_path.open("w", encoding="utf-8") as f:
for c in all_chunks:
f.write(json.dumps(c, ensure_ascii=False) + "\n")
# Write QA CSV
if qa_rows:
import csv
qa_path = out_path.with_suffix(".qa.csv")
with qa_path.open("w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=list(qa_rows[0].keys()))
w.writeheader()
w.writerows(qa_rows)
logger.info("Wrote %s and %s", out_path, qa_path)
return all_chunks
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
p.add_argument("--in_jsonl", type=Path, required=True)
p.add_argument("--out_jsonl", type=Path, required=True)
p.add_argument("--policy", type=Path, default=Path("configs/chunking.yaml"))
p.add_argument("--tokenizer", type=str, default=None,
help="HF model name or leave empty for heuristic/tiktoken")
args = p.parse_args()
# Load documents
docs = []
with args.in_jsonl.open("r", encoding="utf-8") as f:
for line in f:
if line.strip():
docs.append(json.loads(line))
# Process chunks
chunk_stream(docs, args.policy, tokenizer_model=args.tokenizer, out_path=args.out_jsonl)