mahoon-legal-ai / weights_sweep.py
hajimammad's picture
Upload weights_sweep.py
0f23dca verified
# -*- coding: utf-8 -*-
"""
weights_sweep.py — Auto-tuning legal tag weights via W&B Sweeps
کارکرد اجرایی:
- هر ران: وزن‌ها از config → نوشتن legal_entity_weights.json → اجرای GoldenBuilder روی یک زیرمجموعه کوچک →
محاسبهٔ pass_rate (نرخ قبولی گیت کیفیت) → لاگ متریک‌ها و آرتیفکت‌ها روی W&B.
- در پایان Sweep، از داشبورد W&B «بهترین Run» را انتخاب کنید و وزن‌ها را تثبیت نمایید
(یا از آرتیفکت همان Run دانلود کنید و جایگزین legal_entity_weights.json نمایید).
پیش‌نیاز:
- فایل golden_builder.py در ریشه ریپو
- Secrets: WANDB_API_KEY در HF Spaces
- requirements: wandb, transformers, torch, ...
پارامترها (قابل‌تنظیم از طریق env یا UI):
- TUNE_DATA: مسیر فایل JSON/JSONL داده
- TUNE_TEXT_KEY: کلید متن در داده (پیش‌فرض "متن_کامل")
- TUNE_MAX_SAMPLES: تعداد نمونهٔ کوچک برای هر ران (پیش‌فرض 120)
- TUNE_BATCH: batch size Builder (پیش‌فرض 2)
- TUNE_COUNT: تعداد ران در sweep (پیش‌فرض 16)
- WANDB_PROJECT, WANDB_ENTITY: پروژه/ورک‌اسپیس W&B
"""
import os
import json
import random
from typing import Dict, List
import wandb
# فضای جست‌وجو: در صورت نیاز بازه‌ها را سخت‌گیرانه‌تر/وسیع‌تر کنید
SWEEP_CONFIG = {
"method": "bayes", # "random" یا "grid" هم قابل استفاده است
"metric": {"name": "pass_rate", "goal": "maximize"},
"parameters": {
"STATUTE": {"min": 0.8, "max": 1.4},
"COURT": {"min": 0.6, "max": 1.2},
"CRIME": {"min": 0.9, "max": 1.6},
"CIVIL": {"min": 0.5, "max": 1.2},
"PROCED": {"min": 0.5, "max": 1.0},
"PARTY": {"min": 0.4, "max": 0.9},
"BUSINESS": {"min": 0.4, "max": 0.9},
}
}
DEFAULT_TEXT_KEY = "متن_کامل"
def write_weights_file(weights: Dict[str, float], path: str = "legal_entity_weights.json"):
with open(path, "w", encoding="utf-8") as f:
json.dump({k: float(v) for k, v in weights.items()}, f, ensure_ascii=False, indent=2)
def sample_data(path: str, text_key: str, max_samples: int) -> List[dict]:
from golden_builder import load_json_or_jsonl
data = load_json_or_jsonl(path)
data = [r for r in data if isinstance(r, dict) and text_key in r and isinstance(r[text_key], str) and len(r[text_key].strip()) > 20]
random.shuffle(data)
return data[:max_samples]
def run_once(data_path: str, text_key: str, max_samples: int, batch_size: int):
"""
یک اجرای واحد Agent: وزن‌ها ← Builder → pass_rate
"""
cfg = wandb.config
weights = {
"STATUTE": cfg.STATUTE,
"COURT": cfg.COURT,
"CRIME": cfg.CRIME,
"CIVIL": cfg.CIVIL,
"PROCED": cfg.PROCED,
"PARTY": cfg.PARTY,
"BUSINESS": cfg.BUSINESS,
}
write_weights_file(weights) # این فایل توسط GoldenBuilder خوانده می‌شود
from golden_builder import GoldenBuilder, save_jsonl
rows_in = sample_data(data_path, text_key, max_samples=max_samples)
if not rows_in:
wandb.log({"pass_rate": 0.0, "kept": 0, "processed": 0})
wandb.summary.update({"weights": weights})
return
# برای سرعت/پایداری: mt5-base کافی است؛ اگر مدل دیگری می‌خواهید، پارامتر کنید
gb = GoldenBuilder(model_name="google/mt5-base")
rows_out = gb.build(rows_in, text_key=text_key, batch_size=batch_size)
processed = len(rows_in)
kept = len(rows_out)
pass_rate = kept / max(processed, 1)
# لاگ متریک‌ها + وزن‌ها
wandb.log({
"pass_rate": pass_rate,
"kept": kept,
"processed": processed
})
wandb.summary.update({"weights": weights})
# آرتیفکت خروجی نمونه (اختیاری ولی مفید برای ارزیابی کیفی)
outp = f"/tmp/gb_out_{wandb.run.id}.jsonl"
save_jsonl(rows_out, outp)
art = wandb.Artifact("gb-sample", type="dataset")
art.add_file(outp)
wandb.log_artifact(art)
def run_sweep(
data_path: str,
text_key: str = DEFAULT_TEXT_KEY,
max_samples: int = 120,
batch_size: int = 2,
project: str = "mahoon-legal-ai",
entity: str = None,
count: int = 16
):
os.environ.setdefault("WANDB_PROJECT", project)
if entity: os.environ.setdefault("WANDB_ENTITY", entity)
# ایجاد Sweep
sweep_id = wandb.sweep(SWEEP_CONFIG, project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity))
def _agent():
wandb.init(project=os.getenv("WANDB_PROJECT", project),
entity=os.getenv("WANDB_ENTITY", entity),
name="weights-tune")
run_once(data_path=data_path, text_key=text_key, max_samples=max_samples, batch_size=batch_size)
# اجرای تعداد مشخصی Agent-run
wandb.agent(sweep_id, function=_agent, count=count)
if __name__ == "__main__":
# اجرای خط فرمان/محلی:
# export WANDB_API_KEY=<توکن واقعی>
# python weights_sweep.py
data = os.getenv("TUNE_DATA", "./sample.jsonl")
text_key = os.getenv("TUNE_TEXT_KEY", DEFAULT_TEXT_KEY)
max_samples = int(os.getenv("TUNE_MAX_SAMPLES", "120"))
count = int(os.getenv("TUNE_COUNT", "16"))
batch_size = int(os.getenv("TUNE_BATCH", "2"))
project = os.getenv("WANDB_PROJECT", "mahoon-legal-ai")
entity = os.getenv("WANDB_ENTITY", None)
run_sweep(
data_path=data,
text_key=text_key,
max_samples=max_samples,
batch_size=batch_size,
project=project,
entity=entity,
count=count
)