hajimammad commited on
Commit
0f23dca
·
verified ·
1 Parent(s): e606eca

Upload weights_sweep.py

Browse files
Files changed (1) hide show
  1. weights_sweep.py +151 -0
weights_sweep.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ weights_sweep.py — Auto-tuning legal tag weights via W&B Sweeps
4
+
5
+ کارکرد اجرایی:
6
+ - هر ران: وزن‌ها از config → نوشتن legal_entity_weights.json → اجرای GoldenBuilder روی یک زیرمجموعه کوچک →
7
+ محاسبهٔ pass_rate (نرخ قبولی گیت کیفیت) → لاگ متریک‌ها و آرتیفکت‌ها روی W&B.
8
+ - در پایان Sweep، از داشبورد W&B «بهترین Run» را انتخاب کنید و وزن‌ها را تثبیت نمایید
9
+ (یا از آرتیفکت همان Run دانلود کنید و جایگزین legal_entity_weights.json نمایید).
10
+
11
+ پیش‌نیاز:
12
+ - فایل golden_builder.py در ریشه ریپو
13
+ - Secrets: WANDB_API_KEY در HF Spaces
14
+ - requirements: wandb, transformers, torch, ...
15
+
16
+ پارامترها (قابل‌تنظیم از طریق env یا UI):
17
+ - TUNE_DATA: مسیر فایل JSON/JSONL داده
18
+ - TUNE_TEXT_KEY: کلید متن در داده (پیش‌فرض "متن_کامل")
19
+ - TUNE_MAX_SAMPLES: تعداد نمونهٔ کوچک برای هر ران (پیش‌فرض 120)
20
+ - TUNE_BATCH: batch size Builder (پیش‌فرض 2)
21
+ - TUNE_COUNT: تعداد ران در sweep (پیش‌فرض 16)
22
+ - WANDB_PROJECT, WANDB_ENTITY: پروژه/ورک‌اسپیس W&B
23
+ """
24
+
25
+ import os
26
+ import json
27
+ import random
28
+ from typing import Dict, List
29
+
30
+ import wandb
31
+
32
+ # فضای جست‌وجو: در صورت نیاز بازه‌ها را سخت‌گیرانه‌تر/وسیع‌تر کنید
33
+ SWEEP_CONFIG = {
34
+ "method": "bayes", # "random" یا "grid" هم قابل استفاده است
35
+ "metric": {"name": "pass_rate", "goal": "maximize"},
36
+ "parameters": {
37
+ "STATUTE": {"min": 0.8, "max": 1.4},
38
+ "COURT": {"min": 0.6, "max": 1.2},
39
+ "CRIME": {"min": 0.9, "max": 1.6},
40
+ "CIVIL": {"min": 0.5, "max": 1.2},
41
+ "PROCED": {"min": 0.5, "max": 1.0},
42
+ "PARTY": {"min": 0.4, "max": 0.9},
43
+ "BUSINESS": {"min": 0.4, "max": 0.9},
44
+ }
45
+ }
46
+
47
+ DEFAULT_TEXT_KEY = "متن_کامل"
48
+
49
+ def write_weights_file(weights: Dict[str, float], path: str = "legal_entity_weights.json"):
50
+ with open(path, "w", encoding="utf-8") as f:
51
+ json.dump({k: float(v) for k, v in weights.items()}, f, ensure_ascii=False, indent=2)
52
+
53
+ def sample_data(path: str, text_key: str, max_samples: int) -> List[dict]:
54
+ from golden_builder import load_json_or_jsonl
55
+ data = load_json_or_jsonl(path)
56
+ data = [r for r in data if isinstance(r, dict) and text_key in r and isinstance(r[text_key], str) and len(r[text_key].strip()) > 20]
57
+ random.shuffle(data)
58
+ return data[:max_samples]
59
+
60
+ def run_once(data_path: str, text_key: str, max_samples: int, batch_size: int):
61
+ """
62
+ یک اجرای واحد Agent: وزن‌ها ← Builder → pass_rate
63
+ """
64
+ cfg = wandb.config
65
+ weights = {
66
+ "STATUTE": cfg.STATUTE,
67
+ "COURT": cfg.COURT,
68
+ "CRIME": cfg.CRIME,
69
+ "CIVIL": cfg.CIVIL,
70
+ "PROCED": cfg.PROCED,
71
+ "PARTY": cfg.PARTY,
72
+ "BUSINESS": cfg.BUSINESS,
73
+ }
74
+ write_weights_file(weights) # این فایل توسط GoldenBuilder خوانده می‌شود
75
+
76
+ from golden_builder import GoldenBuilder, save_jsonl
77
+ rows_in = sample_data(data_path, text_key, max_samples=max_samples)
78
+
79
+ if not rows_in:
80
+ wandb.log({"pass_rate": 0.0, "kept": 0, "processed": 0})
81
+ wandb.summary.update({"weights": weights})
82
+ return
83
+
84
+ # برای سرعت/پایداری: mt5-base کافی است؛ اگر مدل دیگری می‌خواهید، پارامتر کنید
85
+ gb = GoldenBuilder(model_name="google/mt5-base")
86
+ rows_out = gb.build(rows_in, text_key=text_key, batch_size=batch_size)
87
+
88
+ processed = len(rows_in)
89
+ kept = len(rows_out)
90
+ pass_rate = kept / max(processed, 1)
91
+
92
+ # لاگ متریک‌ها + وزن‌ها
93
+ wandb.log({
94
+ "pass_rate": pass_rate,
95
+ "kept": kept,
96
+ "processed": processed
97
+ })
98
+ wandb.summary.update({"weights": weights})
99
+
100
+ # آرتیفکت خروجی نمونه (اختیاری ولی مفید برای ارزیابی کیفی)
101
+ outp = f"/tmp/gb_out_{wandb.run.id}.jsonl"
102
+ save_jsonl(rows_out, outp)
103
+ art = wandb.Artifact("gb-sample", type="dataset")
104
+ art.add_file(outp)
105
+ wandb.log_artifact(art)
106
+
107
+ def run_sweep(
108
+ data_path: str,
109
+ text_key: str = DEFAULT_TEXT_KEY,
110
+ max_samples: int = 120,
111
+ batch_size: int = 2,
112
+ project: str = "mahoon-legal-ai",
113
+ entity: str = None,
114
+ count: int = 16
115
+ ):
116
+ os.environ.setdefault("WANDB_PROJECT", project)
117
+ if entity: os.environ.setdefault("WANDB_ENTITY", entity)
118
+
119
+ # ایجاد Sweep
120
+ sweep_id = wandb.sweep(SWEEP_CONFIG, project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity))
121
+
122
+ def _agent():
123
+ wandb.init(project=os.getenv("WANDB_PROJECT", project),
124
+ entity=os.getenv("WANDB_ENTITY", entity),
125
+ name="weights-tune")
126
+ run_once(data_path=data_path, text_key=text_key, max_samples=max_samples, batch_size=batch_size)
127
+
128
+ # اجرای تعداد مشخصی Agent-run
129
+ wandb.agent(sweep_id, function=_agent, count=count)
130
+
131
+ if __name__ == "__main__":
132
+ # اجرای خط فرمان/محلی:
133
+ # export WANDB_API_KEY=<توکن واقعی>
134
+ # python weights_sweep.py
135
+ data = os.getenv("TUNE_DATA", "./sample.jsonl")
136
+ text_key = os.getenv("TUNE_TEXT_KEY", DEFAULT_TEXT_KEY)
137
+ max_samples = int(os.getenv("TUNE_MAX_SAMPLES", "120"))
138
+ count = int(os.getenv("TUNE_COUNT", "16"))
139
+ batch_size = int(os.getenv("TUNE_BATCH", "2"))
140
+ project = os.getenv("WANDB_PROJECT", "mahoon-legal-ai")
141
+ entity = os.getenv("WANDB_ENTITY", None)
142
+
143
+ run_sweep(
144
+ data_path=data,
145
+ text_key=text_key,
146
+ max_samples=max_samples,
147
+ batch_size=batch_size,
148
+ project=project,
149
+ entity=entity,
150
+ count=count
151
+ )