# -*- coding: utf-8 -*- """ weights_sweep.py — Auto-tuning legal tag weights via W&B Sweeps کارکرد اجرایی: - هر ران: وزن‌ها از config → نوشتن legal_entity_weights.json → اجرای GoldenBuilder روی یک زیرمجموعه کوچک → محاسبهٔ pass_rate (نرخ قبولی گیت کیفیت) → لاگ متریک‌ها و آرتیفکت‌ها روی W&B. - در پایان Sweep، از داشبورد W&B «بهترین Run» را انتخاب کنید و وزن‌ها را تثبیت نمایید (یا از آرتیفکت همان Run دانلود کنید و جایگزین legal_entity_weights.json نمایید). پیش‌نیاز: - فایل golden_builder.py در ریشه ریپو - Secrets: WANDB_API_KEY در HF Spaces - requirements: wandb, transformers, torch, ... پارامترها (قابل‌تنظیم از طریق env یا UI): - TUNE_DATA: مسیر فایل JSON/JSONL داده - TUNE_TEXT_KEY: کلید متن در داده (پیش‌فرض "متن_کامل") - TUNE_MAX_SAMPLES: تعداد نمونهٔ کوچک برای هر ران (پیش‌فرض 120) - TUNE_BATCH: batch size Builder (پیش‌فرض 2) - TUNE_COUNT: تعداد ران در sweep (پیش‌فرض 16) - WANDB_PROJECT, WANDB_ENTITY: پروژه/ورک‌اسپیس W&B """ import os import json import random from typing import Dict, List import wandb # فضای جست‌وجو: در صورت نیاز بازه‌ها را سخت‌گیرانه‌تر/وسیع‌تر کنید SWEEP_CONFIG = { "method": "bayes", # "random" یا "grid" هم قابل استفاده است "metric": {"name": "pass_rate", "goal": "maximize"}, "parameters": { "STATUTE": {"min": 0.8, "max": 1.4}, "COURT": {"min": 0.6, "max": 1.2}, "CRIME": {"min": 0.9, "max": 1.6}, "CIVIL": {"min": 0.5, "max": 1.2}, "PROCED": {"min": 0.5, "max": 1.0}, "PARTY": {"min": 0.4, "max": 0.9}, "BUSINESS": {"min": 0.4, "max": 0.9}, } } DEFAULT_TEXT_KEY = "متن_کامل" def write_weights_file(weights: Dict[str, float], path: str = "legal_entity_weights.json"): with open(path, "w", encoding="utf-8") as f: json.dump({k: float(v) for k, v in weights.items()}, f, ensure_ascii=False, indent=2) def sample_data(path: str, text_key: str, max_samples: int) -> List[dict]: from golden_builder import load_json_or_jsonl data = load_json_or_jsonl(path) data = [r for r in data if isinstance(r, dict) and text_key in r and isinstance(r[text_key], str) and len(r[text_key].strip()) > 20] random.shuffle(data) return data[:max_samples] def run_once(data_path: str, text_key: str, max_samples: int, batch_size: int): """ یک اجرای واحد Agent: وزن‌ها ← Builder → pass_rate """ cfg = wandb.config weights = { "STATUTE": cfg.STATUTE, "COURT": cfg.COURT, "CRIME": cfg.CRIME, "CIVIL": cfg.CIVIL, "PROCED": cfg.PROCED, "PARTY": cfg.PARTY, "BUSINESS": cfg.BUSINESS, } write_weights_file(weights) # این فایل توسط GoldenBuilder خوانده می‌شود from golden_builder import GoldenBuilder, save_jsonl rows_in = sample_data(data_path, text_key, max_samples=max_samples) if not rows_in: wandb.log({"pass_rate": 0.0, "kept": 0, "processed": 0}) wandb.summary.update({"weights": weights}) return # برای سرعت/پایداری: mt5-base کافی است؛ اگر مدل دیگری می‌خواهید، پارامتر کنید gb = GoldenBuilder(model_name="google/mt5-base") rows_out = gb.build(rows_in, text_key=text_key, batch_size=batch_size) processed = len(rows_in) kept = len(rows_out) pass_rate = kept / max(processed, 1) # لاگ متریک‌ها + وزن‌ها wandb.log({ "pass_rate": pass_rate, "kept": kept, "processed": processed }) wandb.summary.update({"weights": weights}) # آرتیفکت خروجی نمونه (اختیاری ولی مفید برای ارزیابی کیفی) outp = f"/tmp/gb_out_{wandb.run.id}.jsonl" save_jsonl(rows_out, outp) art = wandb.Artifact("gb-sample", type="dataset") art.add_file(outp) wandb.log_artifact(art) def run_sweep( data_path: str, text_key: str = DEFAULT_TEXT_KEY, max_samples: int = 120, batch_size: int = 2, project: str = "mahoon-legal-ai", entity: str = None, count: int = 16 ): os.environ.setdefault("WANDB_PROJECT", project) if entity: os.environ.setdefault("WANDB_ENTITY", entity) # ایجاد Sweep sweep_id = wandb.sweep(SWEEP_CONFIG, project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity)) def _agent(): wandb.init(project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity), name="weights-tune") run_once(data_path=data_path, text_key=text_key, max_samples=max_samples, batch_size=batch_size) # اجرای تعداد مشخصی Agent-run wandb.agent(sweep_id, function=_agent, count=count) if __name__ == "__main__": # اجرای خط فرمان/محلی: # export WANDB_API_KEY=<توکن واقعی> # python weights_sweep.py data = os.getenv("TUNE_DATA", "./sample.jsonl") text_key = os.getenv("TUNE_TEXT_KEY", DEFAULT_TEXT_KEY) max_samples = int(os.getenv("TUNE_MAX_SAMPLES", "120")) count = int(os.getenv("TUNE_COUNT", "16")) batch_size = int(os.getenv("TUNE_BATCH", "2")) project = os.getenv("WANDB_PROJECT", "mahoon-legal-ai") entity = os.getenv("WANDB_ENTITY", None) run_sweep( data_path=data, text_key=text_key, max_samples=max_samples, batch_size=batch_size, project=project, entity=entity, count=count )