hajimammad commited on
Commit
e606eca
·
verified ·
1 Parent(s): c35b21c

Update golden_builder.py

Browse files
Files changed (1) hide show
  1. golden_builder.py +392 -100
golden_builder.py CHANGED
@@ -1,72 +1,166 @@
1
- # golden_builder.py
2
  # -*- coding: utf-8 -*-
3
- import json, re, logging, hashlib
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
- from typing import Dict, List, Optional
7
- from collections import Counter
8
 
9
- import numpy as np
10
  import torch
 
11
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
 
 
 
 
13
  log = logging.getLogger("golden-builder")
14
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
 
16
- # ========= Utilities =========
17
- PERSIAN_MAP = {'ك':'ک','ى':'ی','ﻲ':'ی','ﯽ':'ی','أ':'ا','إ':'ا'}
18
- NOISE = [r"http[s]?://\S+", r"www\.\S+", r"\d{10,}", r"(.)\1{4,}", r"[^\u0600-\u06FF\s\d\.,;:!?()\"'\-]+"]
 
 
 
 
 
19
 
20
- def clean_text(s: str) -> str:
21
  if not isinstance(s, str): return ""
22
- for a,b in PERSIAN_MAP.items(): s = s.replace(a,b)
23
- for pat in NOISE: s = re.sub(pat, " ", s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  s = re.sub(r"\s+", " ", s)
25
  s = re.sub(r"\.{2,}", "...", s)
 
26
  s = re.sub(r"\s+([،.;:!?])", r"\1", s)
27
  s = re.sub(r"([،.;:!?])(?=[^\s])", r"\1 ", s)
28
  return s.strip()
29
 
 
 
 
30
  def md5(s: str) -> str:
31
- import hashlib as _h
32
- return _h.md5(s.encode("utf-8")).hexdigest()
33
 
34
  def lex_diversity(s: str) -> float:
35
  toks = s.split()
36
  return 0.0 if not toks else len(set(toks))/len(toks)
37
 
38
- def has_repetition(s: str, n:int=3) -> bool:
39
  toks = s.split()
40
  if len(toks) < n: return False
41
  grams = [tuple(toks[i:i+n]) for i in range(len(toks)-n+1)]
42
  from collections import Counter
43
- return any(c>2 for c in Counter(grams).values())
 
 
 
 
 
44
 
45
- # ========= Lightweight NER (regex spans برای متادیتا) =========
 
 
46
  @dataclass
47
  class LegalEntity:
48
  text: str; category: str; start: int; end: int; weight: float
49
 
 
 
 
 
 
50
  class LegalEntityExtractor:
51
  def __init__(self):
52
- self._defs = {
53
- "STATUTE": ( [r"قانون\s+(?:اساسی|مدنی|کیفری|کار|تجارت|مجازات)",
54
- r"آیین\s+دادرسی\s+(?:مدنی|کیفری)",
55
- r"ماده\s+\d+", r"تبصره\s+\d+"], 1.0 ),
56
- "COURT": ( [r"دیوان\s+(?:عالی|عدالت)", r"دادگاه\s+(?:عمومی|تجدیدنظر|انقلاب)", r"شعبه\s+\d+"], 0.9 ),
57
- "CRIME": ( [r"کلاهبرداری|اختلاس|ارتشا|خیانت\s+در\s+اما��ت|جعل|سرقت|قتل"], 1.2 ),
58
- "PENALTY": ( [r"حبس|جزای\s+نقدی|شلاق|قصاص|دیه|محرومیت\s+از\s+حقوق\s+اجتماعی"], 1.1 ),
59
- "CIVIL": ( [r"قرارداد|عقد\s+(?:بیع|اجاره|رهن|نکاح)|خسارت|تعهد|ضمان|مطالبه"], 0.8 ),
60
- "PROCED": ( [r"دادخواست|لایحه|شکواییه|ابلاغ|جلسه\s+دادرسی|کارشناسی|دلایل\s+اثباتی"], 0.7 ),
61
- "PARTY": ( [r"خواهان|خوانده|شاکی|متهم|وکیل\s+دادگستری|دادستان|قاضی"], 0.6 ),
62
- "BUSINESS": ( [r"شرکت\s+(?:سهامی|مسئولیت\s+محدود)|ورشکستگی|سهام|چک|سفته|برات"], 0.6 ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  }
 
 
 
 
 
 
 
 
 
 
64
  self._patterns = []
65
- for cat,(pats,w) in self._defs.items():
66
- for p in pats:
67
- self._patterns.append( (re.compile(p, re.IGNORECASE), cat, w) )
 
 
68
 
69
- def extract(self, text:str) -> List[LegalEntity]:
 
 
70
  out, seen = [], set()
71
  for rgx, cat, w in self._patterns:
72
  for m in rgx.finditer(text):
@@ -75,96 +169,296 @@ class LegalEntityExtractor:
75
  seen.add((s,e))
76
  out.append(LegalEntity(m.group(), cat, s, e, w))
77
  out.sort(key=lambda x: x.start)
 
78
  return out
79
 
80
- # ========= Builder =========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  class GoldenBuilder:
82
- def __init__(self, model_name: str = "google/mt5-base", device: Optional[str] = None,
83
- min_len:int=40, max_len:int=160, min_entities:int=2):
 
 
 
 
 
 
 
 
 
 
84
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
85
  log.info("Device: %s", self.device)
 
86
  self.tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
87
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
88
- self.model.eval()
89
- self.min_len, self.max_len = min_len, max_len
90
- self.min_entities = min_entities
91
- self._seen_hashes = set()
 
 
92
  self.ner = LegalEntityExtractor()
93
 
94
- def _summarize_batch(self, texts: List[str], num_beams:int=6) -> List[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if not texts: return []
96
- inputs = self.tok(texts, return_tensors="pt", truncation=True, padding=True, max_length=512).to(self.device)
97
- with torch.no_grad():
98
- ids = self.model.generate(
99
- **inputs,
100
- max_length=self.max_len, min_length=self.min_len,
101
- num_beams=num_beams, early_stopping=True,
102
- length_penalty=2.5, no_repeat_ngram_size=3, do_sample=False
103
- )
104
- return self.tok.batch_decode(ids, skip_special_tokens=True)
105
-
106
- def _quality_gate(self, src:str, tgt:str, ents:List[LegalEntity]) -> bool:
 
 
 
 
 
 
 
 
 
107
  s_len, t_len = len(src.split()), len(tgt.split())
108
- if not (30 <= s_len and 20 <= t_len <= 220): return False
109
- comp = (t_len/(s_len+1e-8))
110
- if not (0.12 <= comp <= 0.65): return False
111
- if lex_diversity(tgt) < 0.4: return False
112
- if has_repetition(tgt, 3): return False
113
- if len(ents) < self.min_entities: return False
114
- ent_density = (sum((e.end - e.start) for e in ents) / max(len(src),1)) * 100
115
- if ent_density < 4.0: return False
 
 
 
116
  return True
117
 
118
- def build(self, raw_items: List[Dict], text_key:str="متن_کامل", batch_size:int=4) -> List[Dict]:
119
- rows, i = [], 0
 
 
 
 
 
 
 
 
 
 
120
  N = len(raw_items)
 
 
 
 
 
121
  while i < N:
122
  chunk = raw_items[i:i+batch_size]
123
- cleaned = [clean_text(str(it.get(text_key, ""))) for it in chunk]
124
- # de-dup + کوتاه‌زدایی
125
- todo = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  for c in cleaned:
127
- if len(c.split()) < 30:
128
- todo.append("")
129
  continue
130
- h = md5(c)
131
- if h in self._seen_hashes:
132
- todo.append("")
133
- continue
134
- self._seen_hashes.add(h)
135
- todo.append(c)
136
- # Summarize only valid items
137
- inputs = [f"summarize: {t}" for t in todo if t]
138
- outputs = self._summarize_batch(inputs) if inputs else []
139
- k = 0
140
- for c in todo:
141
- if not c: continue
142
  tgt = clean_text(outputs[k]); k += 1
143
  ents = self.ner.extract(c)
144
- if not self._quality_gate(c, tgt, ents):
145
- continue
146
- ents_payload = [{"text": e.text, "category": e.category, "start": e.start, "end": e.end, "weight": e.weight}
147
- for e in ents[:20]]
148
- rows.append({
149
- "input": f"summarize: {c}",
150
- "output": tgt,
151
- "metadata": {
152
- "input_length": len(c.split()),
153
- "target_length": len(tgt.split())
154
- },
155
- "legal_entities": {
156
- "total_entities": len(ents),
157
- "categories": dict(Counter(e.category for e in ents)),
158
- "entities": ents_payload
159
- }
160
- })
 
 
161
  i += batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return rows
163
 
164
- # ========= I/O =========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def load_json_or_jsonl(path: str) -> List[Dict]:
166
  p = Path(path)
167
  raw = p.read_text(encoding="utf-8").strip()
 
168
  try:
169
  data = json.loads(raw)
170
  return data if isinstance(data, list) else [data]
@@ -173,10 +467,8 @@ def load_json_or_jsonl(path: str) -> List[Dict]:
173
  for ln in raw.splitlines():
174
  ln = ln.strip()
175
  if not ln: continue
176
- try:
177
- out.append(json.loads(ln))
178
- except json.JSONDecodeError:
179
- pass
180
  return out
181
 
182
  def save_jsonl(rows: List[Dict], out_path: str):
 
 
1
  # -*- coding: utf-8 -*-
2
+ """
3
+ Golden Builder (Persian Legal) — Fast, Robust, W&B-enabled
4
+ - سازگار با اپ شما (app.py): کلاس GoldenBuilder + توابع load_json_or_jsonl / save_jsonl
5
+ - بهبودها:
6
+ * نرمال‌سازی فارسی، پاکسازی نویز
7
+ * کش O(1) برای خلاصه‌ها
8
+ * باکت‌بندی برحسب طول توکن؛ جلوگیری از OOM
9
+ * autocast (bf16/fp16) برای سرعت و بهره‌وری VRAM
10
+ * گیت کیفیت: طول/تنوع/عدم تکرار n-gram/چگالی و امتیاز وزنی موجودیت
11
+ * وزن‌ها از legal_entity_weights.json خوانده می‌شود (خروجی Weight Tuning)
12
+ * W&B اختیاری: متادیتا + آرتیفکت دیتاست خروجی
13
+ """
14
+
15
+ import os, re, json, hashlib, logging, math, random
16
  from dataclasses import dataclass
17
  from pathlib import Path
18
+ from typing import Dict, List, Optional, Callable, Tuple
 
19
 
 
20
  import torch
21
+ import numpy as np
22
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
23
 
24
+ # =========================
25
+ # Logging
26
+ # =========================
27
  log = logging.getLogger("golden-builder")
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
 
30
+ # =========================
31
+ # Persian Normalization & Cleaning
32
+ # =========================
33
+ ZWNJ = "\u200c"
34
+ AR_DIGITS = "٠١٢٣٤٥٦٧٨٩"
35
+ FA_DIGITS = "۰۱۲۳۴۵۶۷۸۹"
36
+ EN_DIGITS = "0123456789"
37
+ TRANS_DIG = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
38
 
39
+ def normalize_fa(s: str) -> str:
40
  if not isinstance(s, str): return ""
41
+ s = s.replace("\u064A", "ی").replace("\u0643", "ک")
42
+ s = s.translate(TRANS_DIG)
43
+ # حذف اعراب/کنترل‌ها
44
+ s = re.sub(r"[\u064B-\u065F\u0610-\u061A\u200B-\u200F\u202A-\u202E\uFEFF]", "", s)
45
+ # ZWNJ یکنواخت
46
+ s = re.sub(r"\s*‌\s*", ZWNJ, s)
47
+ # فاصله‌ها
48
+ s = re.sub(r"\s+", " ", s).strip()
49
+ return s
50
+
51
+ NOISE_PATTERNS = [
52
+ r"http[s]?://\S+",
53
+ r"www\.\S+",
54
+ r"\d{10,}", # رشته‌های عددی خیلی بلند
55
+ r"(.)\1{4,}", # کشیده‌ها
56
+ r"[^\u0600-\u06FF\s\d\.,;:!?()\"'\-]+", # کاراکترهای غیر فارسی/علائم
57
+ ]
58
+
59
+ def clean_text(s: str) -> str:
60
+ s = normalize_fa(s)
61
+ for pat in NOISE_PATTERNS:
62
+ s = re.sub(pat, " ", s)
63
  s = re.sub(r"\s+", " ", s)
64
  s = re.sub(r"\.{2,}", "...", s)
65
+ # فاصله‌گذاری علائم
66
  s = re.sub(r"\s+([،.;:!?])", r"\1", s)
67
  s = re.sub(r"([،.;:!?])(?=[^\s])", r"\1 ", s)
68
  return s.strip()
69
 
70
+ # =========================
71
+ # Utils
72
+ # =========================
73
  def md5(s: str) -> str:
74
+ return hashlib.md5(s.encode("utf-8")).hexdigest()
 
75
 
76
  def lex_diversity(s: str) -> float:
77
  toks = s.split()
78
  return 0.0 if not toks else len(set(toks))/len(toks)
79
 
80
+ def has_repetition(s: str, n: int = 3, thr: int = 2) -> bool:
81
  toks = s.split()
82
  if len(toks) < n: return False
83
  grams = [tuple(toks[i:i+n]) for i in range(len(toks)-n+1)]
84
  from collections import Counter
85
+ return any(c > thr for c in Counter(grams).values())
86
+
87
+ def set_all_seeds(seed: int = 42):
88
+ random.seed(seed); np.random.seed(seed)
89
+ torch.manual_seed(seed)
90
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
91
 
92
+ # =========================
93
+ # Lightweight Legal NER (Regex) with external weights
94
+ # =========================
95
  @dataclass
96
  class LegalEntity:
97
  text: str; category: str; start: int; end: int; weight: float
98
 
99
+ DEFAULT_WEIGHTS = {
100
+ "STATUTE": 1.0, "COURT": 0.9, "CRIME": 1.2,
101
+ "CIVIL": 0.8, "PROCED": 0.7, "PARTY": 0.6, "BUSINESS": 0.6
102
+ }
103
+
104
  class LegalEntityExtractor:
105
  def __init__(self):
106
+ defs = {
107
+ "STATUTE": ([
108
+ r"قانون\s+(?:اساسی|مدنی|کیفری|کار|تجارت|مجازات|دریایی|هوایی)",
109
+ r"آیین\s+دادرسی\s+(?:مدنی|کیفری|دادگاه‌های\s+عمومی|اداری)",
110
+ r"ماده\s+\d+(?:\s+(?:تبصره|الحاقی|اصلاحی))?",
111
+ r"تبصره\s+\d+",
112
+ r"لایحه\s+قانونی|اصلاحیه"
113
+ ], DEFAULT_WEIGHTS["STATUTE"]),
114
+ "COURT": ([
115
+ r"دیوان\s+(?:عالی|عدالت\s+اداری|محاسبات)",
116
+ r"دادگاه\s+(?:عمومی|تجدیدنظر|انقلاب|نظامی|اطفال|خانواده)",
117
+ r"شعبه\s+\d+(?:\s+دادگاه)?",
118
+ r"هیئت\s+(?:منصفه|تخلفات|عمومی)"
119
+ ], DEFAULT_WEIGHTS["COURT"]),
120
+ "CRIME": ([
121
+ r"کلاهبرداری|اختلاس|ارتشا|رشوه|خیانت\s+در\s+امانت",
122
+ r"جعل(?:\s+(?:اسناد|امضا))?|سرقت(?:\s+(?:مشدد|ساده))?",
123
+ r"قتل(?:\s+(?:عمد|شبه\s+عمد|خطای\s+محض))?",
124
+ r"تصادف\s+منجر\s+به\s+فوت|قاچاق\s+(?:مواد\s+مخدر|کالا)|پولشویی"
125
+ ], DEFAULT_WEIGHTS["CRIME"]),
126
+ "CIVIL": ([
127
+ r"قرارداد|عقد\s+(?:بیع|اجاره|رهن|نکاح|صلح|هبه|وکالت)",
128
+ r"خسارت|تعهد|ضمان|مطالبه|وجه\s+التزام|فسخ|اقاله",
129
+ r"مهریه|نفقه|حضانت|جهیزیه"
130
+ ], DEFAULT_WEIGHTS["CIVIL"]),
131
+ "PROCED": ([
132
+ r"دادخواست|لایحه|شکوائیه|ابلاغ|جلسه\s+دادرسی|کارشناسی",
133
+ r"دلایل\s+اثباتی|استماع\s+شهود|رأی|حکم|قرار"
134
+ ], DEFAULT_WEIGHTS["PROCED"]),
135
+ "PARTY": ([
136
+ r"خواهان|خواندگان?|شاکی(?:ان)?|متهم(?:ین|ان)?|محکوم\s+(?:له|علیه)",
137
+ r"وکیل\s+(?:دادگستری|پایه\s+یک)?|دادستان|بازپرس|قاضی|کارشناس\s+رسمی"
138
+ ], DEFAULT_WEIGHTS["PARTY"]),
139
+ "BUSINESS": ([
140
+ r"شرکت\s+(?:سهامی|مسئولیت\s+محدود|تضامنی)|ورشکستگی|نکول|سهام",
141
+ r"چک|سفته|برات|اوراق\s+بهادار|مجمع\s+عمومی"
142
+ ], DEFAULT_WEIGHTS["BUSINESS"])
143
  }
144
+
145
+ # Override از فایل خارجی اگر موجود
146
+ learned = {}
147
+ try:
148
+ if os.path.exists("legal_entity_weights.json"):
149
+ with open("legal_entity_weights.json","r",encoding="utf-8") as f:
150
+ learned = json.load(f)
151
+ except Exception:
152
+ learned = {}
153
+
154
  self._patterns = []
155
+ for cat, (ps, w) in defs.items():
156
+ ww = float(learned.get(cat, w))
157
+ for p in ps:
158
+ self._patterns.append((re.compile(p, re.IGNORECASE), cat, ww))
159
+ self._cache = {}
160
 
161
+ def extract(self, text: str) -> List[LegalEntity]:
162
+ h = md5(text)
163
+ if h in self._cache: return self._cache[h]
164
  out, seen = [], set()
165
  for rgx, cat, w in self._patterns:
166
  for m in rgx.finditer(text):
 
169
  seen.add((s,e))
170
  out.append(LegalEntity(m.group(), cat, s, e, w))
171
  out.sort(key=lambda x: x.start)
172
+ if len(self._cache) < 1000: self._cache[h] = out
173
  return out
174
 
175
+ def weighted_score(self, entities: List[LegalEntity]) -> float:
176
+ # جمع وزن‌ها با طول توکن‌های موجودیت به عنوان تقویت‌کننده
177
+ score = 0.0
178
+ for e in entities:
179
+ span_len = max(len(e.text.split()), 1)
180
+ score += e.weight * math.log1p(span_len)
181
+ return score
182
+
183
+ # =========================
184
+ # Golden Builder
185
+ # =========================
186
+ @dataclass
187
+ class GBConfig:
188
+ min_src_tokens: int = 30
189
+ min_tgt_tokens: int = 20
190
+ max_tgt_tokens: int = 220
191
+ target_minmax_ratio: Tuple[float,float] = (0.12, 0.65) # len(tgt)/len(src)
192
+ min_lex_div: float = 0.40
193
+ ngram_repeat_n: int = 3
194
+ ngram_repeat_thr: int = 2
195
+ min_entity_count: int = 2
196
+ min_entity_weight_score: float = 2.0 # آستانه امتیاز وزنی برای قبولی
197
+
198
  class GoldenBuilder:
199
+ """
200
+ Drop-in replacement
201
+ """
202
+ def __init__(
203
+ self,
204
+ model_name: str = "google/mt5-base",
205
+ device: Optional[str] = None,
206
+ min_len: int = 40,
207
+ max_len: int = 160,
208
+ seed: int = 42
209
+ ):
210
+ set_all_seeds(seed)
211
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
212
  log.info("Device: %s", self.device)
213
+
214
  self.tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
215
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
216
+ self.model.to(self.device).eval()
217
+
218
+ self.min_len = int(min_len)
219
+ self.max_len = int(max_len)
220
+
221
+ self.cfg = GBConfig()
222
  self.ner = LegalEntityExtractor()
223
 
224
+ # dtype & autocast تنظیم
225
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
226
+ self._amp_dtype = torch.bfloat16
227
+ elif torch.cuda.is_available():
228
+ self._amp_dtype = torch.float16
229
+ else:
230
+ self._amp_dtype = torch.float32
231
+
232
+ # کش خلاصه‌ها و seen
233
+ self._summary_cache: Dict[str, str] = {}
234
+ self._seen_hashes = set()
235
+
236
+ # W&B اختیاری
237
+ self._wandb_on = bool(os.getenv("WANDB_API_KEY"))
238
+ self._wb_run = None
239
+ if self._wandb_on:
240
+ try:
241
+ import wandb
242
+ self._wb = wandb
243
+ self._wb_run = wandb.init(
244
+ project=os.getenv("WANDB_PROJECT","mahoon-legal-ai"),
245
+ name="dataset_builder",
246
+ config={"model_name": model_name, "min_len": self.min_len, "max_len": self.max_len}
247
+ )
248
+ except Exception:
249
+ self._wandb_on = False
250
+ self._wb_run = None
251
+
252
+ # --------------------- I/O helpers ---------------------
253
+ def _encode(self, texts: List[str], max_length: int = 512):
254
+ return self.tok(
255
+ texts,
256
+ return_tensors="pt",
257
+ truncation=True,
258
+ padding=True,
259
+ max_length=max_length
260
+ ).to(self.device)
261
+
262
+ # --------------------- Batching & Caching ---------------------
263
+ def _summarize_uncached(self, items: List[Tuple[int, str]], num_beams: int = 6, batch_tokens: int = 1400) -> Dict[int, str]:
264
+ """
265
+ items: list of (original_index, text_with_prefix)
266
+ strategy: sort by length; greedy micro-batches under token budget
267
+ returns: {original_index: summary}
268
+ """
269
+ if not items: return {}
270
+ # تخمین طول توکنی
271
+ lens = [len(self.tok(t, add_special_tokens=False).input_ids) for _, t in items]
272
+ order = np.argsort(lens) # از کوتاه به بلند
273
+
274
+ results: Dict[int, str] = {}
275
+ batch: List[Tuple[int, str]] = []
276
+ budget = 0
277
+
278
+ def flush_batch(B: List[Tuple[int,str]]):
279
+ if not B: return
280
+ idxs = [i for i,_ in B]
281
+ texts = [t for _,t in B]
282
+ inputs = self._encode(texts, max_length=512)
283
+ with torch.no_grad():
284
+ with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=self._amp_dtype):
285
+ ids = self.model.generate(
286
+ **inputs,
287
+ max_length=self.max_len,
288
+ min_length=self.min_len,
289
+ num_beams=num_beams,
290
+ length_penalty=2.5,
291
+ no_repeat_ngram_size=3,
292
+ early_stopping=True,
293
+ do_sample=False
294
+ )
295
+ outs = self.tok.batch_decode(ids, skip_special_tokens=True)
296
+ for i, gen in zip(idxs, outs):
297
+ results[i] = gen
298
+
299
+ for idx in order:
300
+ oi, txt = items[idx]
301
+ tlen = lens[idx]
302
+ if budget + tlen > batch_tokens and batch:
303
+ flush_batch(batch)
304
+ batch, budget = [], 0
305
+ batch.append((oi, txt)); budget += tlen
306
+ if batch:
307
+ flush_batch(batch)
308
+ return results
309
+
310
+ def _summarize_batch(self, texts: List[str], num_beams: int = 6) -> List[str]:
311
+ """
312
+ ورودی: لیست متن‌ها (هر متن شامل prefix "summarize: ...")
313
+ خروجی: لیست خلاصه‌ها به همان ترتیب ورودی
314
+ """
315
  if not texts: return []
316
+ results = [None] * len(texts)
317
+ uncached: List[Tuple[int,str]] = []
318
+ for i, t in enumerate(texts):
319
+ h = md5(t)
320
+ if h in self._summary_cache:
321
+ results[i] = self._summary_cache[h]
322
+ else:
323
+ uncached.append((i, t))
324
+ if uncached:
325
+ out_map = self._summarize_uncached(uncached, num_beams=num_beams)
326
+ for i, _ in uncached:
327
+ results[i] = out_map.get(i, "")
328
+ # update cache
329
+ h = md5(texts[i])
330
+ if len(self._summary_cache) < 10000 and results[i]:
331
+ self._summary_cache[h] = results[i]
332
+ return [r or "" for r in results]
333
+
334
+ # --------------------- Quality Gate ---------------------
335
+ def _quality_gate(self, src: str, tgt: str, ents: List[LegalEntity]) -> bool:
336
  s_len, t_len = len(src.split()), len(tgt.split())
337
+ if s_len < self.cfg.min_src_tokens: return False
338
+ if not (self.cfg.min_tgt_tokens <= t_len <= self.cfg.max_tgt_tokens): return False
339
+ comp = t_len / (s_len + 1e-8)
340
+ if not (self.cfg.target_minmax_ratio[0] <= comp <= self.cfg.target_minmax_ratio[1]): return False
341
+ if lex_diversity(tgt) < self.cfg.min_lex_div: return False
342
+ if has_repetition(tgt, self.cfg.ngram_repeat_n, self.cfg.ngram_repeat_thr): return False
343
+
344
+ # موجودیت‌ها: حداقل تعداد + حداقل امتیاز وزنی
345
+ if len(ents) < self.cfg.min_entity_count: return False
346
+ wscore = self.ner.weighted_score(ents)
347
+ if wscore < self.cfg.min_entity_weight_score: return False
348
  return True
349
 
350
+ # --------------------- Public API ---------------------
351
+ def build(
352
+ self,
353
+ raw_items: List[Dict],
354
+ text_key: str = "متن_کامل",
355
+ batch_size: int = 4,
356
+ progress: Optional[Callable[[float, str], None]] = None
357
+ ) -> List[Dict]:
358
+ """
359
+ EXACT SAME signature (+progress اختیاری برای اتصال به Gradio)
360
+ """
361
+ rows = []
362
  N = len(raw_items)
363
+ if progress: progress(0.0, "شروع ساخت دیتاست")
364
+ log.info(f"Starting build: N={N}, text_key='{text_key}'")
365
+
366
+ processed = passed = failed = skipped = 0
367
+ i = 0
368
  while i < N:
369
  chunk = raw_items[i:i+batch_size]
370
+ # pre-clean & filter
371
+ cleaned = []
372
+ for it in chunk:
373
+ raw = it.get(text_key, "")
374
+ txt = clean_text(str(raw))
375
+ if len(txt.split()) < self.cfg.min_src_tokens:
376
+ skipped += 1
377
+ cleaned.append("") # placeholder برای چینش
378
+ else:
379
+ h = md5(txt)
380
+ if h in self._seen_hashes:
381
+ skipped += 1
382
+ cleaned.append("")
383
+ else:
384
+ self._seen_hashes.add(h)
385
+ cleaned.append(txt)
386
+
387
+ # آماده‌سازی ورودی‌های summary
388
+ todo_texts = [f"summarize: {c}" for c in cleaned if c]
389
+ outputs = self._summarize_batch(todo_texts) if todo_texts else []
390
+ # بازچینی خروجی‌ها روی cleaned
391
+ k = 0
392
  for c in cleaned:
393
+ if not c:
 
394
  continue
395
+ processed += 1
 
 
 
 
 
 
 
 
 
 
 
396
  tgt = clean_text(outputs[k]); k += 1
397
  ents = self.ner.extract(c)
398
+ if self._quality_gate(c, tgt, ents):
399
+ passed += 1
400
+ rows.append({
401
+ "input": f"summarize: {c}",
402
+ "output": tgt,
403
+ "metadata": {
404
+ "input_length": len(c.split()),
405
+ "target_length": len(tgt.split()),
406
+ "entity_count": len(ents),
407
+ "entity_weight_score": self.ner.weighted_score(ents)
408
+ },
409
+ "legal_entities": [
410
+ {"text": e.text, "category": e.category, "start": e.start, "end": e.end, "weight": e.weight}
411
+ for e in (ents[:24])
412
+ ]
413
+ })
414
+ else:
415
+ failed += 1
416
+
417
  i += batch_size
418
+ if progress:
419
+ msg = f"پیشرفت: {i}/{N} | معتبر: {len(rows)} | قبولی: {passed} | مردودی: {failed} | رد اولیه: {skipped}"
420
+ progress(min(i/N, 0.99), msg)
421
+ if (i // max(batch_size,1)) % 10 == 0:
422
+ log.info(f"Progress {i}/{N} | kept={len(rows)} pass_rate={passed/max(processed,1):.1%}")
423
+
424
+ # W&B logging
425
+ if self._wandb_on and self._wb_run is not None:
426
+ try:
427
+ kept = len(rows)
428
+ self._wb_run.summary.update({
429
+ "dataset_examples": kept,
430
+ "processed": processed,
431
+ "passed": passed,
432
+ "failed": failed,
433
+ "skipped": skipped,
434
+ "pass_rate": kept / max(processed, 1)
435
+ })
436
+ except Exception:
437
+ pass
438
+
439
+ if progress: progress(1.0, "اتمام ساخت دیتاست")
440
+ log.info(f"Build complete: kept={len(rows)} | processed={processed} | passed={passed} | failed={failed} | skipped={skipped}")
441
  return rows
442
 
443
+ def save_as_artifact(self, rows: List[Dict], out_path: str = "/tmp/golden_dataset.jsonl", artifact_name: str = "golden-dataset"):
444
+ """اختیاری: خروجی را ذخیره و به W&B آرتیفکت کنید."""
445
+ save_jsonl(rows, out_path)
446
+ if self._wandb_on and self._wb_run is not None:
447
+ try:
448
+ art = self._wb.Artifact(artifact_name, type="dataset")
449
+ art.add_file(out_path)
450
+ self._wb_run.log_artifact(art)
451
+ except Exception:
452
+ pass
453
+ return out_path
454
+
455
+ # =========================
456
+ # I/O helpers
457
+ # =========================
458
  def load_json_or_jsonl(path: str) -> List[Dict]:
459
  p = Path(path)
460
  raw = p.read_text(encoding="utf-8").strip()
461
+ # JSON یا JSONL
462
  try:
463
  data = json.loads(raw)
464
  return data if isinstance(data, list) else [data]
 
467
  for ln in raw.splitlines():
468
  ln = ln.strip()
469
  if not ln: continue
470
+ try: out.append(json.loads(ln))
471
+ except json.JSONDecodeError: pass
 
 
472
  return out
473
 
474
  def save_jsonl(rows: List[Dict], out_path: str):