mahoon-legal-ai / golden_builder.py
hajimammad's picture
Update golden_builder.py
e606eca verified
raw
history blame
19.5 kB
# -*- coding: utf-8 -*-
"""
Golden Builder (Persian Legal) — Fast, Robust, W&B-enabled
- سازگار با اپ شما (app.py): کلاس GoldenBuilder + توابع load_json_or_jsonl / save_jsonl
- بهبودها:
* نرمال‌سازی فارسی، پاکسازی نویز
* کش O(1) برای خلاصه‌ها
* باکت‌بندی برحسب طول توکن؛ جلوگیری از OOM
* autocast (bf16/fp16) برای سرعت و بهره‌وری VRAM
* گیت کیفیت: طول/تنوع/عدم تکرار n-gram/چگالی و امتیاز وزنی موجودیت
* وزن‌ها از legal_entity_weights.json خوانده می‌شود (خروجی Weight Tuning)
* W&B اختیاری: متادیتا + آرتیفکت دیتاست خروجی
"""
import os, re, json, hashlib, logging, math, random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Callable, Tuple
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# =========================
# Logging
# =========================
log = logging.getLogger("golden-builder")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# =========================
# Persian Normalization & Cleaning
# =========================
ZWNJ = "\u200c"
AR_DIGITS = "٠١٢٣٤٥٦٧٨٩"
FA_DIGITS = "۰۱۲۳۴۵۶۷۸۹"
EN_DIGITS = "0123456789"
TRANS_DIG = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
def normalize_fa(s: str) -> str:
if not isinstance(s, str): return ""
s = s.replace("\u064A", "ی").replace("\u0643", "ک")
s = s.translate(TRANS_DIG)
# حذف اعراب/کنترل‌ها
s = re.sub(r"[\u064B-\u065F\u0610-\u061A\u200B-\u200F\u202A-\u202E\uFEFF]", "", s)
# ZWNJ یکنواخت
s = re.sub(r"\s*‌\s*", ZWNJ, s)
# فاصله‌ها
s = re.sub(r"\s+", " ", s).strip()
return s
NOISE_PATTERNS = [
r"http[s]?://\S+",
r"www\.\S+",
r"\d{10,}", # رشته‌های عددی خیلی بلند
r"(.)\1{4,}", # کشیده‌ها
r"[^\u0600-\u06FF\s\d\.,;:!?()\"'\-]+", # کاراکترهای غیر فارسی/علائم
]
def clean_text(s: str) -> str:
s = normalize_fa(s)
for pat in NOISE_PATTERNS:
s = re.sub(pat, " ", s)
s = re.sub(r"\s+", " ", s)
s = re.sub(r"\.{2,}", "...", s)
# فاصله‌گذاری علائم
s = re.sub(r"\s+([،.;:!?])", r"\1", s)
s = re.sub(r"([،.;:!?])(?=[^\s])", r"\1 ", s)
return s.strip()
# =========================
# Utils
# =========================
def md5(s: str) -> str:
return hashlib.md5(s.encode("utf-8")).hexdigest()
def lex_diversity(s: str) -> float:
toks = s.split()
return 0.0 if not toks else len(set(toks))/len(toks)
def has_repetition(s: str, n: int = 3, thr: int = 2) -> bool:
toks = s.split()
if len(toks) < n: return False
grams = [tuple(toks[i:i+n]) for i in range(len(toks)-n+1)]
from collections import Counter
return any(c > thr for c in Counter(grams).values())
def set_all_seeds(seed: int = 42):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
# =========================
# Lightweight Legal NER (Regex) with external weights
# =========================
@dataclass
class LegalEntity:
text: str; category: str; start: int; end: int; weight: float
DEFAULT_WEIGHTS = {
"STATUTE": 1.0, "COURT": 0.9, "CRIME": 1.2,
"CIVIL": 0.8, "PROCED": 0.7, "PARTY": 0.6, "BUSINESS": 0.6
}
class LegalEntityExtractor:
def __init__(self):
defs = {
"STATUTE": ([
r"قانون\s+(?:اساسی|مدنی|کیفری|کار|تجارت|مجازات|دریایی|هوایی)",
r"آیین\s+دادرسی\s+(?:مدنی|کیفری|دادگاه‌های\s+عمومی|اداری)",
r"ماده\s+\d+(?:\s+(?:تبصره|الحاقی|اصلاحی))?",
r"تبصره\s+\d+",
r"لایحه\s+قانونی|اصلاحیه"
], DEFAULT_WEIGHTS["STATUTE"]),
"COURT": ([
r"دیوان\s+(?:عالی|عدالت\s+اداری|محاسبات)",
r"دادگاه\s+(?:عمومی|تجدیدنظر|انقلاب|نظامی|اطفال|خانواده)",
r"شعبه\s+\d+(?:\s+دادگاه)?",
r"هیئت\s+(?:منصفه|تخلفات|عمومی)"
], DEFAULT_WEIGHTS["COURT"]),
"CRIME": ([
r"کلاهبرداری|اختلاس|ارتشا|رشوه|خیانت\s+در\s+امانت",
r"جعل(?:\s+(?:اسناد|امضا))?|سرقت(?:\s+(?:مشدد|ساده))?",
r"قتل(?:\s+(?:عمد|شبه\s+عمد|خطای\s+محض))?",
r"تصادف\s+منجر\s+به\s+فوت|قاچاق\s+(?:مواد\s+مخدر|کالا)|پولشویی"
], DEFAULT_WEIGHTS["CRIME"]),
"CIVIL": ([
r"قرارداد|عقد\s+(?:بیع|اجاره|رهن|نکاح|صلح|هبه|وکالت)",
r"خسارت|تعهد|ضمان|مطالبه|وجه\s+التزام|فسخ|اقاله",
r"مهریه|نفقه|حضانت|جهیزیه"
], DEFAULT_WEIGHTS["CIVIL"]),
"PROCED": ([
r"دادخواست|لایحه|شکوائیه|ابلاغ|جلسه\s+دادرسی|کارشناسی",
r"دلایل\s+اثباتی|استماع\s+شهود|رأی|حکم|قرار"
], DEFAULT_WEIGHTS["PROCED"]),
"PARTY": ([
r"خواهان|خواندگان?|شاکی(?:ان)?|متهم(?:ین|ان)?|محکوم\s+(?:له|علیه)",
r"وکیل\s+(?:دادگستری|پایه\s+یک)?|دادستان|بازپرس|قاضی|کارشناس\s+رسمی"
], DEFAULT_WEIGHTS["PARTY"]),
"BUSINESS": ([
r"شرکت\s+(?:سهامی|مسئولیت\s+محدود|تضامنی)|ورشکستگی|نکول|سهام",
r"چک|سفته|برات|اوراق\s+بهادار|مجمع\s+عمومی"
], DEFAULT_WEIGHTS["BUSINESS"])
}
# Override از فایل خارجی اگر موجود
learned = {}
try:
if os.path.exists("legal_entity_weights.json"):
with open("legal_entity_weights.json","r",encoding="utf-8") as f:
learned = json.load(f)
except Exception:
learned = {}
self._patterns = []
for cat, (ps, w) in defs.items():
ww = float(learned.get(cat, w))
for p in ps:
self._patterns.append((re.compile(p, re.IGNORECASE), cat, ww))
self._cache = {}
def extract(self, text: str) -> List[LegalEntity]:
h = md5(text)
if h in self._cache: return self._cache[h]
out, seen = [], set()
for rgx, cat, w in self._patterns:
for m in rgx.finditer(text):
s,e = m.span()
if (s,e) in seen: continue
seen.add((s,e))
out.append(LegalEntity(m.group(), cat, s, e, w))
out.sort(key=lambda x: x.start)
if len(self._cache) < 1000: self._cache[h] = out
return out
def weighted_score(self, entities: List[LegalEntity]) -> float:
# جمع وزن‌ها با طول توکن‌های موجودیت به عنوان تقویت‌کننده
score = 0.0
for e in entities:
span_len = max(len(e.text.split()), 1)
score += e.weight * math.log1p(span_len)
return score
# =========================
# Golden Builder
# =========================
@dataclass
class GBConfig:
min_src_tokens: int = 30
min_tgt_tokens: int = 20
max_tgt_tokens: int = 220
target_minmax_ratio: Tuple[float,float] = (0.12, 0.65) # len(tgt)/len(src)
min_lex_div: float = 0.40
ngram_repeat_n: int = 3
ngram_repeat_thr: int = 2
min_entity_count: int = 2
min_entity_weight_score: float = 2.0 # آستانه امتیاز وزنی برای قبولی
class GoldenBuilder:
"""
Drop-in replacement
"""
def __init__(
self,
model_name: str = "google/mt5-base",
device: Optional[str] = None,
min_len: int = 40,
max_len: int = 160,
seed: int = 42
):
set_all_seeds(seed)
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
log.info("Device: %s", self.device)
self.tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.to(self.device).eval()
self.min_len = int(min_len)
self.max_len = int(max_len)
self.cfg = GBConfig()
self.ner = LegalEntityExtractor()
# dtype & autocast تنظیم
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
self._amp_dtype = torch.bfloat16
elif torch.cuda.is_available():
self._amp_dtype = torch.float16
else:
self._amp_dtype = torch.float32
# کش خلاصه‌ها و seen
self._summary_cache: Dict[str, str] = {}
self._seen_hashes = set()
# W&B اختیاری
self._wandb_on = bool(os.getenv("WANDB_API_KEY"))
self._wb_run = None
if self._wandb_on:
try:
import wandb
self._wb = wandb
self._wb_run = wandb.init(
project=os.getenv("WANDB_PROJECT","mahoon-legal-ai"),
name="dataset_builder",
config={"model_name": model_name, "min_len": self.min_len, "max_len": self.max_len}
)
except Exception:
self._wandb_on = False
self._wb_run = None
# --------------------- I/O helpers ---------------------
def _encode(self, texts: List[str], max_length: int = 512):
return self.tok(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=max_length
).to(self.device)
# --------------------- Batching & Caching ---------------------
def _summarize_uncached(self, items: List[Tuple[int, str]], num_beams: int = 6, batch_tokens: int = 1400) -> Dict[int, str]:
"""
items: list of (original_index, text_with_prefix)
strategy: sort by length; greedy micro-batches under token budget
returns: {original_index: summary}
"""
if not items: return {}
# تخمین طول توکنی
lens = [len(self.tok(t, add_special_tokens=False).input_ids) for _, t in items]
order = np.argsort(lens) # از کوتاه به بلند
results: Dict[int, str] = {}
batch: List[Tuple[int, str]] = []
budget = 0
def flush_batch(B: List[Tuple[int,str]]):
if not B: return
idxs = [i for i,_ in B]
texts = [t for _,t in B]
inputs = self._encode(texts, max_length=512)
with torch.no_grad():
with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=self._amp_dtype):
ids = self.model.generate(
**inputs,
max_length=self.max_len,
min_length=self.min_len,
num_beams=num_beams,
length_penalty=2.5,
no_repeat_ngram_size=3,
early_stopping=True,
do_sample=False
)
outs = self.tok.batch_decode(ids, skip_special_tokens=True)
for i, gen in zip(idxs, outs):
results[i] = gen
for idx in order:
oi, txt = items[idx]
tlen = lens[idx]
if budget + tlen > batch_tokens and batch:
flush_batch(batch)
batch, budget = [], 0
batch.append((oi, txt)); budget += tlen
if batch:
flush_batch(batch)
return results
def _summarize_batch(self, texts: List[str], num_beams: int = 6) -> List[str]:
"""
ورودی: لیست متن‌ها (هر متن شامل prefix "summarize: ...")
خروجی: لیست خلاصه‌ها به همان ترتیب ورودی
"""
if not texts: return []
results = [None] * len(texts)
uncached: List[Tuple[int,str]] = []
for i, t in enumerate(texts):
h = md5(t)
if h in self._summary_cache:
results[i] = self._summary_cache[h]
else:
uncached.append((i, t))
if uncached:
out_map = self._summarize_uncached(uncached, num_beams=num_beams)
for i, _ in uncached:
results[i] = out_map.get(i, "")
# update cache
h = md5(texts[i])
if len(self._summary_cache) < 10000 and results[i]:
self._summary_cache[h] = results[i]
return [r or "" for r in results]
# --------------------- Quality Gate ---------------------
def _quality_gate(self, src: str, tgt: str, ents: List[LegalEntity]) -> bool:
s_len, t_len = len(src.split()), len(tgt.split())
if s_len < self.cfg.min_src_tokens: return False
if not (self.cfg.min_tgt_tokens <= t_len <= self.cfg.max_tgt_tokens): return False
comp = t_len / (s_len + 1e-8)
if not (self.cfg.target_minmax_ratio[0] <= comp <= self.cfg.target_minmax_ratio[1]): return False
if lex_diversity(tgt) < self.cfg.min_lex_div: return False
if has_repetition(tgt, self.cfg.ngram_repeat_n, self.cfg.ngram_repeat_thr): return False
# موجودیت‌ها: حداقل تعداد + حداقل امتیاز وزنی
if len(ents) < self.cfg.min_entity_count: return False
wscore = self.ner.weighted_score(ents)
if wscore < self.cfg.min_entity_weight_score: return False
return True
# --------------------- Public API ---------------------
def build(
self,
raw_items: List[Dict],
text_key: str = "متن_کامل",
batch_size: int = 4,
progress: Optional[Callable[[float, str], None]] = None
) -> List[Dict]:
"""
EXACT SAME signature (+progress اختیاری برای اتصال به Gradio)
"""
rows = []
N = len(raw_items)
if progress: progress(0.0, "شروع ساخت دیتاست")
log.info(f"Starting build: N={N}, text_key='{text_key}'")
processed = passed = failed = skipped = 0
i = 0
while i < N:
chunk = raw_items[i:i+batch_size]
# pre-clean & filter
cleaned = []
for it in chunk:
raw = it.get(text_key, "")
txt = clean_text(str(raw))
if len(txt.split()) < self.cfg.min_src_tokens:
skipped += 1
cleaned.append("") # placeholder برای چینش
else:
h = md5(txt)
if h in self._seen_hashes:
skipped += 1
cleaned.append("")
else:
self._seen_hashes.add(h)
cleaned.append(txt)
# آماده‌سازی ورودی‌های summary
todo_texts = [f"summarize: {c}" for c in cleaned if c]
outputs = self._summarize_batch(todo_texts) if todo_texts else []
# بازچینی خروجی‌ها روی cleaned
k = 0
for c in cleaned:
if not c:
continue
processed += 1
tgt = clean_text(outputs[k]); k += 1
ents = self.ner.extract(c)
if self._quality_gate(c, tgt, ents):
passed += 1
rows.append({
"input": f"summarize: {c}",
"output": tgt,
"metadata": {
"input_length": len(c.split()),
"target_length": len(tgt.split()),
"entity_count": len(ents),
"entity_weight_score": self.ner.weighted_score(ents)
},
"legal_entities": [
{"text": e.text, "category": e.category, "start": e.start, "end": e.end, "weight": e.weight}
for e in (ents[:24])
]
})
else:
failed += 1
i += batch_size
if progress:
msg = f"پیشرفت: {i}/{N} | معتبر: {len(rows)} | قبولی: {passed} | مردودی: {failed} | رد اولیه: {skipped}"
progress(min(i/N, 0.99), msg)
if (i // max(batch_size,1)) % 10 == 0:
log.info(f"Progress {i}/{N} | kept={len(rows)} pass_rate={passed/max(processed,1):.1%}")
# W&B logging
if self._wandb_on and self._wb_run is not None:
try:
kept = len(rows)
self._wb_run.summary.update({
"dataset_examples": kept,
"processed": processed,
"passed": passed,
"failed": failed,
"skipped": skipped,
"pass_rate": kept / max(processed, 1)
})
except Exception:
pass
if progress: progress(1.0, "اتمام ساخت دیتاست")
log.info(f"Build complete: kept={len(rows)} | processed={processed} | passed={passed} | failed={failed} | skipped={skipped}")
return rows
def save_as_artifact(self, rows: List[Dict], out_path: str = "/tmp/golden_dataset.jsonl", artifact_name: str = "golden-dataset"):
"""اختیاری: خروجی را ذخیره و به W&B آرتیفکت کنید."""
save_jsonl(rows, out_path)
if self._wandb_on and self._wb_run is not None:
try:
art = self._wb.Artifact(artifact_name, type="dataset")
art.add_file(out_path)
self._wb_run.log_artifact(art)
except Exception:
pass
return out_path
# =========================
# I/O helpers
# =========================
def load_json_or_jsonl(path: str) -> List[Dict]:
p = Path(path)
raw = p.read_text(encoding="utf-8").strip()
# JSON یا JSONL
try:
data = json.loads(raw)
return data if isinstance(data, list) else [data]
except json.JSONDecodeError:
out = []
for ln in raw.splitlines():
ln = ln.strip()
if not ln: continue
try: out.append(json.loads(ln))
except json.JSONDecodeError: pass
return out
def save_jsonl(rows: List[Dict], out_path: str):
p = Path(out_path); p.parent.mkdir(parents=True, exist_ok=True)
with p.open("w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")