thadillo commited on
Commit
1377fb1
Β·
1 Parent(s): e6341fe

πŸš€ Deploy to HF Spaces: Model selection + Fine-tuning updates

Browse files

- Add model selection (7+ transformer models)
- Add zero-shot model selection (3 NLI models)
- Improve fine-tuning with head-only and LoRA modes
- Add training run management (export/delete)
- Configure for HF Spaces (port 7860, persistent storage)
- Update database schema for model tracking
- Add comprehensive AI model presets

.gitignore CHANGED
@@ -33,3 +33,6 @@ instance/
33
  # OS
34
  .DS_Store
35
  Thumbs.db
 
 
 
 
33
  # OS
34
  .DS_Store
35
  Thumbs.db
36
+
37
+ # Models
38
+ models/finetuned/
.hfignore ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ venv/
8
+ ENV/
9
+ env/
10
+ .venv
11
+
12
+ # Environment files
13
+ .env
14
+ .env.*
15
+
16
+ # Local development
17
+ instance/
18
+ *.db
19
+ *.sqlite
20
+ *.sqlite3
21
+
22
+ # Models and cache (will be generated on HF)
23
+ models/finetuned/*
24
+ .cache/
25
+ *.pth
26
+ *.bin
27
+ *.onnx
28
+
29
+ # Git
30
+ .git/
31
+ .gitignore
32
+ .gitattributes
33
+
34
+ # IDE
35
+ .vscode/
36
+ .idea/
37
+ *.swp
38
+ *.swo
39
+ *~
40
+
41
+ # OS
42
+ .DS_Store
43
+ Thumbs.db
44
+
45
+ # Documentation (keep only README.md)
46
+ DEPLOYMENT.md
47
+ QUICKSTART.md
48
+ PROJECT_STRUCTURE.md
49
+ MIGRATION_SUMMARY.md
50
+ Claude's Plan.md
51
+ AI_MODEL_COMPARISON.md
52
+ TRAINING_STRATEGY.md
53
+ ZERO_SHOT_MODEL_SELECTION.md
54
+ HF_DEPLOYMENT_CHECKLIST.md
55
+
56
+ # Test files
57
+ test_*.py
58
+ mock_data*.json
59
+
60
+ # Local-specific files
61
+ Dockerfile
62
+ docker-compose.yml
63
+ .dockerignore
64
+ gunicorn_config.py
65
+ run.py
66
+ start.sh
67
+
68
+ # Keep these for HF:
69
+ # - Dockerfile (will be copied from Dockerfile.hf)
70
+ # - README.md (will be copied from README_HF.md)
71
+ # - app_hf.py
72
+ # - wsgi.py
73
+ # - requirements.txt
74
+ # - app/ directory
75
+
app/analyzer.py CHANGED
@@ -90,7 +90,8 @@ class SubmissionAnalyzer:
90
  finetuned_path,
91
  num_labels=len(self.categories),
92
  id2label=self.id2label,
93
- label2id=self.label2id
 
94
  )
95
  self.model.eval()
96
  self.model_type = 'finetuned'
@@ -102,14 +103,23 @@ class SubmissionAnalyzer:
102
 
103
  # Load base zero-shot model
104
  try:
105
- logger.info("Loading base zero-shot classification model...")
 
 
 
 
 
 
 
 
106
  self.classifier = pipeline(
107
  "zero-shot-classification",
108
- model="facebook/bart-large-mnli",
109
  device=-1 # Use CPU (-1), change to 0 for GPU
110
  )
111
  self.model_type = 'base'
112
- logger.info("Base model loaded successfully!")
 
113
  except Exception as e:
114
  logger.error(f"Error loading model: {e}")
115
  raise
 
90
  finetuned_path,
91
  num_labels=len(self.categories),
92
  id2label=self.id2label,
93
+ label2id=self.label2id,
94
+ ignore_mismatched_sizes=True
95
  )
96
  self.model.eval()
97
  self.model_type = 'finetuned'
 
103
 
104
  # Load base zero-shot model
105
  try:
106
+ # Get selected zero-shot model from settings
107
+ from app.models.models import Settings
108
+ from app.fine_tuning.model_presets import get_model_preset
109
+
110
+ zero_shot_model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
111
+ model_preset = get_model_preset(zero_shot_model_key)
112
+ zero_shot_model_id = model_preset['model_id']
113
+
114
+ logger.info(f"Loading zero-shot classification model: {zero_shot_model_id}...")
115
  self.classifier = pipeline(
116
  "zero-shot-classification",
117
+ model=zero_shot_model_id,
118
  device=-1 # Use CPU (-1), change to 0 for GPU
119
  )
120
  self.model_type = 'base'
121
+ self.zero_shot_model_key = zero_shot_model_key
122
+ logger.info(f"Zero-shot model loaded successfully: {model_preset['name']}!")
123
  except Exception as e:
124
  logger.error(f"Error loading model: {e}")
125
  raise
app/fine_tuning/model_manager.py CHANGED
@@ -20,16 +20,27 @@ logger = logging.getLogger(__name__)
20
  class ModelManager:
21
  """Manage fine-tuned model deployment and versioning"""
22
 
23
- def __init__(self, models_dir: str = "/data/models/finetuned"):
24
  """
25
  Initialize ModelManager.
26
 
27
  Args:
28
  models_dir: Base directory for storing fine-tuned models
 
29
  """
 
 
 
 
30
  self.models_dir = models_dir
31
  self.base_model_name = "facebook/bart-large-mnli"
32
- os.makedirs(models_dir, exist_ok=True)
 
 
 
 
 
 
33
 
34
  def get_model_path(self, run_id: int) -> str:
35
  """Get path to model for a specific training run"""
@@ -56,7 +67,10 @@ class ModelManager:
56
  model_name = model_path
57
 
58
  tokenizer = AutoTokenizer.from_pretrained(model_name)
59
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
60
 
61
  return model, tokenizer
62
 
 
20
  class ModelManager:
21
  """Manage fine-tuned model deployment and versioning"""
22
 
23
+ def __init__(self, models_dir: str = None):
24
  """
25
  Initialize ModelManager.
26
 
27
  Args:
28
  models_dir: Base directory for storing fine-tuned models
29
+ (defaults to MODELS_DIR env var or './models/finetuned')
30
  """
31
+ if models_dir is None:
32
+ # Use environment variable or local path
33
+ models_dir = os.getenv('MODELS_DIR', 'models/finetuned')
34
+
35
  self.models_dir = models_dir
36
  self.base_model_name = "facebook/bart-large-mnli"
37
+
38
+ # Create directory if it doesn't exist
39
+ try:
40
+ os.makedirs(models_dir, exist_ok=True)
41
+ except PermissionError:
42
+ logger.error(f"Permission denied creating models directory: {models_dir}")
43
+ raise
44
 
45
  def get_model_path(self, run_id: int) -> str:
46
  """Get path to model for a specific training run"""
 
67
  model_name = model_path
68
 
69
  tokenizer = AutoTokenizer.from_pretrained(model_name)
70
+ model = AutoModelForSequenceClassification.from_pretrained(
71
+ model_name,
72
+ ignore_mismatched_sizes=True
73
+ )
74
 
75
  return model, tokenizer
76
 
app/fine_tuning/model_presets.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model presets for both fine-tuning and zero-shot classification.
3
+ Provides configuration for various HuggingFace models optimized for text classification.
4
+ """
5
+
6
+ MODEL_PRESETS = {
7
+ # Zero-shot capable models (NLI-trained)
8
+ 'bart-large-mnli': {
9
+ 'name': 'BART-large-MNLI',
10
+ 'model_id': 'facebook/bart-large-mnli',
11
+ 'max_length': 1024,
12
+ 'size': '400M',
13
+ 'speed': 'Slow',
14
+ 'best_for': 'Zero-shot + Fine-tuning',
15
+ 'description': 'Large sequence-to-sequence model, excellent zero-shot performance',
16
+ 'recommended_lr': 2e-5,
17
+ 'recommended_batch': 4,
18
+ 'supports_zero_shot': True
19
+ },
20
+ 'deberta-v3-base-mnli': {
21
+ 'name': 'DeBERTa-v3-base-MNLI',
22
+ 'model_id': 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli',
23
+ 'max_length': 512,
24
+ 'size': '86M',
25
+ 'speed': 'Fast',
26
+ 'best_for': 'Fast zero-shot classification',
27
+ 'description': 'DeBERTa trained on NLI datasets, excellent zero-shot with better speed',
28
+ 'recommended_lr': 2e-5,
29
+ 'recommended_batch': 8,
30
+ 'supports_zero_shot': True
31
+ },
32
+ 'distilbart-mnli': {
33
+ 'name': 'DistilBART-MNLI',
34
+ 'model_id': 'valhalla/distilbart-mnli-12-3',
35
+ 'max_length': 1024,
36
+ 'size': '134M',
37
+ 'speed': 'Medium',
38
+ 'best_for': 'Balanced zero-shot',
39
+ 'description': 'Distilled BART for zero-shot, good balance of speed and accuracy',
40
+ 'recommended_lr': 2e-5,
41
+ 'recommended_batch': 8,
42
+ 'supports_zero_shot': True
43
+ },
44
+
45
+ # Fine-tuning only models
46
+ 'deberta-v3-small': {
47
+ 'name': 'DeBERTa-v3-small',
48
+ 'model_id': 'microsoft/deberta-v3-small',
49
+ 'max_length': 512,
50
+ 'size': '44M',
51
+ 'speed': 'Very Fast',
52
+ 'best_for': 'Fine-tuning with small datasets',
53
+ 'description': 'State-of-the-art efficient model, excellent for small datasets',
54
+ 'recommended_lr': 3e-5,
55
+ 'recommended_batch': 8,
56
+ 'supports_zero_shot': False
57
+ },
58
+ 'deberta-v3-base': {
59
+ 'name': 'DeBERTa-v3-base',
60
+ 'model_id': 'microsoft/deberta-v3-base',
61
+ 'max_length': 512,
62
+ 'size': '86M',
63
+ 'speed': 'Fast',
64
+ 'best_for': 'High accuracy fine-tuning',
65
+ 'description': 'Larger DeBERTa model with better accuracy',
66
+ 'recommended_lr': 2e-5,
67
+ 'recommended_batch': 8,
68
+ 'supports_zero_shot': False
69
+ },
70
+ 'distilbert-base': {
71
+ 'name': 'DistilBERT-base',
72
+ 'model_id': 'distilbert-base-uncased',
73
+ 'max_length': 512,
74
+ 'size': '66M',
75
+ 'speed': 'Fast',
76
+ 'best_for': 'Balanced speed and accuracy',
77
+ 'description': 'Distilled BERT, 60% faster with 97% performance retention',
78
+ 'recommended_lr': 5e-5,
79
+ 'recommended_batch': 8,
80
+ 'supports_zero_shot': False
81
+ },
82
+ 'roberta-base': {
83
+ 'name': 'RoBERTa-base',
84
+ 'model_id': 'roberta-base',
85
+ 'max_length': 512,
86
+ 'size': '125M',
87
+ 'speed': 'Medium',
88
+ 'best_for': 'Maximum accuracy',
89
+ 'description': 'Robustly optimized BERT, excellent classification performance',
90
+ 'recommended_lr': 2e-5,
91
+ 'recommended_batch': 8,
92
+ 'supports_zero_shot': False
93
+ },
94
+ 'electra-small': {
95
+ 'name': 'ELECTRA-small',
96
+ 'model_id': 'google/electra-small-discriminator',
97
+ 'max_length': 512,
98
+ 'size': '14M',
99
+ 'speed': 'Fastest',
100
+ 'best_for': 'Speed-critical applications',
101
+ 'description': 'Very fast and lightweight, good for production',
102
+ 'recommended_lr': 5e-5,
103
+ 'recommended_batch': 16,
104
+ 'supports_zero_shot': False
105
+ },
106
+ 'minilm': {
107
+ 'name': 'MiniLM-L12',
108
+ 'model_id': 'microsoft/MiniLM-L12-H384-uncased',
109
+ 'max_length': 512,
110
+ 'size': '33M',
111
+ 'speed': 'Very Fast',
112
+ 'best_for': 'Lightweight production deployment',
113
+ 'description': 'Compact model optimized for speed',
114
+ 'recommended_lr': 4e-5,
115
+ 'recommended_batch': 12,
116
+ 'supports_zero_shot': False
117
+ }
118
+ }
119
+
120
+ def get_model_preset(preset_key):
121
+ """Get model preset configuration by key."""
122
+ return MODEL_PRESETS.get(preset_key, MODEL_PRESETS['bart-large-mnli'])
123
+
124
+ def get_available_models():
125
+ """Get list of all available models for selection."""
126
+ return [
127
+ {
128
+ 'key': key,
129
+ 'name': config['name'],
130
+ 'size': config['size'],
131
+ 'speed': config['speed'],
132
+ 'best_for': config['best_for'],
133
+ 'supports_zero_shot': config['supports_zero_shot']
134
+ }
135
+ for key, config in MODEL_PRESETS.items()
136
+ ]
137
+
138
+ def get_zero_shot_models():
139
+ """Get list of models that support zero-shot classification."""
140
+ return [
141
+ {
142
+ 'key': key,
143
+ 'name': config['name'],
144
+ 'model_id': config['model_id'],
145
+ 'size': config['size'],
146
+ 'speed': config['speed'],
147
+ 'description': config['description']
148
+ }
149
+ for key, config in MODEL_PRESETS.items()
150
+ if config.get('supports_zero_shot', False)
151
+ ]
152
+
153
+ def get_recommended_hyperparams(preset_key, training_mode='lora'):
154
+ """Get recommended hyperparameters for a model preset."""
155
+ preset = get_model_preset(preset_key)
156
+
157
+ base_params = {
158
+ 'learning_rate': preset['recommended_lr'],
159
+ 'batch_size': preset['recommended_batch'],
160
+ 'max_length': preset['max_length']
161
+ }
162
+
163
+ if training_mode == 'head_only':
164
+ # Higher learning rate for head-only training
165
+ base_params['learning_rate'] = preset['recommended_lr'] * 2
166
+
167
+ return base_params
168
+
app/fine_tuning/trainer.py CHANGED
@@ -75,12 +75,27 @@ class BARTFineTuner:
75
  # Validate splits
76
  assert abs(train_split + val_split + test_split - 1.0) < 0.01, "Splits must sum to 1.0"
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # First split: separate test set
79
  train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
80
  texts, labels,
81
  test_size=test_split,
82
  random_state=random_state,
83
- stratify=labels # Ensure balanced splits
84
  )
85
 
86
  # Second split: separate train and validation
@@ -89,7 +104,7 @@ class BARTFineTuner:
89
  train_val_texts, train_val_labels,
90
  test_size=val_size_adjusted,
91
  random_state=random_state,
92
- stratify=train_val_labels
93
  )
94
 
95
  # Tokenize datasets
@@ -126,6 +141,36 @@ class BARTFineTuner:
126
 
127
  return Dataset.from_dict(dataset_dict)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def setup_lora_model(self, lora_config: Dict) -> None:
130
  """
131
  Set up BART model with LoRA adapters.
@@ -145,7 +190,8 @@ class BARTFineTuner:
145
  num_labels=len(self.categories),
146
  id2label=self.id2label,
147
  label2id=self.label2id,
148
- problem_type="single_label_classification"
 
149
  )
150
 
151
  # Configure LoRA
@@ -193,6 +239,10 @@ class BARTFineTuner:
193
  # Create output directory
194
  os.makedirs(output_dir, exist_ok=True)
195
 
 
 
 
 
196
  # Training arguments
197
  training_args = TrainingArguments(
198
  output_dir=output_dir,
@@ -211,7 +261,8 @@ class BARTFineTuner:
211
  greater_is_better=False,
212
  save_total_limit=2,
213
  report_to="none", # Disable wandb, tensorboard
214
- fp16=torch.cuda.is_available(), # Use mixed precision if GPU available
 
215
  )
216
 
217
  # Trainer
@@ -268,7 +319,8 @@ class BARTFineTuner:
268
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
269
  self.model = AutoModelForSequenceClassification.from_pretrained(
270
  model_path,
271
- num_labels=len(self.categories)
 
272
  )
273
 
274
  # Make predictions
@@ -281,11 +333,24 @@ class BARTFineTuner:
281
 
282
  with torch.no_grad():
283
  for i in range(len(test_dataset)):
284
- batch = {k: test_dataset[i][k].unsqueeze(0).to(device) for k in ['input_ids', 'attention_mask']}
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  outputs = self.model(**batch)
286
  pred = torch.argmax(outputs.logits, dim=1).item()
287
  predictions.append(pred)
288
- true_labels.append(test_dataset[i]['labels'].item())
289
 
290
  # Calculate metrics
291
  accuracy = accuracy_score(true_labels, predictions)
 
75
  # Validate splits
76
  assert abs(train_split + val_split + test_split - 1.0) < 0.01, "Splits must sum to 1.0"
77
 
78
+ num_classes = len(self.categories)
79
+ total_examples = len(texts)
80
+
81
+ # Calculate minimum examples needed for stratified split
82
+ # Need at least num_classes examples in each split
83
+ min_test_size = int(total_examples * test_split)
84
+ min_val_size = int(total_examples * val_split)
85
+
86
+ # Check if we have enough examples for stratification
87
+ use_stratify = (min_test_size >= num_classes and min_val_size >= num_classes)
88
+
89
+ if not use_stratify:
90
+ logger.warning(f"Dataset too small ({total_examples} examples) for stratified split. "
91
+ f"Using random split instead.")
92
+
93
  # First split: separate test set
94
  train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
95
  texts, labels,
96
  test_size=test_split,
97
  random_state=random_state,
98
+ stratify=labels if use_stratify else None
99
  )
100
 
101
  # Second split: separate train and validation
 
104
  train_val_texts, train_val_labels,
105
  test_size=val_size_adjusted,
106
  random_state=random_state,
107
+ stratify=train_val_labels if use_stratify else None
108
  )
109
 
110
  # Tokenize datasets
 
141
 
142
  return Dataset.from_dict(dataset_dict)
143
 
144
+ def setup_head_only_model(self) -> None:
145
+ """
146
+ Set up BART model for classification head-only fine-tuning.
147
+ Freezes the encoder and only trains the classification head.
148
+ Better for small datasets (<100 examples).
149
+ """
150
+ logger.info("Setting up BART model for head-only training")
151
+
152
+ # Load base model
153
+ self.model = AutoModelForSequenceClassification.from_pretrained(
154
+ self.base_model_name,
155
+ num_labels=len(self.categories),
156
+ id2label=self.id2label,
157
+ label2id=self.label2id,
158
+ problem_type="single_label_classification",
159
+ ignore_mismatched_sizes=True
160
+ )
161
+
162
+ # Freeze all parameters except classification head
163
+ for name, param in self.model.named_parameters():
164
+ if 'classification_head' in name or 'classifier' in name:
165
+ param.requires_grad = True
166
+ else:
167
+ param.requires_grad = False
168
+
169
+ # Count trainable parameters
170
+ trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
171
+ total = sum(p.numel() for p in self.model.parameters())
172
+ logger.info(f"Trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")
173
+
174
  def setup_lora_model(self, lora_config: Dict) -> None:
175
  """
176
  Set up BART model with LoRA adapters.
 
190
  num_labels=len(self.categories),
191
  id2label=self.id2label,
192
  label2id=self.label2id,
193
+ problem_type="single_label_classification",
194
+ ignore_mismatched_sizes=True # BART-MNLI has 3 classes, we need 6
195
  )
196
 
197
  # Configure LoRA
 
239
  # Create output directory
240
  os.makedirs(output_dir, exist_ok=True)
241
 
242
+ # Force CPU training to avoid cuDNN compatibility issues on WSL2
243
+ use_cuda = False
244
+ logger.info("Using CPU for training (CUDA disabled to avoid compatibility issues)")
245
+
246
  # Training arguments
247
  training_args = TrainingArguments(
248
  output_dir=output_dir,
 
261
  greater_is_better=False,
262
  save_total_limit=2,
263
  report_to="none", # Disable wandb, tensorboard
264
+ use_cpu=not use_cuda, # Use CPU if CUDA test fails
265
+ fp16=use_cuda, # Only use mixed precision with working CUDA
266
  )
267
 
268
  # Trainer
 
319
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
320
  self.model = AutoModelForSequenceClassification.from_pretrained(
321
  model_path,
322
+ num_labels=len(self.categories),
323
+ ignore_mismatched_sizes=True
324
  )
325
 
326
  # Make predictions
 
333
 
334
  with torch.no_grad():
335
  for i in range(len(test_dataset)):
336
+ # Get the data - handle both tensor and list formats
337
+ item = test_dataset[i]
338
+
339
+ # Convert to tensors if needed
340
+ input_ids = torch.tensor(item['input_ids']) if isinstance(item['input_ids'], list) else item['input_ids']
341
+ attention_mask = torch.tensor(item['attention_mask']) if isinstance(item['attention_mask'], list) else item['attention_mask']
342
+ label = torch.tensor(item['labels']) if isinstance(item['labels'], list) else item['labels']
343
+
344
+ # Create batch
345
+ batch = {
346
+ 'input_ids': input_ids.unsqueeze(0).to(device),
347
+ 'attention_mask': attention_mask.unsqueeze(0).to(device)
348
+ }
349
+
350
  outputs = self.model(**batch)
351
  pred = torch.argmax(outputs.logits, dim=1).item()
352
  predictions.append(pred)
353
+ true_labels.append(label.item() if isinstance(label, torch.Tensor) else label)
354
 
355
  # Calculate metrics
356
  accuracy = accuracy_score(true_labels, predictions)
app/models/models.py CHANGED
@@ -51,7 +51,7 @@ class Settings(db.Model):
51
 
52
  id = db.Column(db.Integer, primary_key=True)
53
  key = db.Column(db.String(50), unique=True, nullable=False)
54
- value = db.Column(db.String(10), nullable=False) # 'true' or 'false'
55
 
56
  @staticmethod
57
  def get_setting(key, default='true'):
 
51
 
52
  id = db.Column(db.Integer, primary_key=True)
53
  key = db.Column(db.String(50), unique=True, nullable=False)
54
+ value = db.Column(db.String(100), nullable=False) # Increased to support model IDs
55
 
56
  @staticmethod
57
  def get_setting(key, default='true'):
app/routes/admin.py CHANGED
@@ -3,11 +3,15 @@ from app.models.models import Token, Submission, Settings, TrainingExample, Fine
3
  from app import db
4
  from app.analyzer import get_analyzer
5
  from functools import wraps
 
6
  import json
7
  import csv
8
  import io
9
  from datetime import datetime
10
  import os
 
 
 
11
 
12
  bp = Blueprint('admin', __name__, url_prefix='/admin')
13
 
@@ -801,20 +805,27 @@ def _run_training_job(run_id: int, config: Dict):
801
  test_split=config.get('test_split', 0.15)
802
  )
803
 
804
- # Setup LoRA model
805
- lora_config = {
806
- 'r': config.get('lora_rank', 16),
807
- 'lora_alpha': config.get('lora_alpha', 32),
808
- 'lora_dropout': config.get('lora_dropout', 0.1)
809
- }
810
- trainer.setup_lora_model(lora_config)
 
 
 
 
 
 
 
811
 
812
  # Update status to training
813
  run.status = 'training'
814
  db.session.commit()
815
 
816
  # Train
817
- models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned')
818
  output_dir = os.path.join(models_dir, f'run_{run_id}')
819
 
820
  training_config = {
@@ -889,9 +900,14 @@ def get_training_status(run_id):
889
  elif run.status == 'failed':
890
  progress = 0
891
 
 
 
 
 
 
892
  status_messages = {
893
  'preparing': 'Preparing training data...',
894
- 'training': 'Training model with LoRA...',
895
  'evaluating': 'Evaluating model performance...',
896
  'completed': 'Training completed successfully!',
897
  'failed': 'Training failed'
@@ -971,3 +987,372 @@ def get_run_details(run_id):
971
  run = FineTuningRun.query.get_or_404(run_id)
972
 
973
  return jsonify(run.to_dict())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from app import db
4
  from app.analyzer import get_analyzer
5
  from functools import wraps
6
+ from typing import Dict
7
  import json
8
  import csv
9
  import io
10
  from datetime import datetime
11
  import os
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
 
16
  bp = Blueprint('admin', __name__, url_prefix='/admin')
17
 
 
805
  test_split=config.get('test_split', 0.15)
806
  )
807
 
808
+ # Setup model based on training mode
809
+ training_mode = config.get('training_mode', 'head_only')
810
+
811
+ if training_mode == 'head_only':
812
+ # Head-only training (recommended for small datasets)
813
+ trainer.setup_head_only_model()
814
+ else:
815
+ # LoRA training
816
+ lora_config = {
817
+ 'r': config.get('lora_rank', 16),
818
+ 'lora_alpha': config.get('lora_alpha', 32),
819
+ 'lora_dropout': config.get('lora_dropout', 0.1)
820
+ }
821
+ trainer.setup_lora_model(lora_config)
822
 
823
  # Update status to training
824
  run.status = 'training'
825
  db.session.commit()
826
 
827
  # Train
828
+ models_dir = os.getenv('MODELS_DIR', 'models/finetuned')
829
  output_dir = os.path.join(models_dir, f'run_{run_id}')
830
 
831
  training_config = {
 
900
  elif run.status == 'failed':
901
  progress = 0
902
 
903
+ # Get training mode from config
904
+ config = run.get_config() if hasattr(run, 'get_config') else {}
905
+ training_mode = config.get('training_mode', 'lora')
906
+ mode_label = 'classification head only' if training_mode == 'head_only' else 'LoRA adapters'
907
+
908
  status_messages = {
909
  'preparing': 'Preparing training data...',
910
+ 'training': f'Training model ({mode_label})...',
911
  'evaluating': 'Evaluating model performance...',
912
  'completed': 'Training completed successfully!',
913
  'failed': 'Training failed'
 
987
  run = FineTuningRun.query.get_or_404(run_id)
988
 
989
  return jsonify(run.to_dict())
990
+
991
+
992
+ @bp.route('/api/set-zero-shot-model', methods=['POST'])
993
+ @admin_required
994
+ def set_zero_shot_model():
995
+ """Set the zero-shot model for classification"""
996
+ try:
997
+ from app.fine_tuning.model_presets import get_model_preset
998
+ from app.analyzer import reload_analyzer
999
+
1000
+ data = request.get_json()
1001
+ model_key = data.get('model_key')
1002
+
1003
+ if not model_key:
1004
+ return jsonify({'success': False, 'error': 'No model key provided'}), 400
1005
+
1006
+ # Validate model exists and supports zero-shot
1007
+ model_preset = get_model_preset(model_key)
1008
+ if not model_preset.get('supports_zero_shot', False):
1009
+ return jsonify({
1010
+ 'success': False,
1011
+ 'error': 'Selected model does not support zero-shot classification'
1012
+ }), 400
1013
+
1014
+ # Save setting
1015
+ Settings.set_setting('zero_shot_model', model_key)
1016
+
1017
+ # Reload analyzer with new model
1018
+ reload_analyzer()
1019
+
1020
+ logger.info(f"Zero-shot model changed to: {model_preset['name']}")
1021
+
1022
+ return jsonify({
1023
+ 'success': True,
1024
+ 'message': f"Zero-shot model changed to {model_preset['name']}",
1025
+ 'model_key': model_key,
1026
+ 'model_name': model_preset['name']
1027
+ })
1028
+
1029
+ except Exception as e:
1030
+ logger.error(f"Error changing zero-shot model: {str(e)}")
1031
+ return jsonify({'success': False, 'error': str(e)}), 500
1032
+
1033
+
1034
+ @bp.route('/api/get-zero-shot-model', methods=['GET'])
1035
+ @admin_required
1036
+ def get_zero_shot_model():
1037
+ """Get the current zero-shot model"""
1038
+ try:
1039
+ from app.fine_tuning.model_presets import get_model_preset
1040
+
1041
+ model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
1042
+ model_preset = get_model_preset(model_key)
1043
+
1044
+ return jsonify({
1045
+ 'success': True,
1046
+ 'model_key': model_key,
1047
+ 'model_name': model_preset['name'],
1048
+ 'model_info': {
1049
+ 'size': model_preset['size'],
1050
+ 'speed': model_preset['speed'],
1051
+ 'description': model_preset['description']
1052
+ }
1053
+ })
1054
+
1055
+ except Exception as e:
1056
+ logger.error(f"Error getting zero-shot model: {str(e)}")
1057
+ return jsonify({'success': False, 'error': str(e)}), 500
1058
+
1059
+
1060
+ @bp.route('/api/delete-training-run/<int:run_id>', methods=['DELETE'])
1061
+ @admin_required
1062
+ def delete_training_run(run_id):
1063
+ """Delete a training run and its associated files"""
1064
+ try:
1065
+ run = FineTuningRun.query.get_or_404(run_id)
1066
+
1067
+ # Prevent deletion of active model
1068
+ if run.is_active_model:
1069
+ return jsonify({
1070
+ 'success': False,
1071
+ 'error': 'Cannot delete the active model. Please rollback or deploy another model first.'
1072
+ }), 400
1073
+
1074
+ # Prevent deletion of currently training runs
1075
+ if run.status == 'training':
1076
+ return jsonify({
1077
+ 'success': False,
1078
+ 'error': 'Cannot delete a training run that is currently in progress.'
1079
+ }), 400
1080
+
1081
+ # Delete model files if they exist
1082
+ import shutil
1083
+ if run.model_path and os.path.exists(run.model_path):
1084
+ try:
1085
+ shutil.rmtree(run.model_path)
1086
+ logger.info(f"Deleted model files at {run.model_path}")
1087
+ except Exception as e:
1088
+ logger.error(f"Error deleting model files: {str(e)}")
1089
+ # Continue with database deletion even if file deletion fails
1090
+
1091
+ # Unlink training examples from this run (don't delete the examples themselves)
1092
+ for example in run.training_examples:
1093
+ example.training_run_id = None
1094
+ example.used_in_training = False
1095
+
1096
+ # Delete the training run from database
1097
+ db.session.delete(run)
1098
+ db.session.commit()
1099
+
1100
+ return jsonify({
1101
+ 'success': True,
1102
+ 'message': f'Training run #{run_id} deleted successfully'
1103
+ })
1104
+
1105
+ except Exception as e:
1106
+ db.session.rollback()
1107
+ logger.error(f"Error deleting training run: {str(e)}")
1108
+ return jsonify({'success': False, 'error': str(e)}), 500
1109
+
1110
+
1111
+ @bp.route('/api/export-model/<int:run_id>', methods=['GET'])
1112
+ @admin_required
1113
+ def export_model(run_id):
1114
+ """Export a trained model as a downloadable ZIP file"""
1115
+ try:
1116
+ import tempfile
1117
+ import shutil
1118
+ from datetime import datetime
1119
+
1120
+ run = FineTuningRun.query.get_or_404(run_id)
1121
+
1122
+ if run.status != 'completed':
1123
+ return jsonify({
1124
+ 'success': False,
1125
+ 'error': 'Can only export completed training runs'
1126
+ }), 400
1127
+
1128
+ if not run.model_path or not os.path.exists(run.model_path):
1129
+ return jsonify({
1130
+ 'success': False,
1131
+ 'error': 'Model files not found'
1132
+ }), 404
1133
+
1134
+ # Create temporary directory for export
1135
+ temp_dir = tempfile.mkdtemp()
1136
+ try:
1137
+ export_name = f"model_run_{run_id}"
1138
+ export_path = os.path.join(temp_dir, export_name)
1139
+
1140
+ # Copy model files
1141
+ shutil.copytree(run.model_path, export_path)
1142
+
1143
+ # Create model card with metadata
1144
+ config = run.get_config()
1145
+ results = run.get_results()
1146
+
1147
+ model_card = {
1148
+ 'run_id': run_id,
1149
+ 'export_date': datetime.utcnow().isoformat(),
1150
+ 'created_at': run.created_at.isoformat() if run.created_at else None,
1151
+ 'training_mode': config.get('training_mode', 'lora'),
1152
+ 'base_model': 'facebook/bart-large-mnli',
1153
+ 'model_type': 'BART fine-tuned for text classification',
1154
+ 'task': 'Multi-class text classification',
1155
+ 'categories': ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions'],
1156
+ 'training_config': config,
1157
+ 'results': results,
1158
+ 'improvement_over_baseline': run.improvement_over_baseline,
1159
+ 'num_training_examples': run.num_training_examples,
1160
+ 'num_validation_examples': run.num_validation_examples,
1161
+ 'num_test_examples': run.num_test_examples
1162
+ }
1163
+
1164
+ with open(os.path.join(export_path, 'model_card.json'), 'w') as f:
1165
+ json.dump(model_card, f, indent=2)
1166
+
1167
+ # Create README
1168
+ readme_content = f"""# Participatory Planning Model - Run {run_id}
1169
+
1170
+ ## Model Information
1171
+ - **Export Date**: {datetime.utcnow().strftime('%Y-%m-%d %H:%M UTC')}
1172
+ - **Training Mode**: {config.get('training_mode', 'lora').upper()}
1173
+ - **Base Model**: facebook/bart-large-mnli
1174
+ - **Task**: Multi-class text classification
1175
+
1176
+ ## Categories
1177
+ 1. Vision
1178
+ 2. Problem
1179
+ 3. Objectives
1180
+ 4. Directives
1181
+ 5. Values
1182
+ 6. Actions
1183
+
1184
+ ## Training Configuration
1185
+ - **Learning Rate**: {config.get('learning_rate', 'N/A')}
1186
+ - **Epochs**: {config.get('num_epochs', 'N/A')}
1187
+ - **Batch Size**: {config.get('batch_size', 'N/A')}
1188
+ - **Training Examples**: {run.num_training_examples}
1189
+ - **Validation Examples**: {run.num_validation_examples}
1190
+ - **Test Examples**: {run.num_test_examples}
1191
+
1192
+ ## Performance
1193
+ - **Test Accuracy**: {results.get('test_accuracy', 0)*100:.1f}%
1194
+ - **Improvement over Baseline**: {run.improvement_over_baseline*100:.1f}%
1195
+
1196
+ ## Usage
1197
+ To load this model:
1198
+ ```python
1199
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
1200
+
1201
+ tokenizer = AutoTokenizer.from_pretrained("./model_run_{run_id}")
1202
+ model = AutoModelForSequenceClassification.from_pretrained("./model_run_{run_id}")
1203
+ ```
1204
+
1205
+ See model_card.json for detailed metrics.
1206
+ """
1207
+
1208
+ with open(os.path.join(export_path, 'README.md'), 'w') as f:
1209
+ f.write(readme_content)
1210
+
1211
+ # Create ZIP file
1212
+ zip_path = os.path.join(temp_dir, f"model_run_{run_id}")
1213
+ shutil.make_archive(zip_path, 'zip', temp_dir, export_name)
1214
+ zip_file = f"{zip_path}.zip"
1215
+
1216
+ # Read ZIP file into memory before cleaning up temp dir
1217
+ with open(zip_file, 'rb') as f:
1218
+ zip_data = io.BytesIO(f.read())
1219
+
1220
+ # Clean up temp directory
1221
+ shutil.rmtree(temp_dir)
1222
+
1223
+ # Send file from memory
1224
+ zip_data.seek(0)
1225
+ return send_file(
1226
+ zip_data,
1227
+ mimetype='application/zip',
1228
+ as_attachment=True,
1229
+ download_name=f'participatory_planner_model_run_{run_id}_{datetime.now().strftime("%Y%m%d")}.zip'
1230
+ )
1231
+ except Exception as e:
1232
+ # Clean up temp dir if error occurs
1233
+ if os.path.exists(temp_dir):
1234
+ shutil.rmtree(temp_dir)
1235
+ raise e
1236
+
1237
+ except Exception as e:
1238
+ logger.error(f"Error exporting model: {str(e)}")
1239
+ return jsonify({'success': False, 'error': str(e)}), 500
1240
+
1241
+
1242
+ @bp.route('/api/import-model', methods=['POST'])
1243
+ @admin_required
1244
+ def import_model():
1245
+ """Import a previously exported model from ZIP file"""
1246
+ try:
1247
+ import tempfile
1248
+ import zipfile
1249
+ import shutil
1250
+
1251
+ if 'file' not in request.files:
1252
+ return jsonify({'success': False, 'error': 'No file uploaded'}), 400
1253
+
1254
+ file = request.files['file']
1255
+
1256
+ if file.filename == '':
1257
+ return jsonify({'success': False, 'error': 'No file selected'}), 400
1258
+
1259
+ if not file.filename.endswith('.zip'):
1260
+ return jsonify({'success': False, 'error': 'File must be a ZIP archive'}), 400
1261
+
1262
+ # Create temporary directory for extraction
1263
+ with tempfile.TemporaryDirectory() as temp_dir:
1264
+ # Save uploaded ZIP
1265
+ zip_path = os.path.join(temp_dir, 'upload.zip')
1266
+ file.save(zip_path)
1267
+
1268
+ # Extract ZIP
1269
+ extract_dir = os.path.join(temp_dir, 'extracted')
1270
+ os.makedirs(extract_dir, exist_ok=True)
1271
+
1272
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
1273
+ zip_ref.extractall(extract_dir)
1274
+
1275
+ # Find the model directory (should be model_run_X)
1276
+ contents = os.listdir(extract_dir)
1277
+ if len(contents) != 1:
1278
+ return jsonify({'success': False, 'error': 'Invalid model archive structure'}), 400
1279
+
1280
+ model_dir = os.path.join(extract_dir, contents[0])
1281
+
1282
+ # Validate it's a valid model
1283
+ required_files = ['config.json']
1284
+ model_files = ['pytorch_model.bin', 'model.safetensors'] # Either format
1285
+
1286
+ has_config = os.path.exists(os.path.join(model_dir, 'config.json'))
1287
+ has_model = any(os.path.exists(os.path.join(model_dir, f)) for f in model_files)
1288
+
1289
+ if not has_config or not has_model:
1290
+ return jsonify({
1291
+ 'success': False,
1292
+ 'error': 'Invalid model archive - missing required files (config.json and model weights)'
1293
+ }), 400
1294
+
1295
+ # Read model card if available
1296
+ model_info = {}
1297
+ model_card_path = os.path.join(model_dir, 'model_card.json')
1298
+ if os.path.exists(model_card_path):
1299
+ with open(model_card_path, 'r') as f:
1300
+ model_info = json.load(f)
1301
+
1302
+ # Create new training run record
1303
+ training_run = FineTuningRun(
1304
+ status='completed',
1305
+ created_at=datetime.utcnow()
1306
+ )
1307
+
1308
+ # Set config from model card if available
1309
+ if 'training_config' in model_info:
1310
+ training_run.set_config(model_info['training_config'])
1311
+ else:
1312
+ # Default config for imported models
1313
+ training_run.set_config({
1314
+ 'training_mode': 'imported',
1315
+ 'imported': True,
1316
+ 'original_filename': file.filename
1317
+ })
1318
+
1319
+ # Set metadata from model card
1320
+ if 'num_training_examples' in model_info:
1321
+ training_run.num_training_examples = model_info['num_training_examples']
1322
+ if 'num_validation_examples' in model_info:
1323
+ training_run.num_validation_examples = model_info['num_validation_examples']
1324
+ if 'num_test_examples' in model_info:
1325
+ training_run.num_test_examples = model_info['num_test_examples']
1326
+ if 'results' in model_info:
1327
+ training_run.set_results(model_info['results'])
1328
+ if 'improvement_over_baseline' in model_info:
1329
+ training_run.improvement_over_baseline = model_info['improvement_over_baseline']
1330
+
1331
+ training_run.completed_at = datetime.utcnow()
1332
+
1333
+ db.session.add(training_run)
1334
+ db.session.commit()
1335
+
1336
+ # Copy model to models directory
1337
+ models_dir = os.getenv('MODELS_DIR', 'models/finetuned')
1338
+ destination_path = os.path.join(models_dir, f'run_{training_run.id}')
1339
+
1340
+ shutil.copytree(model_dir, destination_path)
1341
+ training_run.model_path = destination_path
1342
+ db.session.commit()
1343
+
1344
+ logger.info(f"Model imported successfully as run {training_run.id}")
1345
+
1346
+ return jsonify({
1347
+ 'success': True,
1348
+ 'run_id': training_run.id,
1349
+ 'message': f'Model imported successfully as run #{training_run.id}',
1350
+ 'model_info': model_info
1351
+ })
1352
+
1353
+ except zipfile.BadZipFile:
1354
+ return jsonify({'success': False, 'error': 'Invalid ZIP file'}), 400
1355
+ except Exception as e:
1356
+ db.session.rollback()
1357
+ logger.error(f"Error importing model: {str(e)}")
1358
+ return jsonify({'success': False, 'error': str(e)}), 500
app/templates/admin/training.html CHANGED
@@ -76,6 +76,29 @@
76
  {% endif %}
77
  </div>
78
  <div class="card-body">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  <!-- Import Training Dataset Section -->
80
  <div class="mb-4">
81
  <h6><i class="bi bi-upload"></i> Import Training Dataset</h6>
@@ -83,7 +106,19 @@
83
  <div class="input-group">
84
  <input type="file" class="form-control" id="trainingDatasetFile" accept=".json">
85
  <button class="btn btn-outline-secondary" type="button" onclick="importTrainingDataset()">
86
- <i class="bi bi-cloud-upload"></i> Import
 
 
 
 
 
 
 
 
 
 
 
 
87
  </button>
88
  </div>
89
  </div>
@@ -123,23 +158,21 @@
123
  </div>
124
 
125
  <div class="row mb-3">
126
- <div class="col-md-4">
127
- <label class="form-label">
128
- LoRA Rank
129
- <button type="button" class="btn btn-sm btn-link p-0" onclick="toggleCustomLoraRank()">
130
- <i class="bi bi-pencil-square"></i>
131
- </button>
132
- </label>
133
- <select class="form-select" id="loraRank" onchange="checkCustomLoraRank()">
134
- <option value="8">8 (Fast, less capacity)</option>
135
- <option value="16" selected>16 (Balanced)</option>
136
- <option value="32">32 (Slow, more capacity)</option>
137
- <option value="custom">Custom...</option>
138
  </select>
139
- <input type="number" class="form-control mt-2" id="customLoraRank"
140
- style="display: none;" placeholder="Enter custom rank (4-64)"
141
- min="4" max="64" step="4" value="16">
 
142
  </div>
 
 
 
 
143
  <div class="col-md-4">
144
  <label class="form-label">
145
  Learning Rate
@@ -174,9 +207,6 @@
174
  style="display: none;" placeholder="Enter custom epochs (1-20)"
175
  min="1" max="20" value="3">
176
  </div>
177
- </div>
178
-
179
- <div class="row mb-3">
180
  <div class="col-md-4">
181
  <label class="form-label">Batch Size</label>
182
  <select class="form-select" id="batchSize">
@@ -185,6 +215,28 @@
185
  <option value="16">16 (High memory)</option>
186
  </select>
187
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  <div class="col-md-4">
189
  <label class="form-label">LoRA Alpha</label>
190
  <input type="number" class="form-control" id="loraAlpha" value="32" min="8" max="128" step="8">
@@ -196,6 +248,8 @@
196
  <small class="text-muted">Regularization (0.0-0.5)</small>
197
  </div>
198
  </div>
 
 
199
 
200
  <div class="d-grid gap-2">
201
  <button type="button" class="btn btn-primary btn-lg" onclick="startTraining()">
@@ -283,6 +337,16 @@
283
  <button class="btn btn-sm btn-info" onclick="viewRunDetails({{ run.id }})">
284
  <i class="bi bi-eye"></i> Details
285
  </button>
 
 
 
 
 
 
 
 
 
 
286
  </td>
287
  </tr>
288
  {% endfor %}
@@ -497,24 +561,158 @@ function importTrainingDataset() {
497
  });
498
  }
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  // Start training function
501
  function startTraining() {
502
  if (!confirm('Start fine-tuning the model? This will take several minutes.')) {
503
  return;
504
  }
505
 
 
506
  const config = {
507
  train_split: parseInt(document.getElementById('trainSplit').value) / 100,
508
  val_split: parseInt(document.getElementById('valSplit').value) / 100,
509
  test_split: parseInt(document.getElementById('testSplit').value) / 100,
510
- lora_rank: getLoraRank(),
511
- lora_alpha: parseInt(document.getElementById('loraAlpha').value),
512
- lora_dropout: parseFloat(document.getElementById('loraDropout').value),
513
  learning_rate: getLearningRate(),
514
  num_epochs: getNumEpochs(),
515
  batch_size: parseInt(document.getElementById('batchSize').value)
516
  };
517
 
 
 
 
 
 
 
 
518
  // Show progress modal
519
  const progressModal = new bootstrap.Modal(document.getElementById('trainingProgressModal'));
520
  progressModal.show();
@@ -610,22 +808,67 @@ function rollbackModel() {
610
  });
611
  }
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  // View run details
614
  function viewRunDetails(runId) {
615
  fetch(`{{ url_for("admin.get_run_details", run_id=0) }}`.replace('/0', `/${runId}`))
616
  .then(response => response.json())
617
  .then(data => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  const content = `
619
  <div class="row">
620
  <div class="col-md-6">
621
  <h6>Training Configuration</h6>
622
  <ul class="list-group">
623
- <li class="list-group-item"><strong>LoRA Rank:</strong> ${data.config.lora_rank}</li>
624
- <li class="list-group-item"><strong>Learning Rate:</strong> ${data.config.learning_rate}</li>
625
- <li class="list-group-item"><strong>Epochs:</strong> ${data.config.num_epochs}</li>
626
- <li class="list-group-item"><strong>Training Examples:</strong> ${data.num_training_examples}</li>
627
- <li class="list-group-item"><strong>Validation Examples:</strong> ${data.num_validation_examples}</li>
628
- <li class="list-group-item"><strong>Test Examples:</strong> ${data.num_test_examples}</li>
629
  </ul>
630
  </div>
631
  <div class="col-md-6">
@@ -646,6 +889,9 @@ function viewRunDetails(runId) {
646
  document.getElementById('runDetailsContent').innerHTML = content;
647
  const modal = new bootstrap.Modal(document.getElementById('runDetailsModal'));
648
  modal.show();
 
 
 
649
  });
650
  }
651
  </script>
 
76
  {% endif %}
77
  </div>
78
  <div class="card-body">
79
+ <!-- Zero-Shot Model Selection Section -->
80
+ <div class="mb-4 pb-3 border-bottom">
81
+ <h6><i class="bi bi-magic"></i> Zero-Shot Classification Model</h6>
82
+ <p class="text-muted small">Select which model to use for classifying submissions (before fine-tuning)</p>
83
+ <div class="row align-items-end">
84
+ <div class="col-md-6">
85
+ <label class="form-label">Active Model</label>
86
+ <select class="form-select" id="zeroShotModelSelect" onchange="changeZeroShotModel()">
87
+ <option value="bart-large-mnli">BART-large-MNLI (400M) - Current Default</option>
88
+ <option value="deberta-v3-base-mnli">DeBERTa-v3-base-MNLI (86M) - Fast &amp; Accurate</option>
89
+ <option value="distilbart-mnli">DistilBART-MNLI (134M) - Balanced</option>
90
+ </select>
91
+ </div>
92
+ <div class="col-md-6">
93
+ <div id="zeroShotModelInfo" class="alert alert-info mb-0" role="alert">
94
+ <small id="zeroShotModelDescription">
95
+ Loading model info...
96
+ </small>
97
+ </div>
98
+ </div>
99
+ </div>
100
+ </div>
101
+
102
  <!-- Import Training Dataset Section -->
103
  <div class="mb-4">
104
  <h6><i class="bi bi-upload"></i> Import Training Dataset</h6>
 
106
  <div class="input-group">
107
  <input type="file" class="form-control" id="trainingDatasetFile" accept=".json">
108
  <button class="btn btn-outline-secondary" type="button" onclick="importTrainingDataset()">
109
+ <i class="bi bi-cloud-upload"></i> Import Dataset
110
+ </button>
111
+ </div>
112
+ </div>
113
+
114
+ <!-- Import Fine-Tuned Model Section -->
115
+ <div class="mb-4">
116
+ <h6><i class="bi bi-box-arrow-in-down"></i> Import Fine-Tuned Model</h6>
117
+ <p class="text-muted small">Upload a previously exported model ZIP file to use it in this system</p>
118
+ <div class="input-group">
119
+ <input type="file" class="form-control" id="importModelFile" accept=".zip">
120
+ <button class="btn btn-outline-primary" type="button" onclick="importModel()">
121
+ <i class="bi bi-download"></i> Import Model
122
  </button>
123
  </div>
124
  </div>
 
158
  </div>
159
 
160
  <div class="row mb-3">
161
+ <div class="col-md-12">
162
+ <label class="form-label">Training Mode</label>
163
+ <select class="form-select" id="trainingMode" onchange="updateTrainingModeUI()">
164
+ <option value="head_only">Classification Head Only (Recommended for small datasets)</option>
165
+ <option value="lora">LoRA Fine-Tuning (For larger datasets)</option>
 
 
 
 
 
 
 
166
  </select>
167
+ <p class="text-muted small mt-1">
168
+ <strong>Head Only:</strong> Faster, better for &lt;100 examples. Only trains the output layer.<br>
169
+ <strong>LoRA:</strong> Slower, better for &gt;100 examples. Trains adapter layers throughout the model.
170
+ </p>
171
  </div>
172
+ </div>
173
+
174
+ <!-- Common Settings (visible for both modes) -->
175
+ <div class="row mb-3">
176
  <div class="col-md-4">
177
  <label class="form-label">
178
  Learning Rate
 
207
  style="display: none;" placeholder="Enter custom epochs (1-20)"
208
  min="1" max="20" value="3">
209
  </div>
 
 
 
210
  <div class="col-md-4">
211
  <label class="form-label">Batch Size</label>
212
  <select class="form-select" id="batchSize">
 
215
  <option value="16">16 (High memory)</option>
216
  </select>
217
  </div>
218
+ </div>
219
+
220
+ <!-- LoRA-specific Settings (only visible in LoRA mode) -->
221
+ <div id="loraSettings">
222
+ <div class="row mb-3">
223
+ <div class="col-md-4">
224
+ <label class="form-label">
225
+ LoRA Rank
226
+ <button type="button" class="btn btn-sm btn-link p-0" onclick="toggleCustomLoraRank()">
227
+ <i class="bi bi-pencil-square"></i>
228
+ </button>
229
+ </label>
230
+ <select class="form-select" id="loraRank" onchange="checkCustomLoraRank()">
231
+ <option value="8">8 (Fast, less capacity)</option>
232
+ <option value="16" selected>16 (Balanced)</option>
233
+ <option value="32">32 (Slow, more capacity)</option>
234
+ <option value="custom">Custom...</option>
235
+ </select>
236
+ <input type="number" class="form-control mt-2" id="customLoraRank"
237
+ style="display: none;" placeholder="Enter custom rank (4-64)"
238
+ min="4" max="64" step="4" value="16">
239
+ </div>
240
  <div class="col-md-4">
241
  <label class="form-label">LoRA Alpha</label>
242
  <input type="number" class="form-control" id="loraAlpha" value="32" min="8" max="128" step="8">
 
248
  <small class="text-muted">Regularization (0.0-0.5)</small>
249
  </div>
250
  </div>
251
+ </div><!-- End loraSettings -->
252
+
253
 
254
  <div class="d-grid gap-2">
255
  <button type="button" class="btn btn-primary btn-lg" onclick="startTraining()">
 
337
  <button class="btn btn-sm btn-info" onclick="viewRunDetails({{ run.id }})">
338
  <i class="bi bi-eye"></i> Details
339
  </button>
340
+ {% if run.status == 'completed' %}
341
+ <a href="{{ url_for('admin.export_model', run_id=run.id) }}" class="btn btn-sm btn-success" download>
342
+ <i class="bi bi-download"></i> Export
343
+ </a>
344
+ {% endif %}
345
+ {% if not run.is_active_model and run.status != 'training' %}
346
+ <button class="btn btn-sm btn-danger" onclick="deleteRun({{ run.id }})">
347
+ <i class="bi bi-trash"></i> Delete
348
+ </button>
349
+ {% endif %}
350
  </td>
351
  </tr>
352
  {% endfor %}
 
561
  });
562
  }
563
 
564
+ // Import fine-tuned model function
565
+ function importModel() {
566
+ const fileInput = document.getElementById('importModelFile');
567
+ const file = fileInput.files[0];
568
+
569
+ if (!file) {
570
+ alert('Please select a model ZIP file to import');
571
+ return;
572
+ }
573
+
574
+ if (!file.name.endsWith('.zip')) {
575
+ alert('Please select a ZIP file');
576
+ return;
577
+ }
578
+
579
+ if (!confirm('Import this fine-tuned model? It will be added to your training history and can be deployed.')) {
580
+ return;
581
+ }
582
+
583
+ const formData = new FormData();
584
+ formData.append('file', file);
585
+
586
+ // Show loading state
587
+ const button = event.target;
588
+ const originalText = button.innerHTML;
589
+ button.innerHTML = '<span class="spinner-border spinner-border-sm" role="status"></span> Importing...';
590
+ button.disabled = true;
591
+
592
+ fetch('{{ url_for("admin.import_model") }}', {
593
+ method: 'POST',
594
+ body: formData
595
+ })
596
+ .then(response => response.json())
597
+ .then(data => {
598
+ if (data.success) {
599
+ alert(`Successfully imported model as run #${data.run_id}!`);
600
+ location.reload();
601
+ } else {
602
+ alert('Error importing model: ' + data.error);
603
+ button.innerHTML = originalText;
604
+ button.disabled = false;
605
+ }
606
+ })
607
+ .catch(err => {
608
+ alert('Error: ' + err.message);
609
+ button.innerHTML = originalText;
610
+ button.disabled = false;
611
+ });
612
+ }
613
+
614
+ // Update UI based on training mode
615
+ function updateTrainingModeUI() {
616
+ const mode = document.getElementById('trainingMode').value;
617
+ const loraSettings = document.getElementById('loraSettings');
618
+
619
+ if (mode === 'head_only') {
620
+ loraSettings.style.display = 'none';
621
+ } else {
622
+ loraSettings.style.display = 'block';
623
+ }
624
+ }
625
+
626
+ // Initialize on page load
627
+ document.addEventListener('DOMContentLoaded', function() {
628
+ updateTrainingModeUI();
629
+ loadCurrentZeroShotModel();
630
+ });
631
+
632
+ // Load current zero-shot model
633
+ function loadCurrentZeroShotModel() {
634
+ fetch('{{ url_for("admin.get_zero_shot_model") }}')
635
+ .then(response => response.json())
636
+ .then(data => {
637
+ if (data.success) {
638
+ document.getElementById('zeroShotModelSelect').value = data.model_key;
639
+ updateZeroShotModelDescription(data.model_info);
640
+ }
641
+ })
642
+ .catch(err => {
643
+ console.error('Error loading zero-shot model:', err);
644
+ });
645
+ }
646
+
647
+ // Update zero-shot model description
648
+ function updateZeroShotModelDescription(modelInfo) {
649
+ const desc = document.getElementById('zeroShotModelDescription');
650
+ if (modelInfo) {
651
+ desc.innerHTML = `<strong>${modelInfo.size}</strong> parameters | Speed: <strong>${modelInfo.speed}</strong><br>${modelInfo.description}`;
652
+ }
653
+ }
654
+
655
+ // Change zero-shot model
656
+ function changeZeroShotModel() {
657
+ const modelKey = document.getElementById('zeroShotModelSelect').value;
658
+
659
+ if (!confirm(`Switch zero-shot model? This will reload the analyzer and may take a moment.`)) {
660
+ loadCurrentZeroShotModel(); // Revert selection
661
+ return;
662
+ }
663
+
664
+ const button = event.target;
665
+ const originalHtml = button.parentElement.innerHTML;
666
+ button.disabled = true;
667
+
668
+ fetch('{{ url_for("admin.set_zero_shot_model") }}', {
669
+ method: 'POST',
670
+ headers: {'Content-Type': 'application/json'},
671
+ body: JSON.stringify({model_key: modelKey})
672
+ })
673
+ .then(response => response.json())
674
+ .then(data => {
675
+ if (data.success) {
676
+ alert(`βœ“ Zero-shot model changed to ${data.model_name}!\n\nAll new classifications will use this model.`);
677
+ loadCurrentZeroShotModel(); // Refresh info
678
+ } else {
679
+ alert('Error changing model: ' + data.error);
680
+ loadCurrentZeroShotModel(); // Revert selection
681
+ }
682
+ })
683
+ .catch(err => {
684
+ alert('Error: ' + err.message);
685
+ loadCurrentZeroShotModel(); // Revert selection
686
+ })
687
+ .finally(() => {
688
+ button.disabled = false;
689
+ });
690
+ }
691
+
692
  // Start training function
693
  function startTraining() {
694
  if (!confirm('Start fine-tuning the model? This will take several minutes.')) {
695
  return;
696
  }
697
 
698
+ const mode = document.getElementById('trainingMode').value;
699
  const config = {
700
  train_split: parseInt(document.getElementById('trainSplit').value) / 100,
701
  val_split: parseInt(document.getElementById('valSplit').value) / 100,
702
  test_split: parseInt(document.getElementById('testSplit').value) / 100,
703
+ training_mode: mode,
 
 
704
  learning_rate: getLearningRate(),
705
  num_epochs: getNumEpochs(),
706
  batch_size: parseInt(document.getElementById('batchSize').value)
707
  };
708
 
709
+ // Only include LoRA settings if in LoRA mode
710
+ if (mode === 'lora') {
711
+ config.lora_rank = getLoraRank();
712
+ config.lora_alpha = parseInt(document.getElementById('loraAlpha').value);
713
+ config.lora_dropout = parseFloat(document.getElementById('loraDropout').value);
714
+ }
715
+
716
  // Show progress modal
717
  const progressModal = new bootstrap.Modal(document.getElementById('trainingProgressModal'));
718
  progressModal.show();
 
808
  });
809
  }
810
 
811
+ // Delete training run
812
+ function deleteRun(runId) {
813
+ if (!confirm('Delete this training run and all associated files? This action cannot be undone.')) {
814
+ return;
815
+ }
816
+
817
+ fetch(`{{ url_for("admin.delete_training_run", run_id=0) }}`.replace('/0', `/${runId}`), {
818
+ method: 'DELETE'
819
+ })
820
+ .then(response => response.json())
821
+ .then(data => {
822
+ if (data.success) {
823
+ alert('Training run deleted successfully');
824
+ location.reload();
825
+ } else {
826
+ alert('Error deleting run: ' + data.error);
827
+ }
828
+ })
829
+ .catch(err => {
830
+ alert('Error: ' + err.message);
831
+ });
832
+ }
833
+
834
  // View run details
835
  function viewRunDetails(runId) {
836
  fetch(`{{ url_for("admin.get_run_details", run_id=0) }}`.replace('/0', `/${runId}`))
837
  .then(response => response.json())
838
  .then(data => {
839
+ const config = data.training_config || {};
840
+ const trainingMode = config.training_mode || 'lora';
841
+ const modeLabel = trainingMode === 'head_only' ? 'Classification Head Only' : 'LoRA Fine-Tuning';
842
+
843
+ // Build configuration list based on training mode
844
+ let configItems = `
845
+ <li class="list-group-item"><strong>Mode:</strong> ${modeLabel}</li>
846
+ <li class="list-group-item"><strong>Learning Rate:</strong> ${config.learning_rate || 'N/A'}</li>
847
+ <li class="list-group-item"><strong>Epochs:</strong> ${config.num_epochs || 'N/A'}</li>
848
+ <li class="list-group-item"><strong>Batch Size:</strong> ${config.batch_size || 'N/A'}</li>
849
+ `;
850
+
851
+ // Add LoRA-specific settings if applicable
852
+ if (trainingMode === 'lora') {
853
+ configItems += `
854
+ <li class="list-group-item"><strong>LoRA Rank:</strong> ${config.lora_rank || 'N/A'}</li>
855
+ <li class="list-group-item"><strong>LoRA Alpha:</strong> ${config.lora_alpha || 'N/A'}</li>
856
+ <li class="list-group-item"><strong>LoRA Dropout:</strong> ${config.lora_dropout || 'N/A'}</li>
857
+ `;
858
+ }
859
+
860
+ configItems += `
861
+ <li class="list-group-item"><strong>Training Examples:</strong> ${data.num_training_examples || 'N/A'}</li>
862
+ <li class="list-group-item"><strong>Validation Examples:</strong> ${data.num_validation_examples || 'N/A'}</li>
863
+ <li class="list-group-item"><strong>Test Examples:</strong> ${data.num_test_examples || 'N/A'}</li>
864
+ `;
865
+
866
  const content = `
867
  <div class="row">
868
  <div class="col-md-6">
869
  <h6>Training Configuration</h6>
870
  <ul class="list-group">
871
+ ${configItems}
 
 
 
 
 
872
  </ul>
873
  </div>
874
  <div class="col-md-6">
 
889
  document.getElementById('runDetailsContent').innerHTML = content;
890
  const modal = new bootstrap.Modal(document.getElementById('runDetailsModal'));
891
  modal.show();
892
+ })
893
+ .catch(err => {
894
+ alert('Error loading run details: ' + err.message);
895
  });
896
  }
897
  </script>
requirements.txt CHANGED
@@ -1,13 +1,13 @@
1
  Flask==3.0.0
2
  Flask-SQLAlchemy==3.1.1
3
  python-dotenv==1.0.0
4
- transformers==4.36.0
5
  torch==2.5.0
6
  sentencepiece>=0.2.0
7
  gunicorn==21.2.0
8
 
9
  # Fine-tuning dependencies
10
- peft>=0.7.0
11
  datasets>=2.14.0
12
  scikit-learn>=1.3.0
13
  matplotlib>=3.7.0
 
1
  Flask==3.0.0
2
  Flask-SQLAlchemy==3.1.1
3
  python-dotenv==1.0.0
4
+ transformers==4.46.0
5
  torch==2.5.0
6
  sentencepiece>=0.2.0
7
  gunicorn==21.2.0
8
 
9
  # Fine-tuning dependencies
10
+ peft==0.13.2
11
  datasets>=2.14.0
12
  scikit-learn>=1.3.0
13
  matplotlib>=3.7.0