Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B
|
| 4 |
- پاسخزایی: Qwen2.5-7B, Llama-3.1-8B, Mistral-7B (همه causal)
|
| 5 |
- RAG: Chroma + BM25 + CrossEncoder reranker (gte-multilingual-reranker-base)
|
| 6 |
-
- Dataset: Builder (
|
| 7 |
- Training: SFT/LoRA سبک روی causal + W&B logging/Artifacts
|
| 8 |
-
-
|
|
|
|
| 9 |
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
import os, sys, re, json, time, pickle, zipfile, warnings
|
| 14 |
from dataclasses import dataclass, field
|
| 15 |
from pathlib import Path
|
| 16 |
-
from typing import List, Dict, Optional
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
@@ -21,7 +23,6 @@ from torch.utils.data import Dataset
|
|
| 21 |
from sklearn.model_selection import train_test_split
|
| 22 |
|
| 23 |
import gradio as gr
|
| 24 |
-
from packaging import version
|
| 25 |
warnings.filterwarnings("ignore")
|
| 26 |
|
| 27 |
# ====== ML & NLP ======
|
|
@@ -35,7 +36,6 @@ from transformers import (
|
|
| 35 |
import chromadb
|
| 36 |
from rank_bm25 import BM25Okapi
|
| 37 |
from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
|
| 38 |
-
from langdetect import detect
|
| 39 |
|
| 40 |
# ========= Persian text normalization =========
|
| 41 |
ZWNJ = "\u200c"
|
|
@@ -50,7 +50,7 @@ def normalize_fa(s: str) -> str:
|
|
| 50 |
s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) # حذف اعراب
|
| 51 |
trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
|
| 52 |
s = s.translate(trans)
|
| 53 |
-
s = re.sub(r"\s*\s*", ZWNJ, s)
|
| 54 |
s = re.sub(r"\s+", " ", s).strip()
|
| 55 |
return s
|
| 56 |
|
|
@@ -79,9 +79,9 @@ class RAGConfig:
|
|
| 79 |
|
| 80 |
@dataclass
|
| 81 |
class TrainConfig:
|
| 82 |
-
base_model: str = "PartAI/Dorna-Llama3-8B-Instruct"
|
| 83 |
-
alt_model_1: str = "zpm/Llama-3.1-PersianQA"
|
| 84 |
-
hakim_model: str = "AI-Hoosh/HAKIM-7B"
|
| 85 |
hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct"
|
| 86 |
output_dir: str = "./mahoon_causal_lora"
|
| 87 |
seed: int = 42
|
|
@@ -96,9 +96,9 @@ class TrainConfig:
|
|
| 96 |
eval_strategy: str = "epoch"
|
| 97 |
save_strategy: str = "epoch"
|
| 98 |
save_total_limit: int = 2
|
| 99 |
-
report_to: str = "wandb" #
|
| 100 |
max_grad_norm: float = 1.0
|
| 101 |
-
use_4bit: bool = True # QLoRA 4-bit
|
| 102 |
max_seq_len: int = 2048
|
| 103 |
|
| 104 |
@dataclass
|
|
@@ -108,7 +108,7 @@ class SystemConfig:
|
|
| 108 |
train: TrainConfig = field(default_factory=TrainConfig)
|
| 109 |
|
| 110 |
# ==========================
|
| 111 |
-
#
|
| 112 |
# ==========================
|
| 113 |
def set_seed_all(seed: int = 42):
|
| 114 |
import random
|
|
@@ -433,7 +433,6 @@ class TrainerManager:
|
|
| 433 |
max_grad_norm=self.cfg.train.max_grad_norm,
|
| 434 |
)
|
| 435 |
|
| 436 |
-
# ---------- Trainer + W&B callback ----------
|
| 437 |
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
|
| 438 |
try:
|
| 439 |
if use_wandb:
|
|
@@ -451,7 +450,7 @@ class TrainerManager:
|
|
| 451 |
callbacks=callbacks,
|
| 452 |
)
|
| 453 |
|
| 454 |
-
# Optional
|
| 455 |
if use_wandb:
|
| 456 |
try:
|
| 457 |
import wandb
|
|
@@ -474,7 +473,6 @@ class TrainerManager:
|
|
| 474 |
trainer.save_model(self.cfg.train.output_dir)
|
| 475 |
self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
|
| 476 |
|
| 477 |
-
# Log artifacts to W&B
|
| 478 |
if use_wandb:
|
| 479 |
try:
|
| 480 |
import wandb
|
|
@@ -588,7 +586,7 @@ class LegalApp:
|
|
| 588 |
set_seed_all(self.scfg.train.seed)
|
| 589 |
|
| 590 |
progress(0.30, desc="آمادهسازی دیتاستها و RAG (اختیاری)")
|
| 591 |
-
|
| 592 |
paths, use_rag=use_rag, use_wandb=use_wandb,
|
| 593 |
wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name
|
| 594 |
)
|
|
@@ -598,7 +596,10 @@ class LegalApp:
|
|
| 598 |
|
| 599 |
# Dataset Builder (از ماژول شما)
|
| 600 |
def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None):
|
| 601 |
-
|
|
|
|
|
|
|
|
|
|
| 602 |
path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
|
| 603 |
if not path: return None, "⚠️ فایل ورودی معتبر نیست."
|
| 604 |
try:
|
|
@@ -613,6 +614,24 @@ class LegalApp:
|
|
| 613 |
except Exception as e:
|
| 614 |
return None, f"❌ خطا در ساخت دیتاست: {e}"
|
| 615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
# UI
|
| 617 |
def build_ui(self):
|
| 618 |
log_deps()
|
|
@@ -631,7 +650,7 @@ class LegalApp:
|
|
| 631 |
gr.Markdown("""
|
| 632 |
<div style='text-align:center;padding:18px'>
|
| 633 |
<h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only)</h1>
|
| 634 |
-
<p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training</p>
|
| 635 |
</div>
|
| 636 |
""")
|
| 637 |
|
|
@@ -725,7 +744,7 @@ class LegalApp:
|
|
| 725 |
wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
|
| 726 |
wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
|
| 727 |
run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name")
|
| 728 |
-
gr.Markdown("
|
| 729 |
|
| 730 |
train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
|
| 731 |
with gr.Row():
|
|
@@ -735,6 +754,19 @@ class LegalApp:
|
|
| 735 |
train_btn = gr.Button("شروع آموزش", variant="primary")
|
| 736 |
train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
|
| 737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
# ---- Events ----
|
| 739 |
def _resolve_gen(choice: str, override: str) -> str:
|
| 740 |
return override.strip() if override.strip() else default_gen_models[choice]
|
|
@@ -778,6 +810,26 @@ class LegalApp:
|
|
| 778 |
outputs=train_status
|
| 779 |
)
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
return app
|
| 782 |
|
| 783 |
# ==========================
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B Training + Weight Tuning
|
| 4 |
- پاسخزایی: Qwen2.5-7B, Llama-3.1-8B, Mistral-7B (همه causal)
|
| 5 |
- RAG: Chroma + BM25 + CrossEncoder reranker (gte-multilingual-reranker-base)
|
| 6 |
+
- Dataset Ops: Builder (از golden_builder) + Cleaner/Deduper
|
| 7 |
- Training: SFT/LoRA سبک روی causal + W&B logging/Artifacts
|
| 8 |
+
- Tuning: Weight Tuning با W&B Sweep (weights_sweep.py)
|
| 9 |
+
- UI: Gradio 5.47.0
|
| 10 |
|
| 11 |
+
نکته: در Settings → Secrets مقدار `WANDB_API_KEY` را ست کنید (مقدار واقعی؛ placeholder 🟡 نگذارید).
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
import os, sys, re, json, time, pickle, zipfile, warnings
|
| 16 |
from dataclasses import dataclass, field
|
| 17 |
from pathlib import Path
|
| 18 |
+
from typing import List, Dict, Optional
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import torch
|
|
|
|
| 23 |
from sklearn.model_selection import train_test_split
|
| 24 |
|
| 25 |
import gradio as gr
|
|
|
|
| 26 |
warnings.filterwarnings("ignore")
|
| 27 |
|
| 28 |
# ====== ML & NLP ======
|
|
|
|
| 36 |
import chromadb
|
| 37 |
from rank_bm25 import BM25Okapi
|
| 38 |
from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
|
|
|
|
| 39 |
|
| 40 |
# ========= Persian text normalization =========
|
| 41 |
ZWNJ = "\u200c"
|
|
|
|
| 50 |
s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) # حذف اعراب
|
| 51 |
trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
|
| 52 |
s = s.translate(trans)
|
| 53 |
+
s = re.sub(r"\s*\s*", ZWNJ, s) # ZWNJ
|
| 54 |
s = re.sub(r"\s+", " ", s).strip()
|
| 55 |
return s
|
| 56 |
|
|
|
|
| 79 |
|
| 80 |
@dataclass
|
| 81 |
class TrainConfig:
|
| 82 |
+
base_model: str = "PartAI/Dorna-Llama3-8B-Instruct"
|
| 83 |
+
alt_model_1: str = "zpm/Llama-3.1-PersianQA"
|
| 84 |
+
hakim_model: str = "AI-Hoosh/HAKIM-7B"
|
| 85 |
hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct"
|
| 86 |
output_dir: str = "./mahoon_causal_lora"
|
| 87 |
seed: int = 42
|
|
|
|
| 96 |
eval_strategy: str = "epoch"
|
| 97 |
save_strategy: str = "epoch"
|
| 98 |
save_total_limit: int = 2
|
| 99 |
+
report_to: str = "wandb" # W&B
|
| 100 |
max_grad_norm: float = 1.0
|
| 101 |
+
use_4bit: bool = True # QLoRA 4-bit (در صورت افزودن PEFT/TRL)
|
| 102 |
max_seq_len: int = 2048
|
| 103 |
|
| 104 |
@dataclass
|
|
|
|
| 108 |
train: TrainConfig = field(default_factory=TrainConfig)
|
| 109 |
|
| 110 |
# ==========================
|
| 111 |
+
# Helpers
|
| 112 |
# ==========================
|
| 113 |
def set_seed_all(seed: int = 42):
|
| 114 |
import random
|
|
|
|
| 433 |
max_grad_norm=self.cfg.train.max_grad_norm,
|
| 434 |
)
|
| 435 |
|
|
|
|
| 436 |
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
|
| 437 |
try:
|
| 438 |
if use_wandb:
|
|
|
|
| 450 |
callbacks=callbacks,
|
| 451 |
)
|
| 452 |
|
| 453 |
+
# Optional richer W&B init
|
| 454 |
if use_wandb:
|
| 455 |
try:
|
| 456 |
import wandb
|
|
|
|
| 473 |
trainer.save_model(self.cfg.train.output_dir)
|
| 474 |
self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
|
| 475 |
|
|
|
|
| 476 |
if use_wandb:
|
| 477 |
try:
|
| 478 |
import wandb
|
|
|
|
| 586 |
set_seed_all(self.scfg.train.seed)
|
| 587 |
|
| 588 |
progress(0.30, desc="آمادهسازی دیتاستها و RAG (اختیاری)")
|
| 589 |
+
tm.train_causal(
|
| 590 |
paths, use_rag=use_rag, use_wandb=use_wandb,
|
| 591 |
wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name
|
| 592 |
)
|
|
|
|
| 596 |
|
| 597 |
# Dataset Builder (از ماژول شما)
|
| 598 |
def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None):
|
| 599 |
+
try:
|
| 600 |
+
from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
|
| 601 |
+
except Exception as e:
|
| 602 |
+
return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}"
|
| 603 |
path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
|
| 604 |
if not path: return None, "⚠️ فایل ورودی معتبر نیست."
|
| 605 |
try:
|
|
|
|
| 614 |
except Exception as e:
|
| 615 |
return None, f"❌ خطا در ساخت دیتاست: {e}"
|
| 616 |
|
| 617 |
+
# Weight Tuning (W&B Sweep)
|
| 618 |
+
def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent):
|
| 619 |
+
p = getattr(f, "name", None) or getattr(f, "path", None)
|
| 620 |
+
if not p:
|
| 621 |
+
return "⚠️ فایل داده نامعتبر است."
|
| 622 |
+
try:
|
| 623 |
+
from weights_sweep import run_sweep
|
| 624 |
+
except Exception as e:
|
| 625 |
+
return f"❌ weights_sweep.py یافت نشد/قابل import نیست: {e}"
|
| 626 |
+
os.environ.setdefault("WANDB_PROJECT", proj or "mahoon-legal-ai")
|
| 627 |
+
if ent: os.environ.setdefault("WANDB_ENTITY", ent)
|
| 628 |
+
try:
|
| 629 |
+
run_sweep(data_path=p, text_key=tk, max_samples=int(ms), batch_size=int(bs),
|
| 630 |
+
project=proj, entity=ent, count=int(runs))
|
| 631 |
+
return "✅ Sweep اجرا شد. بهترین Run را در W&B بررسی و وزنها را تثبیت کنید."
|
| 632 |
+
except Exception as e:
|
| 633 |
+
return f"❌ خطا در اجرای Sweep: {e}"
|
| 634 |
+
|
| 635 |
# UI
|
| 636 |
def build_ui(self):
|
| 637 |
log_deps()
|
|
|
|
| 650 |
gr.Markdown("""
|
| 651 |
<div style='text-align:center;padding:18px'>
|
| 652 |
<h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only)</h1>
|
| 653 |
+
<p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning</p>
|
| 654 |
</div>
|
| 655 |
""")
|
| 656 |
|
|
|
|
| 744 |
wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
|
| 745 |
wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
|
| 746 |
run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name")
|
| 747 |
+
gr.Markdown("راهنما: در Settings → Secrets مقدار `WANDB_API_KEY` را تنظیم کنید (مقدار واقعی).")
|
| 748 |
|
| 749 |
train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
|
| 750 |
with gr.Row():
|
|
|
|
| 754 |
train_btn = gr.Button("شروع آموزش", variant="primary")
|
| 755 |
train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
|
| 756 |
|
| 757 |
+
# --- Tab: Weight Tuning ---
|
| 758 |
+
with gr.Tab("Weight Tuning"):
|
| 759 |
+
gr.Markdown("تیون خودکار وزنهای موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.")
|
| 760 |
+
tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"])
|
| 761 |
+
tune_text_key = gr.Textbox(value="متن_کامل", label="کلید متن")
|
| 762 |
+
tune_max_samples = gr.Slider(50, 400, value=120, step=10, label="حداکثر نمونه")
|
| 763 |
+
tune_runs = gr.Slider(4, 64, value=16, step=4, label="تعداد ران Sweep")
|
| 764 |
+
tune_batch = gr.Slider(1, 4, value=2, step=1, label="batch size Builder")
|
| 765 |
+
tune_proj = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
|
| 766 |
+
tune_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
|
| 767 |
+
run_tune = gr.Button("شروع Sweep", variant="primary")
|
| 768 |
+
tune_status = gr.Markdown()
|
| 769 |
+
|
| 770 |
# ---- Events ----
|
| 771 |
def _resolve_gen(choice: str, override: str) -> str:
|
| 772 |
return override.strip() if override.strip() else default_gen_models[choice]
|
|
|
|
| 810 |
outputs=train_status
|
| 811 |
)
|
| 812 |
|
| 813 |
+
clean_btn.click(
|
| 814 |
+
lambda f, th: (
|
| 815 |
+
(lambda _p, _out:
|
| 816 |
+
( _out,
|
| 817 |
+
f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{deduplicate_jsonl(_p, _out, sim_threshold=float(th))}**" )
|
| 818 |
+
)(
|
| 819 |
+
getattr(f, "name", None) or getattr(f, "path", None),
|
| 820 |
+
f"/tmp/cleaned_{int(time.time())}.jsonl"
|
| 821 |
+
) if (getattr(f, 'name', None) or getattr(f, 'path', None)) else (None, "⚠️ فایل نامعتبر.")
|
| 822 |
+
),
|
| 823 |
+
inputs=[raw_ds, sim_th],
|
| 824 |
+
outputs=[cleaned_out, clean_status]
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
run_tune.click(
|
| 828 |
+
lambda f, tk, ms, runs, bs, proj, ent: self.run_weight_tune(f, tk, ms, runs, bs, proj, ent),
|
| 829 |
+
inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity],
|
| 830 |
+
outputs=tune_status
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
return app
|
| 834 |
|
| 835 |
# ==========================
|