Spaces:
Sleeping
Sleeping
thadillo
commited on
Commit
Β·
1377fb1
1
Parent(s):
e6341fe
π Deploy to HF Spaces: Model selection + Fine-tuning updates
Browse files- Add model selection (7+ transformer models)
- Add zero-shot model selection (3 NLI models)
- Improve fine-tuning with head-only and LoRA modes
- Add training run management (export/delete)
- Configure for HF Spaces (port 7860, persistent storage)
- Update database schema for model tracking
- Add comprehensive AI model presets
- .gitignore +3 -0
- .hfignore +75 -0
- app/analyzer.py +14 -4
- app/fine_tuning/model_manager.py +17 -3
- app/fine_tuning/model_presets.py +168 -0
- app/fine_tuning/trainer.py +72 -7
- app/models/models.py +1 -1
- app/routes/admin.py +394 -9
- app/templates/admin/training.html +274 -28
- requirements.txt +2 -2
.gitignore
CHANGED
|
@@ -33,3 +33,6 @@ instance/
|
|
| 33 |
# OS
|
| 34 |
.DS_Store
|
| 35 |
Thumbs.db
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# OS
|
| 34 |
.DS_Store
|
| 35 |
Thumbs.db
|
| 36 |
+
|
| 37 |
+
# Models
|
| 38 |
+
models/finetuned/
|
.hfignore
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
venv/
|
| 8 |
+
ENV/
|
| 9 |
+
env/
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Environment files
|
| 13 |
+
.env
|
| 14 |
+
.env.*
|
| 15 |
+
|
| 16 |
+
# Local development
|
| 17 |
+
instance/
|
| 18 |
+
*.db
|
| 19 |
+
*.sqlite
|
| 20 |
+
*.sqlite3
|
| 21 |
+
|
| 22 |
+
# Models and cache (will be generated on HF)
|
| 23 |
+
models/finetuned/*
|
| 24 |
+
.cache/
|
| 25 |
+
*.pth
|
| 26 |
+
*.bin
|
| 27 |
+
*.onnx
|
| 28 |
+
|
| 29 |
+
# Git
|
| 30 |
+
.git/
|
| 31 |
+
.gitignore
|
| 32 |
+
.gitattributes
|
| 33 |
+
|
| 34 |
+
# IDE
|
| 35 |
+
.vscode/
|
| 36 |
+
.idea/
|
| 37 |
+
*.swp
|
| 38 |
+
*.swo
|
| 39 |
+
*~
|
| 40 |
+
|
| 41 |
+
# OS
|
| 42 |
+
.DS_Store
|
| 43 |
+
Thumbs.db
|
| 44 |
+
|
| 45 |
+
# Documentation (keep only README.md)
|
| 46 |
+
DEPLOYMENT.md
|
| 47 |
+
QUICKSTART.md
|
| 48 |
+
PROJECT_STRUCTURE.md
|
| 49 |
+
MIGRATION_SUMMARY.md
|
| 50 |
+
Claude's Plan.md
|
| 51 |
+
AI_MODEL_COMPARISON.md
|
| 52 |
+
TRAINING_STRATEGY.md
|
| 53 |
+
ZERO_SHOT_MODEL_SELECTION.md
|
| 54 |
+
HF_DEPLOYMENT_CHECKLIST.md
|
| 55 |
+
|
| 56 |
+
# Test files
|
| 57 |
+
test_*.py
|
| 58 |
+
mock_data*.json
|
| 59 |
+
|
| 60 |
+
# Local-specific files
|
| 61 |
+
Dockerfile
|
| 62 |
+
docker-compose.yml
|
| 63 |
+
.dockerignore
|
| 64 |
+
gunicorn_config.py
|
| 65 |
+
run.py
|
| 66 |
+
start.sh
|
| 67 |
+
|
| 68 |
+
# Keep these for HF:
|
| 69 |
+
# - Dockerfile (will be copied from Dockerfile.hf)
|
| 70 |
+
# - README.md (will be copied from README_HF.md)
|
| 71 |
+
# - app_hf.py
|
| 72 |
+
# - wsgi.py
|
| 73 |
+
# - requirements.txt
|
| 74 |
+
# - app/ directory
|
| 75 |
+
|
app/analyzer.py
CHANGED
|
@@ -90,7 +90,8 @@ class SubmissionAnalyzer:
|
|
| 90 |
finetuned_path,
|
| 91 |
num_labels=len(self.categories),
|
| 92 |
id2label=self.id2label,
|
| 93 |
-
label2id=self.label2id
|
|
|
|
| 94 |
)
|
| 95 |
self.model.eval()
|
| 96 |
self.model_type = 'finetuned'
|
|
@@ -102,14 +103,23 @@ class SubmissionAnalyzer:
|
|
| 102 |
|
| 103 |
# Load base zero-shot model
|
| 104 |
try:
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
self.classifier = pipeline(
|
| 107 |
"zero-shot-classification",
|
| 108 |
-
model=
|
| 109 |
device=-1 # Use CPU (-1), change to 0 for GPU
|
| 110 |
)
|
| 111 |
self.model_type = 'base'
|
| 112 |
-
|
|
|
|
| 113 |
except Exception as e:
|
| 114 |
logger.error(f"Error loading model: {e}")
|
| 115 |
raise
|
|
|
|
| 90 |
finetuned_path,
|
| 91 |
num_labels=len(self.categories),
|
| 92 |
id2label=self.id2label,
|
| 93 |
+
label2id=self.label2id,
|
| 94 |
+
ignore_mismatched_sizes=True
|
| 95 |
)
|
| 96 |
self.model.eval()
|
| 97 |
self.model_type = 'finetuned'
|
|
|
|
| 103 |
|
| 104 |
# Load base zero-shot model
|
| 105 |
try:
|
| 106 |
+
# Get selected zero-shot model from settings
|
| 107 |
+
from app.models.models import Settings
|
| 108 |
+
from app.fine_tuning.model_presets import get_model_preset
|
| 109 |
+
|
| 110 |
+
zero_shot_model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
|
| 111 |
+
model_preset = get_model_preset(zero_shot_model_key)
|
| 112 |
+
zero_shot_model_id = model_preset['model_id']
|
| 113 |
+
|
| 114 |
+
logger.info(f"Loading zero-shot classification model: {zero_shot_model_id}...")
|
| 115 |
self.classifier = pipeline(
|
| 116 |
"zero-shot-classification",
|
| 117 |
+
model=zero_shot_model_id,
|
| 118 |
device=-1 # Use CPU (-1), change to 0 for GPU
|
| 119 |
)
|
| 120 |
self.model_type = 'base'
|
| 121 |
+
self.zero_shot_model_key = zero_shot_model_key
|
| 122 |
+
logger.info(f"Zero-shot model loaded successfully: {model_preset['name']}!")
|
| 123 |
except Exception as e:
|
| 124 |
logger.error(f"Error loading model: {e}")
|
| 125 |
raise
|
app/fine_tuning/model_manager.py
CHANGED
|
@@ -20,16 +20,27 @@ logger = logging.getLogger(__name__)
|
|
| 20 |
class ModelManager:
|
| 21 |
"""Manage fine-tuned model deployment and versioning"""
|
| 22 |
|
| 23 |
-
def __init__(self, models_dir: str =
|
| 24 |
"""
|
| 25 |
Initialize ModelManager.
|
| 26 |
|
| 27 |
Args:
|
| 28 |
models_dir: Base directory for storing fine-tuned models
|
|
|
|
| 29 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.models_dir = models_dir
|
| 31 |
self.base_model_name = "facebook/bart-large-mnli"
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def get_model_path(self, run_id: int) -> str:
|
| 35 |
"""Get path to model for a specific training run"""
|
|
@@ -56,7 +67,10 @@ class ModelManager:
|
|
| 56 |
model_name = model_path
|
| 57 |
|
| 58 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 59 |
-
model = AutoModelForSequenceClassification.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
return model, tokenizer
|
| 62 |
|
|
|
|
| 20 |
class ModelManager:
|
| 21 |
"""Manage fine-tuned model deployment and versioning"""
|
| 22 |
|
| 23 |
+
def __init__(self, models_dir: str = None):
|
| 24 |
"""
|
| 25 |
Initialize ModelManager.
|
| 26 |
|
| 27 |
Args:
|
| 28 |
models_dir: Base directory for storing fine-tuned models
|
| 29 |
+
(defaults to MODELS_DIR env var or './models/finetuned')
|
| 30 |
"""
|
| 31 |
+
if models_dir is None:
|
| 32 |
+
# Use environment variable or local path
|
| 33 |
+
models_dir = os.getenv('MODELS_DIR', 'models/finetuned')
|
| 34 |
+
|
| 35 |
self.models_dir = models_dir
|
| 36 |
self.base_model_name = "facebook/bart-large-mnli"
|
| 37 |
+
|
| 38 |
+
# Create directory if it doesn't exist
|
| 39 |
+
try:
|
| 40 |
+
os.makedirs(models_dir, exist_ok=True)
|
| 41 |
+
except PermissionError:
|
| 42 |
+
logger.error(f"Permission denied creating models directory: {models_dir}")
|
| 43 |
+
raise
|
| 44 |
|
| 45 |
def get_model_path(self, run_id: int) -> str:
|
| 46 |
"""Get path to model for a specific training run"""
|
|
|
|
| 67 |
model_name = model_path
|
| 68 |
|
| 69 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 70 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 71 |
+
model_name,
|
| 72 |
+
ignore_mismatched_sizes=True
|
| 73 |
+
)
|
| 74 |
|
| 75 |
return model, tokenizer
|
| 76 |
|
app/fine_tuning/model_presets.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model presets for both fine-tuning and zero-shot classification.
|
| 3 |
+
Provides configuration for various HuggingFace models optimized for text classification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
MODEL_PRESETS = {
|
| 7 |
+
# Zero-shot capable models (NLI-trained)
|
| 8 |
+
'bart-large-mnli': {
|
| 9 |
+
'name': 'BART-large-MNLI',
|
| 10 |
+
'model_id': 'facebook/bart-large-mnli',
|
| 11 |
+
'max_length': 1024,
|
| 12 |
+
'size': '400M',
|
| 13 |
+
'speed': 'Slow',
|
| 14 |
+
'best_for': 'Zero-shot + Fine-tuning',
|
| 15 |
+
'description': 'Large sequence-to-sequence model, excellent zero-shot performance',
|
| 16 |
+
'recommended_lr': 2e-5,
|
| 17 |
+
'recommended_batch': 4,
|
| 18 |
+
'supports_zero_shot': True
|
| 19 |
+
},
|
| 20 |
+
'deberta-v3-base-mnli': {
|
| 21 |
+
'name': 'DeBERTa-v3-base-MNLI',
|
| 22 |
+
'model_id': 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli',
|
| 23 |
+
'max_length': 512,
|
| 24 |
+
'size': '86M',
|
| 25 |
+
'speed': 'Fast',
|
| 26 |
+
'best_for': 'Fast zero-shot classification',
|
| 27 |
+
'description': 'DeBERTa trained on NLI datasets, excellent zero-shot with better speed',
|
| 28 |
+
'recommended_lr': 2e-5,
|
| 29 |
+
'recommended_batch': 8,
|
| 30 |
+
'supports_zero_shot': True
|
| 31 |
+
},
|
| 32 |
+
'distilbart-mnli': {
|
| 33 |
+
'name': 'DistilBART-MNLI',
|
| 34 |
+
'model_id': 'valhalla/distilbart-mnli-12-3',
|
| 35 |
+
'max_length': 1024,
|
| 36 |
+
'size': '134M',
|
| 37 |
+
'speed': 'Medium',
|
| 38 |
+
'best_for': 'Balanced zero-shot',
|
| 39 |
+
'description': 'Distilled BART for zero-shot, good balance of speed and accuracy',
|
| 40 |
+
'recommended_lr': 2e-5,
|
| 41 |
+
'recommended_batch': 8,
|
| 42 |
+
'supports_zero_shot': True
|
| 43 |
+
},
|
| 44 |
+
|
| 45 |
+
# Fine-tuning only models
|
| 46 |
+
'deberta-v3-small': {
|
| 47 |
+
'name': 'DeBERTa-v3-small',
|
| 48 |
+
'model_id': 'microsoft/deberta-v3-small',
|
| 49 |
+
'max_length': 512,
|
| 50 |
+
'size': '44M',
|
| 51 |
+
'speed': 'Very Fast',
|
| 52 |
+
'best_for': 'Fine-tuning with small datasets',
|
| 53 |
+
'description': 'State-of-the-art efficient model, excellent for small datasets',
|
| 54 |
+
'recommended_lr': 3e-5,
|
| 55 |
+
'recommended_batch': 8,
|
| 56 |
+
'supports_zero_shot': False
|
| 57 |
+
},
|
| 58 |
+
'deberta-v3-base': {
|
| 59 |
+
'name': 'DeBERTa-v3-base',
|
| 60 |
+
'model_id': 'microsoft/deberta-v3-base',
|
| 61 |
+
'max_length': 512,
|
| 62 |
+
'size': '86M',
|
| 63 |
+
'speed': 'Fast',
|
| 64 |
+
'best_for': 'High accuracy fine-tuning',
|
| 65 |
+
'description': 'Larger DeBERTa model with better accuracy',
|
| 66 |
+
'recommended_lr': 2e-5,
|
| 67 |
+
'recommended_batch': 8,
|
| 68 |
+
'supports_zero_shot': False
|
| 69 |
+
},
|
| 70 |
+
'distilbert-base': {
|
| 71 |
+
'name': 'DistilBERT-base',
|
| 72 |
+
'model_id': 'distilbert-base-uncased',
|
| 73 |
+
'max_length': 512,
|
| 74 |
+
'size': '66M',
|
| 75 |
+
'speed': 'Fast',
|
| 76 |
+
'best_for': 'Balanced speed and accuracy',
|
| 77 |
+
'description': 'Distilled BERT, 60% faster with 97% performance retention',
|
| 78 |
+
'recommended_lr': 5e-5,
|
| 79 |
+
'recommended_batch': 8,
|
| 80 |
+
'supports_zero_shot': False
|
| 81 |
+
},
|
| 82 |
+
'roberta-base': {
|
| 83 |
+
'name': 'RoBERTa-base',
|
| 84 |
+
'model_id': 'roberta-base',
|
| 85 |
+
'max_length': 512,
|
| 86 |
+
'size': '125M',
|
| 87 |
+
'speed': 'Medium',
|
| 88 |
+
'best_for': 'Maximum accuracy',
|
| 89 |
+
'description': 'Robustly optimized BERT, excellent classification performance',
|
| 90 |
+
'recommended_lr': 2e-5,
|
| 91 |
+
'recommended_batch': 8,
|
| 92 |
+
'supports_zero_shot': False
|
| 93 |
+
},
|
| 94 |
+
'electra-small': {
|
| 95 |
+
'name': 'ELECTRA-small',
|
| 96 |
+
'model_id': 'google/electra-small-discriminator',
|
| 97 |
+
'max_length': 512,
|
| 98 |
+
'size': '14M',
|
| 99 |
+
'speed': 'Fastest',
|
| 100 |
+
'best_for': 'Speed-critical applications',
|
| 101 |
+
'description': 'Very fast and lightweight, good for production',
|
| 102 |
+
'recommended_lr': 5e-5,
|
| 103 |
+
'recommended_batch': 16,
|
| 104 |
+
'supports_zero_shot': False
|
| 105 |
+
},
|
| 106 |
+
'minilm': {
|
| 107 |
+
'name': 'MiniLM-L12',
|
| 108 |
+
'model_id': 'microsoft/MiniLM-L12-H384-uncased',
|
| 109 |
+
'max_length': 512,
|
| 110 |
+
'size': '33M',
|
| 111 |
+
'speed': 'Very Fast',
|
| 112 |
+
'best_for': 'Lightweight production deployment',
|
| 113 |
+
'description': 'Compact model optimized for speed',
|
| 114 |
+
'recommended_lr': 4e-5,
|
| 115 |
+
'recommended_batch': 12,
|
| 116 |
+
'supports_zero_shot': False
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def get_model_preset(preset_key):
|
| 121 |
+
"""Get model preset configuration by key."""
|
| 122 |
+
return MODEL_PRESETS.get(preset_key, MODEL_PRESETS['bart-large-mnli'])
|
| 123 |
+
|
| 124 |
+
def get_available_models():
|
| 125 |
+
"""Get list of all available models for selection."""
|
| 126 |
+
return [
|
| 127 |
+
{
|
| 128 |
+
'key': key,
|
| 129 |
+
'name': config['name'],
|
| 130 |
+
'size': config['size'],
|
| 131 |
+
'speed': config['speed'],
|
| 132 |
+
'best_for': config['best_for'],
|
| 133 |
+
'supports_zero_shot': config['supports_zero_shot']
|
| 134 |
+
}
|
| 135 |
+
for key, config in MODEL_PRESETS.items()
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
def get_zero_shot_models():
|
| 139 |
+
"""Get list of models that support zero-shot classification."""
|
| 140 |
+
return [
|
| 141 |
+
{
|
| 142 |
+
'key': key,
|
| 143 |
+
'name': config['name'],
|
| 144 |
+
'model_id': config['model_id'],
|
| 145 |
+
'size': config['size'],
|
| 146 |
+
'speed': config['speed'],
|
| 147 |
+
'description': config['description']
|
| 148 |
+
}
|
| 149 |
+
for key, config in MODEL_PRESETS.items()
|
| 150 |
+
if config.get('supports_zero_shot', False)
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
def get_recommended_hyperparams(preset_key, training_mode='lora'):
|
| 154 |
+
"""Get recommended hyperparameters for a model preset."""
|
| 155 |
+
preset = get_model_preset(preset_key)
|
| 156 |
+
|
| 157 |
+
base_params = {
|
| 158 |
+
'learning_rate': preset['recommended_lr'],
|
| 159 |
+
'batch_size': preset['recommended_batch'],
|
| 160 |
+
'max_length': preset['max_length']
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
if training_mode == 'head_only':
|
| 164 |
+
# Higher learning rate for head-only training
|
| 165 |
+
base_params['learning_rate'] = preset['recommended_lr'] * 2
|
| 166 |
+
|
| 167 |
+
return base_params
|
| 168 |
+
|
app/fine_tuning/trainer.py
CHANGED
|
@@ -75,12 +75,27 @@ class BARTFineTuner:
|
|
| 75 |
# Validate splits
|
| 76 |
assert abs(train_split + val_split + test_split - 1.0) < 0.01, "Splits must sum to 1.0"
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# First split: separate test set
|
| 79 |
train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
|
| 80 |
texts, labels,
|
| 81 |
test_size=test_split,
|
| 82 |
random_state=random_state,
|
| 83 |
-
stratify=labels
|
| 84 |
)
|
| 85 |
|
| 86 |
# Second split: separate train and validation
|
|
@@ -89,7 +104,7 @@ class BARTFineTuner:
|
|
| 89 |
train_val_texts, train_val_labels,
|
| 90 |
test_size=val_size_adjusted,
|
| 91 |
random_state=random_state,
|
| 92 |
-
stratify=train_val_labels
|
| 93 |
)
|
| 94 |
|
| 95 |
# Tokenize datasets
|
|
@@ -126,6 +141,36 @@ class BARTFineTuner:
|
|
| 126 |
|
| 127 |
return Dataset.from_dict(dataset_dict)
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def setup_lora_model(self, lora_config: Dict) -> None:
|
| 130 |
"""
|
| 131 |
Set up BART model with LoRA adapters.
|
|
@@ -145,7 +190,8 @@ class BARTFineTuner:
|
|
| 145 |
num_labels=len(self.categories),
|
| 146 |
id2label=self.id2label,
|
| 147 |
label2id=self.label2id,
|
| 148 |
-
problem_type="single_label_classification"
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
# Configure LoRA
|
|
@@ -193,6 +239,10 @@ class BARTFineTuner:
|
|
| 193 |
# Create output directory
|
| 194 |
os.makedirs(output_dir, exist_ok=True)
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
# Training arguments
|
| 197 |
training_args = TrainingArguments(
|
| 198 |
output_dir=output_dir,
|
|
@@ -211,7 +261,8 @@ class BARTFineTuner:
|
|
| 211 |
greater_is_better=False,
|
| 212 |
save_total_limit=2,
|
| 213 |
report_to="none", # Disable wandb, tensorboard
|
| 214 |
-
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
# Trainer
|
|
@@ -268,7 +319,8 @@ class BARTFineTuner:
|
|
| 268 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 269 |
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 270 |
model_path,
|
| 271 |
-
num_labels=len(self.categories)
|
|
|
|
| 272 |
)
|
| 273 |
|
| 274 |
# Make predictions
|
|
@@ -281,11 +333,24 @@ class BARTFineTuner:
|
|
| 281 |
|
| 282 |
with torch.no_grad():
|
| 283 |
for i in range(len(test_dataset)):
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
outputs = self.model(**batch)
|
| 286 |
pred = torch.argmax(outputs.logits, dim=1).item()
|
| 287 |
predictions.append(pred)
|
| 288 |
-
true_labels.append(
|
| 289 |
|
| 290 |
# Calculate metrics
|
| 291 |
accuracy = accuracy_score(true_labels, predictions)
|
|
|
|
| 75 |
# Validate splits
|
| 76 |
assert abs(train_split + val_split + test_split - 1.0) < 0.01, "Splits must sum to 1.0"
|
| 77 |
|
| 78 |
+
num_classes = len(self.categories)
|
| 79 |
+
total_examples = len(texts)
|
| 80 |
+
|
| 81 |
+
# Calculate minimum examples needed for stratified split
|
| 82 |
+
# Need at least num_classes examples in each split
|
| 83 |
+
min_test_size = int(total_examples * test_split)
|
| 84 |
+
min_val_size = int(total_examples * val_split)
|
| 85 |
+
|
| 86 |
+
# Check if we have enough examples for stratification
|
| 87 |
+
use_stratify = (min_test_size >= num_classes and min_val_size >= num_classes)
|
| 88 |
+
|
| 89 |
+
if not use_stratify:
|
| 90 |
+
logger.warning(f"Dataset too small ({total_examples} examples) for stratified split. "
|
| 91 |
+
f"Using random split instead.")
|
| 92 |
+
|
| 93 |
# First split: separate test set
|
| 94 |
train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
|
| 95 |
texts, labels,
|
| 96 |
test_size=test_split,
|
| 97 |
random_state=random_state,
|
| 98 |
+
stratify=labels if use_stratify else None
|
| 99 |
)
|
| 100 |
|
| 101 |
# Second split: separate train and validation
|
|
|
|
| 104 |
train_val_texts, train_val_labels,
|
| 105 |
test_size=val_size_adjusted,
|
| 106 |
random_state=random_state,
|
| 107 |
+
stratify=train_val_labels if use_stratify else None
|
| 108 |
)
|
| 109 |
|
| 110 |
# Tokenize datasets
|
|
|
|
| 141 |
|
| 142 |
return Dataset.from_dict(dataset_dict)
|
| 143 |
|
| 144 |
+
def setup_head_only_model(self) -> None:
|
| 145 |
+
"""
|
| 146 |
+
Set up BART model for classification head-only fine-tuning.
|
| 147 |
+
Freezes the encoder and only trains the classification head.
|
| 148 |
+
Better for small datasets (<100 examples).
|
| 149 |
+
"""
|
| 150 |
+
logger.info("Setting up BART model for head-only training")
|
| 151 |
+
|
| 152 |
+
# Load base model
|
| 153 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 154 |
+
self.base_model_name,
|
| 155 |
+
num_labels=len(self.categories),
|
| 156 |
+
id2label=self.id2label,
|
| 157 |
+
label2id=self.label2id,
|
| 158 |
+
problem_type="single_label_classification",
|
| 159 |
+
ignore_mismatched_sizes=True
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Freeze all parameters except classification head
|
| 163 |
+
for name, param in self.model.named_parameters():
|
| 164 |
+
if 'classification_head' in name or 'classifier' in name:
|
| 165 |
+
param.requires_grad = True
|
| 166 |
+
else:
|
| 167 |
+
param.requires_grad = False
|
| 168 |
+
|
| 169 |
+
# Count trainable parameters
|
| 170 |
+
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 171 |
+
total = sum(p.numel() for p in self.model.parameters())
|
| 172 |
+
logger.info(f"Trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")
|
| 173 |
+
|
| 174 |
def setup_lora_model(self, lora_config: Dict) -> None:
|
| 175 |
"""
|
| 176 |
Set up BART model with LoRA adapters.
|
|
|
|
| 190 |
num_labels=len(self.categories),
|
| 191 |
id2label=self.id2label,
|
| 192 |
label2id=self.label2id,
|
| 193 |
+
problem_type="single_label_classification",
|
| 194 |
+
ignore_mismatched_sizes=True # BART-MNLI has 3 classes, we need 6
|
| 195 |
)
|
| 196 |
|
| 197 |
# Configure LoRA
|
|
|
|
| 239 |
# Create output directory
|
| 240 |
os.makedirs(output_dir, exist_ok=True)
|
| 241 |
|
| 242 |
+
# Force CPU training to avoid cuDNN compatibility issues on WSL2
|
| 243 |
+
use_cuda = False
|
| 244 |
+
logger.info("Using CPU for training (CUDA disabled to avoid compatibility issues)")
|
| 245 |
+
|
| 246 |
# Training arguments
|
| 247 |
training_args = TrainingArguments(
|
| 248 |
output_dir=output_dir,
|
|
|
|
| 261 |
greater_is_better=False,
|
| 262 |
save_total_limit=2,
|
| 263 |
report_to="none", # Disable wandb, tensorboard
|
| 264 |
+
use_cpu=not use_cuda, # Use CPU if CUDA test fails
|
| 265 |
+
fp16=use_cuda, # Only use mixed precision with working CUDA
|
| 266 |
)
|
| 267 |
|
| 268 |
# Trainer
|
|
|
|
| 319 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 320 |
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 321 |
model_path,
|
| 322 |
+
num_labels=len(self.categories),
|
| 323 |
+
ignore_mismatched_sizes=True
|
| 324 |
)
|
| 325 |
|
| 326 |
# Make predictions
|
|
|
|
| 333 |
|
| 334 |
with torch.no_grad():
|
| 335 |
for i in range(len(test_dataset)):
|
| 336 |
+
# Get the data - handle both tensor and list formats
|
| 337 |
+
item = test_dataset[i]
|
| 338 |
+
|
| 339 |
+
# Convert to tensors if needed
|
| 340 |
+
input_ids = torch.tensor(item['input_ids']) if isinstance(item['input_ids'], list) else item['input_ids']
|
| 341 |
+
attention_mask = torch.tensor(item['attention_mask']) if isinstance(item['attention_mask'], list) else item['attention_mask']
|
| 342 |
+
label = torch.tensor(item['labels']) if isinstance(item['labels'], list) else item['labels']
|
| 343 |
+
|
| 344 |
+
# Create batch
|
| 345 |
+
batch = {
|
| 346 |
+
'input_ids': input_ids.unsqueeze(0).to(device),
|
| 347 |
+
'attention_mask': attention_mask.unsqueeze(0).to(device)
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
outputs = self.model(**batch)
|
| 351 |
pred = torch.argmax(outputs.logits, dim=1).item()
|
| 352 |
predictions.append(pred)
|
| 353 |
+
true_labels.append(label.item() if isinstance(label, torch.Tensor) else label)
|
| 354 |
|
| 355 |
# Calculate metrics
|
| 356 |
accuracy = accuracy_score(true_labels, predictions)
|
app/models/models.py
CHANGED
|
@@ -51,7 +51,7 @@ class Settings(db.Model):
|
|
| 51 |
|
| 52 |
id = db.Column(db.Integer, primary_key=True)
|
| 53 |
key = db.Column(db.String(50), unique=True, nullable=False)
|
| 54 |
-
value = db.Column(db.String(
|
| 55 |
|
| 56 |
@staticmethod
|
| 57 |
def get_setting(key, default='true'):
|
|
|
|
| 51 |
|
| 52 |
id = db.Column(db.Integer, primary_key=True)
|
| 53 |
key = db.Column(db.String(50), unique=True, nullable=False)
|
| 54 |
+
value = db.Column(db.String(100), nullable=False) # Increased to support model IDs
|
| 55 |
|
| 56 |
@staticmethod
|
| 57 |
def get_setting(key, default='true'):
|
app/routes/admin.py
CHANGED
|
@@ -3,11 +3,15 @@ from app.models.models import Token, Submission, Settings, TrainingExample, Fine
|
|
| 3 |
from app import db
|
| 4 |
from app.analyzer import get_analyzer
|
| 5 |
from functools import wraps
|
|
|
|
| 6 |
import json
|
| 7 |
import csv
|
| 8 |
import io
|
| 9 |
from datetime import datetime
|
| 10 |
import os
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
bp = Blueprint('admin', __name__, url_prefix='/admin')
|
| 13 |
|
|
@@ -801,20 +805,27 @@ def _run_training_job(run_id: int, config: Dict):
|
|
| 801 |
test_split=config.get('test_split', 0.15)
|
| 802 |
)
|
| 803 |
|
| 804 |
-
# Setup
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
|
| 812 |
# Update status to training
|
| 813 |
run.status = 'training'
|
| 814 |
db.session.commit()
|
| 815 |
|
| 816 |
# Train
|
| 817 |
-
models_dir = os.getenv('MODELS_DIR', '
|
| 818 |
output_dir = os.path.join(models_dir, f'run_{run_id}')
|
| 819 |
|
| 820 |
training_config = {
|
|
@@ -889,9 +900,14 @@ def get_training_status(run_id):
|
|
| 889 |
elif run.status == 'failed':
|
| 890 |
progress = 0
|
| 891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
status_messages = {
|
| 893 |
'preparing': 'Preparing training data...',
|
| 894 |
-
'training': 'Training model
|
| 895 |
'evaluating': 'Evaluating model performance...',
|
| 896 |
'completed': 'Training completed successfully!',
|
| 897 |
'failed': 'Training failed'
|
|
@@ -971,3 +987,372 @@ def get_run_details(run_id):
|
|
| 971 |
run = FineTuningRun.query.get_or_404(run_id)
|
| 972 |
|
| 973 |
return jsonify(run.to_dict())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from app import db
|
| 4 |
from app.analyzer import get_analyzer
|
| 5 |
from functools import wraps
|
| 6 |
+
from typing import Dict
|
| 7 |
import json
|
| 8 |
import csv
|
| 9 |
import io
|
| 10 |
from datetime import datetime
|
| 11 |
import os
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
bp = Blueprint('admin', __name__, url_prefix='/admin')
|
| 17 |
|
|
|
|
| 805 |
test_split=config.get('test_split', 0.15)
|
| 806 |
)
|
| 807 |
|
| 808 |
+
# Setup model based on training mode
|
| 809 |
+
training_mode = config.get('training_mode', 'head_only')
|
| 810 |
+
|
| 811 |
+
if training_mode == 'head_only':
|
| 812 |
+
# Head-only training (recommended for small datasets)
|
| 813 |
+
trainer.setup_head_only_model()
|
| 814 |
+
else:
|
| 815 |
+
# LoRA training
|
| 816 |
+
lora_config = {
|
| 817 |
+
'r': config.get('lora_rank', 16),
|
| 818 |
+
'lora_alpha': config.get('lora_alpha', 32),
|
| 819 |
+
'lora_dropout': config.get('lora_dropout', 0.1)
|
| 820 |
+
}
|
| 821 |
+
trainer.setup_lora_model(lora_config)
|
| 822 |
|
| 823 |
# Update status to training
|
| 824 |
run.status = 'training'
|
| 825 |
db.session.commit()
|
| 826 |
|
| 827 |
# Train
|
| 828 |
+
models_dir = os.getenv('MODELS_DIR', 'models/finetuned')
|
| 829 |
output_dir = os.path.join(models_dir, f'run_{run_id}')
|
| 830 |
|
| 831 |
training_config = {
|
|
|
|
| 900 |
elif run.status == 'failed':
|
| 901 |
progress = 0
|
| 902 |
|
| 903 |
+
# Get training mode from config
|
| 904 |
+
config = run.get_config() if hasattr(run, 'get_config') else {}
|
| 905 |
+
training_mode = config.get('training_mode', 'lora')
|
| 906 |
+
mode_label = 'classification head only' if training_mode == 'head_only' else 'LoRA adapters'
|
| 907 |
+
|
| 908 |
status_messages = {
|
| 909 |
'preparing': 'Preparing training data...',
|
| 910 |
+
'training': f'Training model ({mode_label})...',
|
| 911 |
'evaluating': 'Evaluating model performance...',
|
| 912 |
'completed': 'Training completed successfully!',
|
| 913 |
'failed': 'Training failed'
|
|
|
|
| 987 |
run = FineTuningRun.query.get_or_404(run_id)
|
| 988 |
|
| 989 |
return jsonify(run.to_dict())
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
@bp.route('/api/set-zero-shot-model', methods=['POST'])
|
| 993 |
+
@admin_required
|
| 994 |
+
def set_zero_shot_model():
|
| 995 |
+
"""Set the zero-shot model for classification"""
|
| 996 |
+
try:
|
| 997 |
+
from app.fine_tuning.model_presets import get_model_preset
|
| 998 |
+
from app.analyzer import reload_analyzer
|
| 999 |
+
|
| 1000 |
+
data = request.get_json()
|
| 1001 |
+
model_key = data.get('model_key')
|
| 1002 |
+
|
| 1003 |
+
if not model_key:
|
| 1004 |
+
return jsonify({'success': False, 'error': 'No model key provided'}), 400
|
| 1005 |
+
|
| 1006 |
+
# Validate model exists and supports zero-shot
|
| 1007 |
+
model_preset = get_model_preset(model_key)
|
| 1008 |
+
if not model_preset.get('supports_zero_shot', False):
|
| 1009 |
+
return jsonify({
|
| 1010 |
+
'success': False,
|
| 1011 |
+
'error': 'Selected model does not support zero-shot classification'
|
| 1012 |
+
}), 400
|
| 1013 |
+
|
| 1014 |
+
# Save setting
|
| 1015 |
+
Settings.set_setting('zero_shot_model', model_key)
|
| 1016 |
+
|
| 1017 |
+
# Reload analyzer with new model
|
| 1018 |
+
reload_analyzer()
|
| 1019 |
+
|
| 1020 |
+
logger.info(f"Zero-shot model changed to: {model_preset['name']}")
|
| 1021 |
+
|
| 1022 |
+
return jsonify({
|
| 1023 |
+
'success': True,
|
| 1024 |
+
'message': f"Zero-shot model changed to {model_preset['name']}",
|
| 1025 |
+
'model_key': model_key,
|
| 1026 |
+
'model_name': model_preset['name']
|
| 1027 |
+
})
|
| 1028 |
+
|
| 1029 |
+
except Exception as e:
|
| 1030 |
+
logger.error(f"Error changing zero-shot model: {str(e)}")
|
| 1031 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
@bp.route('/api/get-zero-shot-model', methods=['GET'])
|
| 1035 |
+
@admin_required
|
| 1036 |
+
def get_zero_shot_model():
|
| 1037 |
+
"""Get the current zero-shot model"""
|
| 1038 |
+
try:
|
| 1039 |
+
from app.fine_tuning.model_presets import get_model_preset
|
| 1040 |
+
|
| 1041 |
+
model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
|
| 1042 |
+
model_preset = get_model_preset(model_key)
|
| 1043 |
+
|
| 1044 |
+
return jsonify({
|
| 1045 |
+
'success': True,
|
| 1046 |
+
'model_key': model_key,
|
| 1047 |
+
'model_name': model_preset['name'],
|
| 1048 |
+
'model_info': {
|
| 1049 |
+
'size': model_preset['size'],
|
| 1050 |
+
'speed': model_preset['speed'],
|
| 1051 |
+
'description': model_preset['description']
|
| 1052 |
+
}
|
| 1053 |
+
})
|
| 1054 |
+
|
| 1055 |
+
except Exception as e:
|
| 1056 |
+
logger.error(f"Error getting zero-shot model: {str(e)}")
|
| 1057 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
@bp.route('/api/delete-training-run/<int:run_id>', methods=['DELETE'])
|
| 1061 |
+
@admin_required
|
| 1062 |
+
def delete_training_run(run_id):
|
| 1063 |
+
"""Delete a training run and its associated files"""
|
| 1064 |
+
try:
|
| 1065 |
+
run = FineTuningRun.query.get_or_404(run_id)
|
| 1066 |
+
|
| 1067 |
+
# Prevent deletion of active model
|
| 1068 |
+
if run.is_active_model:
|
| 1069 |
+
return jsonify({
|
| 1070 |
+
'success': False,
|
| 1071 |
+
'error': 'Cannot delete the active model. Please rollback or deploy another model first.'
|
| 1072 |
+
}), 400
|
| 1073 |
+
|
| 1074 |
+
# Prevent deletion of currently training runs
|
| 1075 |
+
if run.status == 'training':
|
| 1076 |
+
return jsonify({
|
| 1077 |
+
'success': False,
|
| 1078 |
+
'error': 'Cannot delete a training run that is currently in progress.'
|
| 1079 |
+
}), 400
|
| 1080 |
+
|
| 1081 |
+
# Delete model files if they exist
|
| 1082 |
+
import shutil
|
| 1083 |
+
if run.model_path and os.path.exists(run.model_path):
|
| 1084 |
+
try:
|
| 1085 |
+
shutil.rmtree(run.model_path)
|
| 1086 |
+
logger.info(f"Deleted model files at {run.model_path}")
|
| 1087 |
+
except Exception as e:
|
| 1088 |
+
logger.error(f"Error deleting model files: {str(e)}")
|
| 1089 |
+
# Continue with database deletion even if file deletion fails
|
| 1090 |
+
|
| 1091 |
+
# Unlink training examples from this run (don't delete the examples themselves)
|
| 1092 |
+
for example in run.training_examples:
|
| 1093 |
+
example.training_run_id = None
|
| 1094 |
+
example.used_in_training = False
|
| 1095 |
+
|
| 1096 |
+
# Delete the training run from database
|
| 1097 |
+
db.session.delete(run)
|
| 1098 |
+
db.session.commit()
|
| 1099 |
+
|
| 1100 |
+
return jsonify({
|
| 1101 |
+
'success': True,
|
| 1102 |
+
'message': f'Training run #{run_id} deleted successfully'
|
| 1103 |
+
})
|
| 1104 |
+
|
| 1105 |
+
except Exception as e:
|
| 1106 |
+
db.session.rollback()
|
| 1107 |
+
logger.error(f"Error deleting training run: {str(e)}")
|
| 1108 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
@bp.route('/api/export-model/<int:run_id>', methods=['GET'])
|
| 1112 |
+
@admin_required
|
| 1113 |
+
def export_model(run_id):
|
| 1114 |
+
"""Export a trained model as a downloadable ZIP file"""
|
| 1115 |
+
try:
|
| 1116 |
+
import tempfile
|
| 1117 |
+
import shutil
|
| 1118 |
+
from datetime import datetime
|
| 1119 |
+
|
| 1120 |
+
run = FineTuningRun.query.get_or_404(run_id)
|
| 1121 |
+
|
| 1122 |
+
if run.status != 'completed':
|
| 1123 |
+
return jsonify({
|
| 1124 |
+
'success': False,
|
| 1125 |
+
'error': 'Can only export completed training runs'
|
| 1126 |
+
}), 400
|
| 1127 |
+
|
| 1128 |
+
if not run.model_path or not os.path.exists(run.model_path):
|
| 1129 |
+
return jsonify({
|
| 1130 |
+
'success': False,
|
| 1131 |
+
'error': 'Model files not found'
|
| 1132 |
+
}), 404
|
| 1133 |
+
|
| 1134 |
+
# Create temporary directory for export
|
| 1135 |
+
temp_dir = tempfile.mkdtemp()
|
| 1136 |
+
try:
|
| 1137 |
+
export_name = f"model_run_{run_id}"
|
| 1138 |
+
export_path = os.path.join(temp_dir, export_name)
|
| 1139 |
+
|
| 1140 |
+
# Copy model files
|
| 1141 |
+
shutil.copytree(run.model_path, export_path)
|
| 1142 |
+
|
| 1143 |
+
# Create model card with metadata
|
| 1144 |
+
config = run.get_config()
|
| 1145 |
+
results = run.get_results()
|
| 1146 |
+
|
| 1147 |
+
model_card = {
|
| 1148 |
+
'run_id': run_id,
|
| 1149 |
+
'export_date': datetime.utcnow().isoformat(),
|
| 1150 |
+
'created_at': run.created_at.isoformat() if run.created_at else None,
|
| 1151 |
+
'training_mode': config.get('training_mode', 'lora'),
|
| 1152 |
+
'base_model': 'facebook/bart-large-mnli',
|
| 1153 |
+
'model_type': 'BART fine-tuned for text classification',
|
| 1154 |
+
'task': 'Multi-class text classification',
|
| 1155 |
+
'categories': ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions'],
|
| 1156 |
+
'training_config': config,
|
| 1157 |
+
'results': results,
|
| 1158 |
+
'improvement_over_baseline': run.improvement_over_baseline,
|
| 1159 |
+
'num_training_examples': run.num_training_examples,
|
| 1160 |
+
'num_validation_examples': run.num_validation_examples,
|
| 1161 |
+
'num_test_examples': run.num_test_examples
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
with open(os.path.join(export_path, 'model_card.json'), 'w') as f:
|
| 1165 |
+
json.dump(model_card, f, indent=2)
|
| 1166 |
+
|
| 1167 |
+
# Create README
|
| 1168 |
+
readme_content = f"""# Participatory Planning Model - Run {run_id}
|
| 1169 |
+
|
| 1170 |
+
## Model Information
|
| 1171 |
+
- **Export Date**: {datetime.utcnow().strftime('%Y-%m-%d %H:%M UTC')}
|
| 1172 |
+
- **Training Mode**: {config.get('training_mode', 'lora').upper()}
|
| 1173 |
+
- **Base Model**: facebook/bart-large-mnli
|
| 1174 |
+
- **Task**: Multi-class text classification
|
| 1175 |
+
|
| 1176 |
+
## Categories
|
| 1177 |
+
1. Vision
|
| 1178 |
+
2. Problem
|
| 1179 |
+
3. Objectives
|
| 1180 |
+
4. Directives
|
| 1181 |
+
5. Values
|
| 1182 |
+
6. Actions
|
| 1183 |
+
|
| 1184 |
+
## Training Configuration
|
| 1185 |
+
- **Learning Rate**: {config.get('learning_rate', 'N/A')}
|
| 1186 |
+
- **Epochs**: {config.get('num_epochs', 'N/A')}
|
| 1187 |
+
- **Batch Size**: {config.get('batch_size', 'N/A')}
|
| 1188 |
+
- **Training Examples**: {run.num_training_examples}
|
| 1189 |
+
- **Validation Examples**: {run.num_validation_examples}
|
| 1190 |
+
- **Test Examples**: {run.num_test_examples}
|
| 1191 |
+
|
| 1192 |
+
## Performance
|
| 1193 |
+
- **Test Accuracy**: {results.get('test_accuracy', 0)*100:.1f}%
|
| 1194 |
+
- **Improvement over Baseline**: {run.improvement_over_baseline*100:.1f}%
|
| 1195 |
+
|
| 1196 |
+
## Usage
|
| 1197 |
+
To load this model:
|
| 1198 |
+
```python
|
| 1199 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 1200 |
+
|
| 1201 |
+
tokenizer = AutoTokenizer.from_pretrained("./model_run_{run_id}")
|
| 1202 |
+
model = AutoModelForSequenceClassification.from_pretrained("./model_run_{run_id}")
|
| 1203 |
+
```
|
| 1204 |
+
|
| 1205 |
+
See model_card.json for detailed metrics.
|
| 1206 |
+
"""
|
| 1207 |
+
|
| 1208 |
+
with open(os.path.join(export_path, 'README.md'), 'w') as f:
|
| 1209 |
+
f.write(readme_content)
|
| 1210 |
+
|
| 1211 |
+
# Create ZIP file
|
| 1212 |
+
zip_path = os.path.join(temp_dir, f"model_run_{run_id}")
|
| 1213 |
+
shutil.make_archive(zip_path, 'zip', temp_dir, export_name)
|
| 1214 |
+
zip_file = f"{zip_path}.zip"
|
| 1215 |
+
|
| 1216 |
+
# Read ZIP file into memory before cleaning up temp dir
|
| 1217 |
+
with open(zip_file, 'rb') as f:
|
| 1218 |
+
zip_data = io.BytesIO(f.read())
|
| 1219 |
+
|
| 1220 |
+
# Clean up temp directory
|
| 1221 |
+
shutil.rmtree(temp_dir)
|
| 1222 |
+
|
| 1223 |
+
# Send file from memory
|
| 1224 |
+
zip_data.seek(0)
|
| 1225 |
+
return send_file(
|
| 1226 |
+
zip_data,
|
| 1227 |
+
mimetype='application/zip',
|
| 1228 |
+
as_attachment=True,
|
| 1229 |
+
download_name=f'participatory_planner_model_run_{run_id}_{datetime.now().strftime("%Y%m%d")}.zip'
|
| 1230 |
+
)
|
| 1231 |
+
except Exception as e:
|
| 1232 |
+
# Clean up temp dir if error occurs
|
| 1233 |
+
if os.path.exists(temp_dir):
|
| 1234 |
+
shutil.rmtree(temp_dir)
|
| 1235 |
+
raise e
|
| 1236 |
+
|
| 1237 |
+
except Exception as e:
|
| 1238 |
+
logger.error(f"Error exporting model: {str(e)}")
|
| 1239 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
@bp.route('/api/import-model', methods=['POST'])
|
| 1243 |
+
@admin_required
|
| 1244 |
+
def import_model():
|
| 1245 |
+
"""Import a previously exported model from ZIP file"""
|
| 1246 |
+
try:
|
| 1247 |
+
import tempfile
|
| 1248 |
+
import zipfile
|
| 1249 |
+
import shutil
|
| 1250 |
+
|
| 1251 |
+
if 'file' not in request.files:
|
| 1252 |
+
return jsonify({'success': False, 'error': 'No file uploaded'}), 400
|
| 1253 |
+
|
| 1254 |
+
file = request.files['file']
|
| 1255 |
+
|
| 1256 |
+
if file.filename == '':
|
| 1257 |
+
return jsonify({'success': False, 'error': 'No file selected'}), 400
|
| 1258 |
+
|
| 1259 |
+
if not file.filename.endswith('.zip'):
|
| 1260 |
+
return jsonify({'success': False, 'error': 'File must be a ZIP archive'}), 400
|
| 1261 |
+
|
| 1262 |
+
# Create temporary directory for extraction
|
| 1263 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 1264 |
+
# Save uploaded ZIP
|
| 1265 |
+
zip_path = os.path.join(temp_dir, 'upload.zip')
|
| 1266 |
+
file.save(zip_path)
|
| 1267 |
+
|
| 1268 |
+
# Extract ZIP
|
| 1269 |
+
extract_dir = os.path.join(temp_dir, 'extracted')
|
| 1270 |
+
os.makedirs(extract_dir, exist_ok=True)
|
| 1271 |
+
|
| 1272 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 1273 |
+
zip_ref.extractall(extract_dir)
|
| 1274 |
+
|
| 1275 |
+
# Find the model directory (should be model_run_X)
|
| 1276 |
+
contents = os.listdir(extract_dir)
|
| 1277 |
+
if len(contents) != 1:
|
| 1278 |
+
return jsonify({'success': False, 'error': 'Invalid model archive structure'}), 400
|
| 1279 |
+
|
| 1280 |
+
model_dir = os.path.join(extract_dir, contents[0])
|
| 1281 |
+
|
| 1282 |
+
# Validate it's a valid model
|
| 1283 |
+
required_files = ['config.json']
|
| 1284 |
+
model_files = ['pytorch_model.bin', 'model.safetensors'] # Either format
|
| 1285 |
+
|
| 1286 |
+
has_config = os.path.exists(os.path.join(model_dir, 'config.json'))
|
| 1287 |
+
has_model = any(os.path.exists(os.path.join(model_dir, f)) for f in model_files)
|
| 1288 |
+
|
| 1289 |
+
if not has_config or not has_model:
|
| 1290 |
+
return jsonify({
|
| 1291 |
+
'success': False,
|
| 1292 |
+
'error': 'Invalid model archive - missing required files (config.json and model weights)'
|
| 1293 |
+
}), 400
|
| 1294 |
+
|
| 1295 |
+
# Read model card if available
|
| 1296 |
+
model_info = {}
|
| 1297 |
+
model_card_path = os.path.join(model_dir, 'model_card.json')
|
| 1298 |
+
if os.path.exists(model_card_path):
|
| 1299 |
+
with open(model_card_path, 'r') as f:
|
| 1300 |
+
model_info = json.load(f)
|
| 1301 |
+
|
| 1302 |
+
# Create new training run record
|
| 1303 |
+
training_run = FineTuningRun(
|
| 1304 |
+
status='completed',
|
| 1305 |
+
created_at=datetime.utcnow()
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
# Set config from model card if available
|
| 1309 |
+
if 'training_config' in model_info:
|
| 1310 |
+
training_run.set_config(model_info['training_config'])
|
| 1311 |
+
else:
|
| 1312 |
+
# Default config for imported models
|
| 1313 |
+
training_run.set_config({
|
| 1314 |
+
'training_mode': 'imported',
|
| 1315 |
+
'imported': True,
|
| 1316 |
+
'original_filename': file.filename
|
| 1317 |
+
})
|
| 1318 |
+
|
| 1319 |
+
# Set metadata from model card
|
| 1320 |
+
if 'num_training_examples' in model_info:
|
| 1321 |
+
training_run.num_training_examples = model_info['num_training_examples']
|
| 1322 |
+
if 'num_validation_examples' in model_info:
|
| 1323 |
+
training_run.num_validation_examples = model_info['num_validation_examples']
|
| 1324 |
+
if 'num_test_examples' in model_info:
|
| 1325 |
+
training_run.num_test_examples = model_info['num_test_examples']
|
| 1326 |
+
if 'results' in model_info:
|
| 1327 |
+
training_run.set_results(model_info['results'])
|
| 1328 |
+
if 'improvement_over_baseline' in model_info:
|
| 1329 |
+
training_run.improvement_over_baseline = model_info['improvement_over_baseline']
|
| 1330 |
+
|
| 1331 |
+
training_run.completed_at = datetime.utcnow()
|
| 1332 |
+
|
| 1333 |
+
db.session.add(training_run)
|
| 1334 |
+
db.session.commit()
|
| 1335 |
+
|
| 1336 |
+
# Copy model to models directory
|
| 1337 |
+
models_dir = os.getenv('MODELS_DIR', 'models/finetuned')
|
| 1338 |
+
destination_path = os.path.join(models_dir, f'run_{training_run.id}')
|
| 1339 |
+
|
| 1340 |
+
shutil.copytree(model_dir, destination_path)
|
| 1341 |
+
training_run.model_path = destination_path
|
| 1342 |
+
db.session.commit()
|
| 1343 |
+
|
| 1344 |
+
logger.info(f"Model imported successfully as run {training_run.id}")
|
| 1345 |
+
|
| 1346 |
+
return jsonify({
|
| 1347 |
+
'success': True,
|
| 1348 |
+
'run_id': training_run.id,
|
| 1349 |
+
'message': f'Model imported successfully as run #{training_run.id}',
|
| 1350 |
+
'model_info': model_info
|
| 1351 |
+
})
|
| 1352 |
+
|
| 1353 |
+
except zipfile.BadZipFile:
|
| 1354 |
+
return jsonify({'success': False, 'error': 'Invalid ZIP file'}), 400
|
| 1355 |
+
except Exception as e:
|
| 1356 |
+
db.session.rollback()
|
| 1357 |
+
logger.error(f"Error importing model: {str(e)}")
|
| 1358 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
app/templates/admin/training.html
CHANGED
|
@@ -76,6 +76,29 @@
|
|
| 76 |
{% endif %}
|
| 77 |
</div>
|
| 78 |
<div class="card-body">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
<!-- Import Training Dataset Section -->
|
| 80 |
<div class="mb-4">
|
| 81 |
<h6><i class="bi bi-upload"></i> Import Training Dataset</h6>
|
|
@@ -83,7 +106,19 @@
|
|
| 83 |
<div class="input-group">
|
| 84 |
<input type="file" class="form-control" id="trainingDatasetFile" accept=".json">
|
| 85 |
<button class="btn btn-outline-secondary" type="button" onclick="importTrainingDataset()">
|
| 86 |
-
<i class="bi bi-cloud-upload"></i> Import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
</button>
|
| 88 |
</div>
|
| 89 |
</div>
|
|
@@ -123,23 +158,21 @@
|
|
| 123 |
</div>
|
| 124 |
|
| 125 |
<div class="row mb-3">
|
| 126 |
-
<div class="col-md-
|
| 127 |
-
<label class="form-label">
|
| 128 |
-
|
| 129 |
-
<
|
| 130 |
-
|
| 131 |
-
</button>
|
| 132 |
-
</label>
|
| 133 |
-
<select class="form-select" id="loraRank" onchange="checkCustomLoraRank()">
|
| 134 |
-
<option value="8">8 (Fast, less capacity)</option>
|
| 135 |
-
<option value="16" selected>16 (Balanced)</option>
|
| 136 |
-
<option value="32">32 (Slow, more capacity)</option>
|
| 137 |
-
<option value="custom">Custom...</option>
|
| 138 |
</select>
|
| 139 |
-
<
|
| 140 |
-
|
| 141 |
-
|
|
|
|
| 142 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
<div class="col-md-4">
|
| 144 |
<label class="form-label">
|
| 145 |
Learning Rate
|
|
@@ -174,9 +207,6 @@
|
|
| 174 |
style="display: none;" placeholder="Enter custom epochs (1-20)"
|
| 175 |
min="1" max="20" value="3">
|
| 176 |
</div>
|
| 177 |
-
</div>
|
| 178 |
-
|
| 179 |
-
<div class="row mb-3">
|
| 180 |
<div class="col-md-4">
|
| 181 |
<label class="form-label">Batch Size</label>
|
| 182 |
<select class="form-select" id="batchSize">
|
|
@@ -185,6 +215,28 @@
|
|
| 185 |
<option value="16">16 (High memory)</option>
|
| 186 |
</select>
|
| 187 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
<div class="col-md-4">
|
| 189 |
<label class="form-label">LoRA Alpha</label>
|
| 190 |
<input type="number" class="form-control" id="loraAlpha" value="32" min="8" max="128" step="8">
|
|
@@ -196,6 +248,8 @@
|
|
| 196 |
<small class="text-muted">Regularization (0.0-0.5)</small>
|
| 197 |
</div>
|
| 198 |
</div>
|
|
|
|
|
|
|
| 199 |
|
| 200 |
<div class="d-grid gap-2">
|
| 201 |
<button type="button" class="btn btn-primary btn-lg" onclick="startTraining()">
|
|
@@ -283,6 +337,16 @@
|
|
| 283 |
<button class="btn btn-sm btn-info" onclick="viewRunDetails({{ run.id }})">
|
| 284 |
<i class="bi bi-eye"></i> Details
|
| 285 |
</button>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
</td>
|
| 287 |
</tr>
|
| 288 |
{% endfor %}
|
|
@@ -497,24 +561,158 @@ function importTrainingDataset() {
|
|
| 497 |
});
|
| 498 |
}
|
| 499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
// Start training function
|
| 501 |
function startTraining() {
|
| 502 |
if (!confirm('Start fine-tuning the model? This will take several minutes.')) {
|
| 503 |
return;
|
| 504 |
}
|
| 505 |
|
|
|
|
| 506 |
const config = {
|
| 507 |
train_split: parseInt(document.getElementById('trainSplit').value) / 100,
|
| 508 |
val_split: parseInt(document.getElementById('valSplit').value) / 100,
|
| 509 |
test_split: parseInt(document.getElementById('testSplit').value) / 100,
|
| 510 |
-
|
| 511 |
-
lora_alpha: parseInt(document.getElementById('loraAlpha').value),
|
| 512 |
-
lora_dropout: parseFloat(document.getElementById('loraDropout').value),
|
| 513 |
learning_rate: getLearningRate(),
|
| 514 |
num_epochs: getNumEpochs(),
|
| 515 |
batch_size: parseInt(document.getElementById('batchSize').value)
|
| 516 |
};
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
// Show progress modal
|
| 519 |
const progressModal = new bootstrap.Modal(document.getElementById('trainingProgressModal'));
|
| 520 |
progressModal.show();
|
|
@@ -610,22 +808,67 @@ function rollbackModel() {
|
|
| 610 |
});
|
| 611 |
}
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
// View run details
|
| 614 |
function viewRunDetails(runId) {
|
| 615 |
fetch(`{{ url_for("admin.get_run_details", run_id=0) }}`.replace('/0', `/${runId}`))
|
| 616 |
.then(response => response.json())
|
| 617 |
.then(data => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
const content = `
|
| 619 |
<div class="row">
|
| 620 |
<div class="col-md-6">
|
| 621 |
<h6>Training Configuration</h6>
|
| 622 |
<ul class="list-group">
|
| 623 |
-
|
| 624 |
-
<li class="list-group-item"><strong>Learning Rate:</strong> ${data.config.learning_rate}</li>
|
| 625 |
-
<li class="list-group-item"><strong>Epochs:</strong> ${data.config.num_epochs}</li>
|
| 626 |
-
<li class="list-group-item"><strong>Training Examples:</strong> ${data.num_training_examples}</li>
|
| 627 |
-
<li class="list-group-item"><strong>Validation Examples:</strong> ${data.num_validation_examples}</li>
|
| 628 |
-
<li class="list-group-item"><strong>Test Examples:</strong> ${data.num_test_examples}</li>
|
| 629 |
</ul>
|
| 630 |
</div>
|
| 631 |
<div class="col-md-6">
|
|
@@ -646,6 +889,9 @@ function viewRunDetails(runId) {
|
|
| 646 |
document.getElementById('runDetailsContent').innerHTML = content;
|
| 647 |
const modal = new bootstrap.Modal(document.getElementById('runDetailsModal'));
|
| 648 |
modal.show();
|
|
|
|
|
|
|
|
|
|
| 649 |
});
|
| 650 |
}
|
| 651 |
</script>
|
|
|
|
| 76 |
{% endif %}
|
| 77 |
</div>
|
| 78 |
<div class="card-body">
|
| 79 |
+
<!-- Zero-Shot Model Selection Section -->
|
| 80 |
+
<div class="mb-4 pb-3 border-bottom">
|
| 81 |
+
<h6><i class="bi bi-magic"></i> Zero-Shot Classification Model</h6>
|
| 82 |
+
<p class="text-muted small">Select which model to use for classifying submissions (before fine-tuning)</p>
|
| 83 |
+
<div class="row align-items-end">
|
| 84 |
+
<div class="col-md-6">
|
| 85 |
+
<label class="form-label">Active Model</label>
|
| 86 |
+
<select class="form-select" id="zeroShotModelSelect" onchange="changeZeroShotModel()">
|
| 87 |
+
<option value="bart-large-mnli">BART-large-MNLI (400M) - Current Default</option>
|
| 88 |
+
<option value="deberta-v3-base-mnli">DeBERTa-v3-base-MNLI (86M) - Fast & Accurate</option>
|
| 89 |
+
<option value="distilbart-mnli">DistilBART-MNLI (134M) - Balanced</option>
|
| 90 |
+
</select>
|
| 91 |
+
</div>
|
| 92 |
+
<div class="col-md-6">
|
| 93 |
+
<div id="zeroShotModelInfo" class="alert alert-info mb-0" role="alert">
|
| 94 |
+
<small id="zeroShotModelDescription">
|
| 95 |
+
Loading model info...
|
| 96 |
+
</small>
|
| 97 |
+
</div>
|
| 98 |
+
</div>
|
| 99 |
+
</div>
|
| 100 |
+
</div>
|
| 101 |
+
|
| 102 |
<!-- Import Training Dataset Section -->
|
| 103 |
<div class="mb-4">
|
| 104 |
<h6><i class="bi bi-upload"></i> Import Training Dataset</h6>
|
|
|
|
| 106 |
<div class="input-group">
|
| 107 |
<input type="file" class="form-control" id="trainingDatasetFile" accept=".json">
|
| 108 |
<button class="btn btn-outline-secondary" type="button" onclick="importTrainingDataset()">
|
| 109 |
+
<i class="bi bi-cloud-upload"></i> Import Dataset
|
| 110 |
+
</button>
|
| 111 |
+
</div>
|
| 112 |
+
</div>
|
| 113 |
+
|
| 114 |
+
<!-- Import Fine-Tuned Model Section -->
|
| 115 |
+
<div class="mb-4">
|
| 116 |
+
<h6><i class="bi bi-box-arrow-in-down"></i> Import Fine-Tuned Model</h6>
|
| 117 |
+
<p class="text-muted small">Upload a previously exported model ZIP file to use it in this system</p>
|
| 118 |
+
<div class="input-group">
|
| 119 |
+
<input type="file" class="form-control" id="importModelFile" accept=".zip">
|
| 120 |
+
<button class="btn btn-outline-primary" type="button" onclick="importModel()">
|
| 121 |
+
<i class="bi bi-download"></i> Import Model
|
| 122 |
</button>
|
| 123 |
</div>
|
| 124 |
</div>
|
|
|
|
| 158 |
</div>
|
| 159 |
|
| 160 |
<div class="row mb-3">
|
| 161 |
+
<div class="col-md-12">
|
| 162 |
+
<label class="form-label">Training Mode</label>
|
| 163 |
+
<select class="form-select" id="trainingMode" onchange="updateTrainingModeUI()">
|
| 164 |
+
<option value="head_only">Classification Head Only (Recommended for small datasets)</option>
|
| 165 |
+
<option value="lora">LoRA Fine-Tuning (For larger datasets)</option>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
</select>
|
| 167 |
+
<p class="text-muted small mt-1">
|
| 168 |
+
<strong>Head Only:</strong> Faster, better for <100 examples. Only trains the output layer.<br>
|
| 169 |
+
<strong>LoRA:</strong> Slower, better for >100 examples. Trains adapter layers throughout the model.
|
| 170 |
+
</p>
|
| 171 |
</div>
|
| 172 |
+
</div>
|
| 173 |
+
|
| 174 |
+
<!-- Common Settings (visible for both modes) -->
|
| 175 |
+
<div class="row mb-3">
|
| 176 |
<div class="col-md-4">
|
| 177 |
<label class="form-label">
|
| 178 |
Learning Rate
|
|
|
|
| 207 |
style="display: none;" placeholder="Enter custom epochs (1-20)"
|
| 208 |
min="1" max="20" value="3">
|
| 209 |
</div>
|
|
|
|
|
|
|
|
|
|
| 210 |
<div class="col-md-4">
|
| 211 |
<label class="form-label">Batch Size</label>
|
| 212 |
<select class="form-select" id="batchSize">
|
|
|
|
| 215 |
<option value="16">16 (High memory)</option>
|
| 216 |
</select>
|
| 217 |
</div>
|
| 218 |
+
</div>
|
| 219 |
+
|
| 220 |
+
<!-- LoRA-specific Settings (only visible in LoRA mode) -->
|
| 221 |
+
<div id="loraSettings">
|
| 222 |
+
<div class="row mb-3">
|
| 223 |
+
<div class="col-md-4">
|
| 224 |
+
<label class="form-label">
|
| 225 |
+
LoRA Rank
|
| 226 |
+
<button type="button" class="btn btn-sm btn-link p-0" onclick="toggleCustomLoraRank()">
|
| 227 |
+
<i class="bi bi-pencil-square"></i>
|
| 228 |
+
</button>
|
| 229 |
+
</label>
|
| 230 |
+
<select class="form-select" id="loraRank" onchange="checkCustomLoraRank()">
|
| 231 |
+
<option value="8">8 (Fast, less capacity)</option>
|
| 232 |
+
<option value="16" selected>16 (Balanced)</option>
|
| 233 |
+
<option value="32">32 (Slow, more capacity)</option>
|
| 234 |
+
<option value="custom">Custom...</option>
|
| 235 |
+
</select>
|
| 236 |
+
<input type="number" class="form-control mt-2" id="customLoraRank"
|
| 237 |
+
style="display: none;" placeholder="Enter custom rank (4-64)"
|
| 238 |
+
min="4" max="64" step="4" value="16">
|
| 239 |
+
</div>
|
| 240 |
<div class="col-md-4">
|
| 241 |
<label class="form-label">LoRA Alpha</label>
|
| 242 |
<input type="number" class="form-control" id="loraAlpha" value="32" min="8" max="128" step="8">
|
|
|
|
| 248 |
<small class="text-muted">Regularization (0.0-0.5)</small>
|
| 249 |
</div>
|
| 250 |
</div>
|
| 251 |
+
</div><!-- End loraSettings -->
|
| 252 |
+
|
| 253 |
|
| 254 |
<div class="d-grid gap-2">
|
| 255 |
<button type="button" class="btn btn-primary btn-lg" onclick="startTraining()">
|
|
|
|
| 337 |
<button class="btn btn-sm btn-info" onclick="viewRunDetails({{ run.id }})">
|
| 338 |
<i class="bi bi-eye"></i> Details
|
| 339 |
</button>
|
| 340 |
+
{% if run.status == 'completed' %}
|
| 341 |
+
<a href="{{ url_for('admin.export_model', run_id=run.id) }}" class="btn btn-sm btn-success" download>
|
| 342 |
+
<i class="bi bi-download"></i> Export
|
| 343 |
+
</a>
|
| 344 |
+
{% endif %}
|
| 345 |
+
{% if not run.is_active_model and run.status != 'training' %}
|
| 346 |
+
<button class="btn btn-sm btn-danger" onclick="deleteRun({{ run.id }})">
|
| 347 |
+
<i class="bi bi-trash"></i> Delete
|
| 348 |
+
</button>
|
| 349 |
+
{% endif %}
|
| 350 |
</td>
|
| 351 |
</tr>
|
| 352 |
{% endfor %}
|
|
|
|
| 561 |
});
|
| 562 |
}
|
| 563 |
|
| 564 |
+
// Import fine-tuned model function
|
| 565 |
+
function importModel() {
|
| 566 |
+
const fileInput = document.getElementById('importModelFile');
|
| 567 |
+
const file = fileInput.files[0];
|
| 568 |
+
|
| 569 |
+
if (!file) {
|
| 570 |
+
alert('Please select a model ZIP file to import');
|
| 571 |
+
return;
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
if (!file.name.endsWith('.zip')) {
|
| 575 |
+
alert('Please select a ZIP file');
|
| 576 |
+
return;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
if (!confirm('Import this fine-tuned model? It will be added to your training history and can be deployed.')) {
|
| 580 |
+
return;
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
const formData = new FormData();
|
| 584 |
+
formData.append('file', file);
|
| 585 |
+
|
| 586 |
+
// Show loading state
|
| 587 |
+
const button = event.target;
|
| 588 |
+
const originalText = button.innerHTML;
|
| 589 |
+
button.innerHTML = '<span class="spinner-border spinner-border-sm" role="status"></span> Importing...';
|
| 590 |
+
button.disabled = true;
|
| 591 |
+
|
| 592 |
+
fetch('{{ url_for("admin.import_model") }}', {
|
| 593 |
+
method: 'POST',
|
| 594 |
+
body: formData
|
| 595 |
+
})
|
| 596 |
+
.then(response => response.json())
|
| 597 |
+
.then(data => {
|
| 598 |
+
if (data.success) {
|
| 599 |
+
alert(`Successfully imported model as run #${data.run_id}!`);
|
| 600 |
+
location.reload();
|
| 601 |
+
} else {
|
| 602 |
+
alert('Error importing model: ' + data.error);
|
| 603 |
+
button.innerHTML = originalText;
|
| 604 |
+
button.disabled = false;
|
| 605 |
+
}
|
| 606 |
+
})
|
| 607 |
+
.catch(err => {
|
| 608 |
+
alert('Error: ' + err.message);
|
| 609 |
+
button.innerHTML = originalText;
|
| 610 |
+
button.disabled = false;
|
| 611 |
+
});
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
// Update UI based on training mode
|
| 615 |
+
function updateTrainingModeUI() {
|
| 616 |
+
const mode = document.getElementById('trainingMode').value;
|
| 617 |
+
const loraSettings = document.getElementById('loraSettings');
|
| 618 |
+
|
| 619 |
+
if (mode === 'head_only') {
|
| 620 |
+
loraSettings.style.display = 'none';
|
| 621 |
+
} else {
|
| 622 |
+
loraSettings.style.display = 'block';
|
| 623 |
+
}
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
// Initialize on page load
|
| 627 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 628 |
+
updateTrainingModeUI();
|
| 629 |
+
loadCurrentZeroShotModel();
|
| 630 |
+
});
|
| 631 |
+
|
| 632 |
+
// Load current zero-shot model
|
| 633 |
+
function loadCurrentZeroShotModel() {
|
| 634 |
+
fetch('{{ url_for("admin.get_zero_shot_model") }}')
|
| 635 |
+
.then(response => response.json())
|
| 636 |
+
.then(data => {
|
| 637 |
+
if (data.success) {
|
| 638 |
+
document.getElementById('zeroShotModelSelect').value = data.model_key;
|
| 639 |
+
updateZeroShotModelDescription(data.model_info);
|
| 640 |
+
}
|
| 641 |
+
})
|
| 642 |
+
.catch(err => {
|
| 643 |
+
console.error('Error loading zero-shot model:', err);
|
| 644 |
+
});
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
// Update zero-shot model description
|
| 648 |
+
function updateZeroShotModelDescription(modelInfo) {
|
| 649 |
+
const desc = document.getElementById('zeroShotModelDescription');
|
| 650 |
+
if (modelInfo) {
|
| 651 |
+
desc.innerHTML = `<strong>${modelInfo.size}</strong> parameters | Speed: <strong>${modelInfo.speed}</strong><br>${modelInfo.description}`;
|
| 652 |
+
}
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
// Change zero-shot model
|
| 656 |
+
function changeZeroShotModel() {
|
| 657 |
+
const modelKey = document.getElementById('zeroShotModelSelect').value;
|
| 658 |
+
|
| 659 |
+
if (!confirm(`Switch zero-shot model? This will reload the analyzer and may take a moment.`)) {
|
| 660 |
+
loadCurrentZeroShotModel(); // Revert selection
|
| 661 |
+
return;
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
const button = event.target;
|
| 665 |
+
const originalHtml = button.parentElement.innerHTML;
|
| 666 |
+
button.disabled = true;
|
| 667 |
+
|
| 668 |
+
fetch('{{ url_for("admin.set_zero_shot_model") }}', {
|
| 669 |
+
method: 'POST',
|
| 670 |
+
headers: {'Content-Type': 'application/json'},
|
| 671 |
+
body: JSON.stringify({model_key: modelKey})
|
| 672 |
+
})
|
| 673 |
+
.then(response => response.json())
|
| 674 |
+
.then(data => {
|
| 675 |
+
if (data.success) {
|
| 676 |
+
alert(`β Zero-shot model changed to ${data.model_name}!\n\nAll new classifications will use this model.`);
|
| 677 |
+
loadCurrentZeroShotModel(); // Refresh info
|
| 678 |
+
} else {
|
| 679 |
+
alert('Error changing model: ' + data.error);
|
| 680 |
+
loadCurrentZeroShotModel(); // Revert selection
|
| 681 |
+
}
|
| 682 |
+
})
|
| 683 |
+
.catch(err => {
|
| 684 |
+
alert('Error: ' + err.message);
|
| 685 |
+
loadCurrentZeroShotModel(); // Revert selection
|
| 686 |
+
})
|
| 687 |
+
.finally(() => {
|
| 688 |
+
button.disabled = false;
|
| 689 |
+
});
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
// Start training function
|
| 693 |
function startTraining() {
|
| 694 |
if (!confirm('Start fine-tuning the model? This will take several minutes.')) {
|
| 695 |
return;
|
| 696 |
}
|
| 697 |
|
| 698 |
+
const mode = document.getElementById('trainingMode').value;
|
| 699 |
const config = {
|
| 700 |
train_split: parseInt(document.getElementById('trainSplit').value) / 100,
|
| 701 |
val_split: parseInt(document.getElementById('valSplit').value) / 100,
|
| 702 |
test_split: parseInt(document.getElementById('testSplit').value) / 100,
|
| 703 |
+
training_mode: mode,
|
|
|
|
|
|
|
| 704 |
learning_rate: getLearningRate(),
|
| 705 |
num_epochs: getNumEpochs(),
|
| 706 |
batch_size: parseInt(document.getElementById('batchSize').value)
|
| 707 |
};
|
| 708 |
|
| 709 |
+
// Only include LoRA settings if in LoRA mode
|
| 710 |
+
if (mode === 'lora') {
|
| 711 |
+
config.lora_rank = getLoraRank();
|
| 712 |
+
config.lora_alpha = parseInt(document.getElementById('loraAlpha').value);
|
| 713 |
+
config.lora_dropout = parseFloat(document.getElementById('loraDropout').value);
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
// Show progress modal
|
| 717 |
const progressModal = new bootstrap.Modal(document.getElementById('trainingProgressModal'));
|
| 718 |
progressModal.show();
|
|
|
|
| 808 |
});
|
| 809 |
}
|
| 810 |
|
| 811 |
+
// Delete training run
|
| 812 |
+
function deleteRun(runId) {
|
| 813 |
+
if (!confirm('Delete this training run and all associated files? This action cannot be undone.')) {
|
| 814 |
+
return;
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
fetch(`{{ url_for("admin.delete_training_run", run_id=0) }}`.replace('/0', `/${runId}`), {
|
| 818 |
+
method: 'DELETE'
|
| 819 |
+
})
|
| 820 |
+
.then(response => response.json())
|
| 821 |
+
.then(data => {
|
| 822 |
+
if (data.success) {
|
| 823 |
+
alert('Training run deleted successfully');
|
| 824 |
+
location.reload();
|
| 825 |
+
} else {
|
| 826 |
+
alert('Error deleting run: ' + data.error);
|
| 827 |
+
}
|
| 828 |
+
})
|
| 829 |
+
.catch(err => {
|
| 830 |
+
alert('Error: ' + err.message);
|
| 831 |
+
});
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
// View run details
|
| 835 |
function viewRunDetails(runId) {
|
| 836 |
fetch(`{{ url_for("admin.get_run_details", run_id=0) }}`.replace('/0', `/${runId}`))
|
| 837 |
.then(response => response.json())
|
| 838 |
.then(data => {
|
| 839 |
+
const config = data.training_config || {};
|
| 840 |
+
const trainingMode = config.training_mode || 'lora';
|
| 841 |
+
const modeLabel = trainingMode === 'head_only' ? 'Classification Head Only' : 'LoRA Fine-Tuning';
|
| 842 |
+
|
| 843 |
+
// Build configuration list based on training mode
|
| 844 |
+
let configItems = `
|
| 845 |
+
<li class="list-group-item"><strong>Mode:</strong> ${modeLabel}</li>
|
| 846 |
+
<li class="list-group-item"><strong>Learning Rate:</strong> ${config.learning_rate || 'N/A'}</li>
|
| 847 |
+
<li class="list-group-item"><strong>Epochs:</strong> ${config.num_epochs || 'N/A'}</li>
|
| 848 |
+
<li class="list-group-item"><strong>Batch Size:</strong> ${config.batch_size || 'N/A'}</li>
|
| 849 |
+
`;
|
| 850 |
+
|
| 851 |
+
// Add LoRA-specific settings if applicable
|
| 852 |
+
if (trainingMode === 'lora') {
|
| 853 |
+
configItems += `
|
| 854 |
+
<li class="list-group-item"><strong>LoRA Rank:</strong> ${config.lora_rank || 'N/A'}</li>
|
| 855 |
+
<li class="list-group-item"><strong>LoRA Alpha:</strong> ${config.lora_alpha || 'N/A'}</li>
|
| 856 |
+
<li class="list-group-item"><strong>LoRA Dropout:</strong> ${config.lora_dropout || 'N/A'}</li>
|
| 857 |
+
`;
|
| 858 |
+
}
|
| 859 |
+
|
| 860 |
+
configItems += `
|
| 861 |
+
<li class="list-group-item"><strong>Training Examples:</strong> ${data.num_training_examples || 'N/A'}</li>
|
| 862 |
+
<li class="list-group-item"><strong>Validation Examples:</strong> ${data.num_validation_examples || 'N/A'}</li>
|
| 863 |
+
<li class="list-group-item"><strong>Test Examples:</strong> ${data.num_test_examples || 'N/A'}</li>
|
| 864 |
+
`;
|
| 865 |
+
|
| 866 |
const content = `
|
| 867 |
<div class="row">
|
| 868 |
<div class="col-md-6">
|
| 869 |
<h6>Training Configuration</h6>
|
| 870 |
<ul class="list-group">
|
| 871 |
+
${configItems}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 872 |
</ul>
|
| 873 |
</div>
|
| 874 |
<div class="col-md-6">
|
|
|
|
| 889 |
document.getElementById('runDetailsContent').innerHTML = content;
|
| 890 |
const modal = new bootstrap.Modal(document.getElementById('runDetailsModal'));
|
| 891 |
modal.show();
|
| 892 |
+
})
|
| 893 |
+
.catch(err => {
|
| 894 |
+
alert('Error loading run details: ' + err.message);
|
| 895 |
});
|
| 896 |
}
|
| 897 |
</script>
|
requirements.txt
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
Flask==3.0.0
|
| 2 |
Flask-SQLAlchemy==3.1.1
|
| 3 |
python-dotenv==1.0.0
|
| 4 |
-
transformers==4.
|
| 5 |
torch==2.5.0
|
| 6 |
sentencepiece>=0.2.0
|
| 7 |
gunicorn==21.2.0
|
| 8 |
|
| 9 |
# Fine-tuning dependencies
|
| 10 |
-
peft
|
| 11 |
datasets>=2.14.0
|
| 12 |
scikit-learn>=1.3.0
|
| 13 |
matplotlib>=3.7.0
|
|
|
|
| 1 |
Flask==3.0.0
|
| 2 |
Flask-SQLAlchemy==3.1.1
|
| 3 |
python-dotenv==1.0.0
|
| 4 |
+
transformers==4.46.0
|
| 5 |
torch==2.5.0
|
| 6 |
sentencepiece>=0.2.0
|
| 7 |
gunicorn==21.2.0
|
| 8 |
|
| 9 |
# Fine-tuning dependencies
|
| 10 |
+
peft==0.13.2
|
| 11 |
datasets>=2.14.0
|
| 12 |
scikit-learn>=1.3.0
|
| 13 |
matplotlib>=3.7.0
|