Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| weights_sweep.py — Auto-tuning legal tag weights via W&B Sweeps | |
| کارکرد اجرایی: | |
| - هر ران: وزنها از config → نوشتن legal_entity_weights.json → اجرای GoldenBuilder روی یک زیرمجموعه کوچک → | |
| محاسبهٔ pass_rate (نرخ قبولی گیت کیفیت) → لاگ متریکها و آرتیفکتها روی W&B. | |
| - در پایان Sweep، از داشبورد W&B «بهترین Run» را انتخاب کنید و وزنها را تثبیت نمایید | |
| (یا از آرتیفکت همان Run دانلود کنید و جایگزین legal_entity_weights.json نمایید). | |
| پیشنیاز: | |
| - فایل golden_builder.py در ریشه ریپو | |
| - Secrets: WANDB_API_KEY در HF Spaces | |
| - requirements: wandb, transformers, torch, ... | |
| پارامترها (قابلتنظیم از طریق env یا UI): | |
| - TUNE_DATA: مسیر فایل JSON/JSONL داده | |
| - TUNE_TEXT_KEY: کلید متن در داده (پیشفرض "متن_کامل") | |
| - TUNE_MAX_SAMPLES: تعداد نمونهٔ کوچک برای هر ران (پیشفرض 120) | |
| - TUNE_BATCH: batch size Builder (پیشفرض 2) | |
| - TUNE_COUNT: تعداد ران در sweep (پیشفرض 16) | |
| - WANDB_PROJECT, WANDB_ENTITY: پروژه/ورکاسپیس W&B | |
| """ | |
| import os | |
| import json | |
| import random | |
| from typing import Dict, List | |
| import wandb | |
| # فضای جستوجو: در صورت نیاز بازهها را سختگیرانهتر/وسیعتر کنید | |
| SWEEP_CONFIG = { | |
| "method": "bayes", # "random" یا "grid" هم قابل استفاده است | |
| "metric": {"name": "pass_rate", "goal": "maximize"}, | |
| "parameters": { | |
| "STATUTE": {"min": 0.8, "max": 1.4}, | |
| "COURT": {"min": 0.6, "max": 1.2}, | |
| "CRIME": {"min": 0.9, "max": 1.6}, | |
| "CIVIL": {"min": 0.5, "max": 1.2}, | |
| "PROCED": {"min": 0.5, "max": 1.0}, | |
| "PARTY": {"min": 0.4, "max": 0.9}, | |
| "BUSINESS": {"min": 0.4, "max": 0.9}, | |
| } | |
| } | |
| DEFAULT_TEXT_KEY = "متن_کامل" | |
| def write_weights_file(weights: Dict[str, float], path: str = "legal_entity_weights.json"): | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump({k: float(v) for k, v in weights.items()}, f, ensure_ascii=False, indent=2) | |
| def sample_data(path: str, text_key: str, max_samples: int) -> List[dict]: | |
| from golden_builder import load_json_or_jsonl | |
| data = load_json_or_jsonl(path) | |
| data = [r for r in data if isinstance(r, dict) and text_key in r and isinstance(r[text_key], str) and len(r[text_key].strip()) > 20] | |
| random.shuffle(data) | |
| return data[:max_samples] | |
| def run_once(data_path: str, text_key: str, max_samples: int, batch_size: int): | |
| """ | |
| یک اجرای واحد Agent: وزنها ← Builder → pass_rate | |
| """ | |
| cfg = wandb.config | |
| weights = { | |
| "STATUTE": cfg.STATUTE, | |
| "COURT": cfg.COURT, | |
| "CRIME": cfg.CRIME, | |
| "CIVIL": cfg.CIVIL, | |
| "PROCED": cfg.PROCED, | |
| "PARTY": cfg.PARTY, | |
| "BUSINESS": cfg.BUSINESS, | |
| } | |
| write_weights_file(weights) # این فایل توسط GoldenBuilder خوانده میشود | |
| from golden_builder import GoldenBuilder, save_jsonl | |
| rows_in = sample_data(data_path, text_key, max_samples=max_samples) | |
| if not rows_in: | |
| wandb.log({"pass_rate": 0.0, "kept": 0, "processed": 0}) | |
| wandb.summary.update({"weights": weights}) | |
| return | |
| # برای سرعت/پایداری: mt5-base کافی است؛ اگر مدل دیگری میخواهید، پارامتر کنید | |
| gb = GoldenBuilder(model_name="google/mt5-base") | |
| rows_out = gb.build(rows_in, text_key=text_key, batch_size=batch_size) | |
| processed = len(rows_in) | |
| kept = len(rows_out) | |
| pass_rate = kept / max(processed, 1) | |
| # لاگ متریکها + وزنها | |
| wandb.log({ | |
| "pass_rate": pass_rate, | |
| "kept": kept, | |
| "processed": processed | |
| }) | |
| wandb.summary.update({"weights": weights}) | |
| # آرتیفکت خروجی نمونه (اختیاری ولی مفید برای ارزیابی کیفی) | |
| outp = f"/tmp/gb_out_{wandb.run.id}.jsonl" | |
| save_jsonl(rows_out, outp) | |
| art = wandb.Artifact("gb-sample", type="dataset") | |
| art.add_file(outp) | |
| wandb.log_artifact(art) | |
| def run_sweep( | |
| data_path: str, | |
| text_key: str = DEFAULT_TEXT_KEY, | |
| max_samples: int = 120, | |
| batch_size: int = 2, | |
| project: str = "mahoon-legal-ai", | |
| entity: str = None, | |
| count: int = 16 | |
| ): | |
| os.environ.setdefault("WANDB_PROJECT", project) | |
| if entity: os.environ.setdefault("WANDB_ENTITY", entity) | |
| # ایجاد Sweep | |
| sweep_id = wandb.sweep(SWEEP_CONFIG, project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity)) | |
| def _agent(): | |
| wandb.init(project=os.getenv("WANDB_PROJECT", project), | |
| entity=os.getenv("WANDB_ENTITY", entity), | |
| name="weights-tune") | |
| run_once(data_path=data_path, text_key=text_key, max_samples=max_samples, batch_size=batch_size) | |
| # اجرای تعداد مشخصی Agent-run | |
| wandb.agent(sweep_id, function=_agent, count=count) | |
| if __name__ == "__main__": | |
| # اجرای خط فرمان/محلی: | |
| # export WANDB_API_KEY=<توکن واقعی> | |
| # python weights_sweep.py | |
| data = os.getenv("TUNE_DATA", "./sample.jsonl") | |
| text_key = os.getenv("TUNE_TEXT_KEY", DEFAULT_TEXT_KEY) | |
| max_samples = int(os.getenv("TUNE_MAX_SAMPLES", "120")) | |
| count = int(os.getenv("TUNE_COUNT", "16")) | |
| batch_size = int(os.getenv("TUNE_BATCH", "2")) | |
| project = os.getenv("WANDB_PROJECT", "mahoon-legal-ai") | |
| entity = os.getenv("WANDB_ENTITY", None) | |
| run_sweep( | |
| data_path=data, | |
| text_key=text_key, | |
| max_samples=max_samples, | |
| batch_size=batch_size, | |
| project=project, | |
| entity=entity, | |
| count=count | |
| ) | |