hajimammad commited on
Commit
c35b21c
·
verified ·
1 Parent(s): 37f7902

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -20
app.py CHANGED
@@ -1,19 +1,21 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B-integrated Training
4
  - پاسخ‌زایی: Qwen2.5-7B, Llama-3.1-8B, Mistral-7B (همه causal)
5
  - RAG: Chroma + BM25 + CrossEncoder reranker (gte-multilingual-reranker-base)
6
- - Dataset: Builder (بر اساس golden_builder شما) + Cleaner/Deduper
7
  - Training: SFT/LoRA سبک روی causal + W&B logging/Artifacts
8
- - UI: Gradio 5.47.0 (چهار تب: مشاوره، ایندکس قوانین، ساخت دیتاست، پاکسازی دیتاست، آموزش)
 
9
 
 
10
  """
11
 
12
  from __future__ import annotations
13
  import os, sys, re, json, time, pickle, zipfile, warnings
14
  from dataclasses import dataclass, field
15
  from pathlib import Path
16
- from typing import List, Dict, Optional, Tuple
17
 
18
  import numpy as np
19
  import torch
@@ -21,7 +23,6 @@ from torch.utils.data import Dataset
21
  from sklearn.model_selection import train_test_split
22
 
23
  import gradio as gr
24
- from packaging import version
25
  warnings.filterwarnings("ignore")
26
 
27
  # ====== ML & NLP ======
@@ -35,7 +36,6 @@ from transformers import (
35
  import chromadb
36
  from rank_bm25 import BM25Okapi
37
  from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
38
- from langdetect import detect
39
 
40
  # ========= Persian text normalization =========
41
  ZWNJ = "\u200c"
@@ -50,7 +50,7 @@ def normalize_fa(s: str) -> str:
50
  s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) # حذف اعراب
51
  trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
52
  s = s.translate(trans)
53
- s = re.sub(r"\s*‌\s*", ZWNJ, s) # نرمال‌سازی ZWNJ
54
  s = re.sub(r"\s+", " ", s).strip()
55
  return s
56
 
@@ -79,9 +79,9 @@ class RAGConfig:
79
 
80
  @dataclass
81
  class TrainConfig:
82
- base_model: str = "PartAI/Dorna-Llama3-8B-Instruct" # قابل تغییر در UI
83
- alt_model_1: str = "zpm/Llama-3.1-PersianQA" # قابل تغییر در UI
84
- hakim_model: str = "AI-Hoosh/HAKIM-7B" # به‌روزرسانی در UI
85
  hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct"
86
  output_dir: str = "./mahoon_causal_lora"
87
  seed: int = 42
@@ -96,9 +96,9 @@ class TrainConfig:
96
  eval_strategy: str = "epoch"
97
  save_strategy: str = "epoch"
98
  save_total_limit: int = 2
99
- report_to: str = "wandb" # W&B
100
  max_grad_norm: float = 1.0
101
- use_4bit: bool = True # QLoRA 4-bit
102
  max_seq_len: int = 2048
103
 
104
  @dataclass
@@ -108,7 +108,7 @@ class SystemConfig:
108
  train: TrainConfig = field(default_factory=TrainConfig)
109
 
110
  # ==========================
111
- # Utils & deps logging
112
  # ==========================
113
  def set_seed_all(seed: int = 42):
114
  import random
@@ -433,7 +433,6 @@ class TrainerManager:
433
  max_grad_norm=self.cfg.train.max_grad_norm,
434
  )
435
 
436
- # ---------- Trainer + W&B callback ----------
437
  callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
438
  try:
439
  if use_wandb:
@@ -451,7 +450,7 @@ class TrainerManager:
451
  callbacks=callbacks,
452
  )
453
 
454
- # Optional manual init for richer metadata
455
  if use_wandb:
456
  try:
457
  import wandb
@@ -474,7 +473,6 @@ class TrainerManager:
474
  trainer.save_model(self.cfg.train.output_dir)
475
  self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
476
 
477
- # Log artifacts to W&B
478
  if use_wandb:
479
  try:
480
  import wandb
@@ -588,7 +586,7 @@ class LegalApp:
588
  set_seed_all(self.scfg.train.seed)
589
 
590
  progress(0.30, desc="آماده‌سازی دیتاست‌ها و RAG (اختیاری)")
591
- out = tm.train_causal(
592
  paths, use_rag=use_rag, use_wandb=use_wandb,
593
  wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name
594
  )
@@ -598,7 +596,10 @@ class LegalApp:
598
 
599
  # Dataset Builder (از ماژول شما)
600
  def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None):
601
- from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
 
 
 
602
  path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
603
  if not path: return None, "⚠️ فایل ورودی معتبر نیست."
604
  try:
@@ -613,6 +614,24 @@ class LegalApp:
613
  except Exception as e:
614
  return None, f"❌ خطا در ساخت دیتاست: {e}"
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  # UI
617
  def build_ui(self):
618
  log_deps()
@@ -631,7 +650,7 @@ class LegalApp:
631
  gr.Markdown("""
632
  <div style='text-align:center;padding:18px'>
633
  <h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only)</h1>
634
- <p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training</p>
635
  </div>
636
  """)
637
 
@@ -725,7 +744,7 @@ class LegalApp:
725
  wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
726
  wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
727
  run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name")
728
- gr.Markdown("**راهنمای توکن W&B**: در Settings → Secrets مقدار `WANDB_API_KEY` را برابر با **🟡** قرار دهید.")
729
 
730
  train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
731
  with gr.Row():
@@ -735,6 +754,19 @@ class LegalApp:
735
  train_btn = gr.Button("شروع آموزش", variant="primary")
736
  train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  # ---- Events ----
739
  def _resolve_gen(choice: str, override: str) -> str:
740
  return override.strip() if override.strip() else default_gen_models[choice]
@@ -778,6 +810,26 @@ class LegalApp:
778
  outputs=train_status
779
  )
780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  return app
782
 
783
  # ==========================
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B Training + Weight Tuning
4
  - پاسخ‌زایی: Qwen2.5-7B, Llama-3.1-8B, Mistral-7B (همه causal)
5
  - RAG: Chroma + BM25 + CrossEncoder reranker (gte-multilingual-reranker-base)
6
+ - Dataset Ops: Builder (از golden_builder) + Cleaner/Deduper
7
  - Training: SFT/LoRA سبک روی causal + W&B logging/Artifacts
8
+ - Tuning: Weight Tuning با W&B Sweep (weights_sweep.py)
9
+ - UI: Gradio 5.47.0
10
 
11
+ نکته: در Settings → Secrets مقدار `WANDB_API_KEY` را ست کنید (مقدار واقعی؛ placeholder 🟡 نگذارید).
12
  """
13
 
14
  from __future__ import annotations
15
  import os, sys, re, json, time, pickle, zipfile, warnings
16
  from dataclasses import dataclass, field
17
  from pathlib import Path
18
+ from typing import List, Dict, Optional
19
 
20
  import numpy as np
21
  import torch
 
23
  from sklearn.model_selection import train_test_split
24
 
25
  import gradio as gr
 
26
  warnings.filterwarnings("ignore")
27
 
28
  # ====== ML & NLP ======
 
36
  import chromadb
37
  from rank_bm25 import BM25Okapi
38
  from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
 
39
 
40
  # ========= Persian text normalization =========
41
  ZWNJ = "\u200c"
 
50
  s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) # حذف اعراب
51
  trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
52
  s = s.translate(trans)
53
+ s = re.sub(r"\s*‌\s*", ZWNJ, s) # ZWNJ
54
  s = re.sub(r"\s+", " ", s).strip()
55
  return s
56
 
 
79
 
80
  @dataclass
81
  class TrainConfig:
82
+ base_model: str = "PartAI/Dorna-Llama3-8B-Instruct"
83
+ alt_model_1: str = "zpm/Llama-3.1-PersianQA"
84
+ hakim_model: str = "AI-Hoosh/HAKIM-7B"
85
  hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct"
86
  output_dir: str = "./mahoon_causal_lora"
87
  seed: int = 42
 
96
  eval_strategy: str = "epoch"
97
  save_strategy: str = "epoch"
98
  save_total_limit: int = 2
99
+ report_to: str = "wandb" # W&B
100
  max_grad_norm: float = 1.0
101
+ use_4bit: bool = True # QLoRA 4-bit (در صورت افزودن PEFT/TRL)
102
  max_seq_len: int = 2048
103
 
104
  @dataclass
 
108
  train: TrainConfig = field(default_factory=TrainConfig)
109
 
110
  # ==========================
111
+ # Helpers
112
  # ==========================
113
  def set_seed_all(seed: int = 42):
114
  import random
 
433
  max_grad_norm=self.cfg.train.max_grad_norm,
434
  )
435
 
 
436
  callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
437
  try:
438
  if use_wandb:
 
450
  callbacks=callbacks,
451
  )
452
 
453
+ # Optional richer W&B init
454
  if use_wandb:
455
  try:
456
  import wandb
 
473
  trainer.save_model(self.cfg.train.output_dir)
474
  self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
475
 
 
476
  if use_wandb:
477
  try:
478
  import wandb
 
586
  set_seed_all(self.scfg.train.seed)
587
 
588
  progress(0.30, desc="آماده‌سازی دیتاست‌ها و RAG (اختیاری)")
589
+ tm.train_causal(
590
  paths, use_rag=use_rag, use_wandb=use_wandb,
591
  wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name
592
  )
 
596
 
597
  # Dataset Builder (از ماژول شما)
598
  def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None):
599
+ try:
600
+ from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
601
+ except Exception as e:
602
+ return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}"
603
  path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
604
  if not path: return None, "⚠️ فایل ورودی معتبر نیست."
605
  try:
 
614
  except Exception as e:
615
  return None, f"❌ خطا در ساخت دیتاست: {e}"
616
 
617
+ # Weight Tuning (W&B Sweep)
618
+ def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent):
619
+ p = getattr(f, "name", None) or getattr(f, "path", None)
620
+ if not p:
621
+ return "⚠️ فایل داده نامعتبر است."
622
+ try:
623
+ from weights_sweep import run_sweep
624
+ except Exception as e:
625
+ return f"❌ weights_sweep.py یافت نشد/قابل import نیست: {e}"
626
+ os.environ.setdefault("WANDB_PROJECT", proj or "mahoon-legal-ai")
627
+ if ent: os.environ.setdefault("WANDB_ENTITY", ent)
628
+ try:
629
+ run_sweep(data_path=p, text_key=tk, max_samples=int(ms), batch_size=int(bs),
630
+ project=proj, entity=ent, count=int(runs))
631
+ return "✅ Sweep اجرا شد. بهترین Run را در W&B بررسی و وزن‌ها را تثبیت کنید."
632
+ except Exception as e:
633
+ return f"❌ خطا در اجرای Sweep: {e}"
634
+
635
  # UI
636
  def build_ui(self):
637
  log_deps()
 
650
  gr.Markdown("""
651
  <div style='text-align:center;padding:18px'>
652
  <h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only)</h1>
653
+ <p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning</p>
654
  </div>
655
  """)
656
 
 
744
  wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
745
  wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
746
  run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name")
747
+ gr.Markdown("راهنما: در Settings → Secrets مقدار `WANDB_API_KEY` را تنظیم کنید (مقدار واقعی).")
748
 
749
  train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
750
  with gr.Row():
 
754
  train_btn = gr.Button("شروع آموزش", variant="primary")
755
  train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
756
 
757
+ # --- Tab: Weight Tuning ---
758
+ with gr.Tab("Weight Tuning"):
759
+ gr.Markdown("تیون خودکار وزن‌های موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.")
760
+ tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"])
761
+ tune_text_key = gr.Textbox(value="متن_کامل", label="کلید متن")
762
+ tune_max_samples = gr.Slider(50, 400, value=120, step=10, label="حداکثر نمونه")
763
+ tune_runs = gr.Slider(4, 64, value=16, step=4, label="تعداد ران Sweep")
764
+ tune_batch = gr.Slider(1, 4, value=2, step=1, label="batch size Builder")
765
+ tune_proj = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
766
+ tune_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
767
+ run_tune = gr.Button("شروع Sweep", variant="primary")
768
+ tune_status = gr.Markdown()
769
+
770
  # ---- Events ----
771
  def _resolve_gen(choice: str, override: str) -> str:
772
  return override.strip() if override.strip() else default_gen_models[choice]
 
810
  outputs=train_status
811
  )
812
 
813
+ clean_btn.click(
814
+ lambda f, th: (
815
+ (lambda _p, _out:
816
+ ( _out,
817
+ f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{deduplicate_jsonl(_p, _out, sim_threshold=float(th))}**" )
818
+ )(
819
+ getattr(f, "name", None) or getattr(f, "path", None),
820
+ f"/tmp/cleaned_{int(time.time())}.jsonl"
821
+ ) if (getattr(f, 'name', None) or getattr(f, 'path', None)) else (None, "⚠️ فایل نامعتبر.")
822
+ ),
823
+ inputs=[raw_ds, sim_th],
824
+ outputs=[cleaned_out, clean_status]
825
+ )
826
+
827
+ run_tune.click(
828
+ lambda f, tk, ms, runs, bs, proj, ent: self.run_weight_tune(f, tk, ms, runs, bs, proj, ent),
829
+ inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity],
830
+ outputs=tune_status
831
+ )
832
+
833
  return app
834
 
835
  # ==========================