smartTranscend commited on
Commit
569e864
·
verified ·
1 Parent(s): ff78cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1094 -468
app.py CHANGED
@@ -1,77 +1,196 @@
1
  import gradio as gr
2
  import pandas as pd
 
3
  import torch
4
  from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
5
  from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType
6
  from datasets import Dataset
7
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
8
  from torch import nn
 
9
  import os
10
  from datetime import datetime
11
  import gc
 
 
 
 
 
12
 
 
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
15
 
16
- # 設置較小的預設值以節省記憶體
17
  torch.backends.cudnn.benchmark = False
18
  if torch.cuda.is_available():
19
  torch.cuda.empty_cache()
20
 
21
- # 全域變數
22
  trained_models = {}
23
  model_counter = 0
24
- baseline_results = {}
25
- baseline_model_cache = {}
26
 
27
- def calculate_improvement(baseline_val, finetuned_val):
28
- """安全計算改善率"""
29
- if baseline_val == 0:
30
- if finetuned_val > 0:
31
- return float('inf')
32
- else:
33
- return 0.0
34
- return (finetuned_val - baseline_val) / baseline_val * 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def format_improve(val):
37
- """格式化改善率"""
38
- if val == float('inf'):
39
- return "N/A (baseline=0)"
40
- return f"{val:+.1f}%"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
42
  def compute_metrics(pred):
 
43
  try:
44
  labels = pred.label_ids
45
  preds = pred.predictions.argmax(-1)
46
- precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', pos_label=1, zero_division=0)
 
 
 
 
47
  acc = accuracy_score(labels, preds)
 
 
48
  cm = confusion_matrix(labels, preds)
 
49
  if cm.shape == (2, 2):
50
  tn, fp, fn, tp = cm.ravel()
51
- else:
52
- tn = fp = fn = tp = 0
53
  sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
54
  specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
 
 
 
 
 
55
  return {
56
- 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
57
- 'sensitivity': sensitivity, 'specificity': specificity,
58
- 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
 
 
 
 
 
 
 
 
 
59
  }
60
  except Exception as e:
61
  print(f"Error in compute_metrics: {e}")
62
- return {
63
- 'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0,
64
- 'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0
65
- }
66
 
67
- def evaluate_baseline(model, tokenizer, test_dataset, device):
 
68
  """評估未微調的基準模型"""
69
  model.eval()
70
  all_preds = []
71
  all_labels = []
72
 
73
- from torch.utils.data import DataLoader
74
-
75
  def collate_fn(batch):
76
  return {
77
  'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
@@ -79,7 +198,13 @@ def evaluate_baseline(model, tokenizer, test_dataset, device):
79
  'labels': torch.tensor([item['label'] for item in batch])
80
  }
81
 
82
- dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)
 
 
 
 
 
 
83
 
84
  with torch.no_grad():
85
  for batch in dataloader:
@@ -90,636 +215,1137 @@ def evaluate_baseline(model, tokenizer, test_dataset, device):
90
  all_preds.extend(preds.cpu().numpy())
91
  all_labels.extend(labels.numpy())
92
 
93
- precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
 
 
 
94
  acc = accuracy_score(all_labels, all_preds)
 
95
  cm = confusion_matrix(all_labels, all_preds)
 
96
  if cm.shape == (2, 2):
97
  tn, fp, fn, tp = cm.ravel()
98
- else:
99
- tn = fp = fn = tp = 0
100
  sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
101
  specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
 
 
102
 
103
  return {
104
- 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
105
- 'sensitivity': sensitivity, 'specificity': specificity,
106
- 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
 
 
 
 
 
 
 
 
 
107
  }
108
 
109
- class WeightedTrainer(Trainer):
110
- def __init__(self, *args, class_weights=None, use_focal_loss=False, **kwargs):
 
 
 
 
 
111
  super().__init__(*args, **kwargs)
112
  self.class_weights = class_weights
113
  self.use_focal_loss = use_focal_loss
 
 
 
 
 
 
 
 
114
 
115
  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
 
116
  labels = inputs.pop("labels")
117
  outputs = model(**inputs)
118
  logits = outputs.logits
119
 
120
- if self.use_focal_loss:
121
- # Focal Loss: 更關注難分類的樣本
122
  ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')(
123
  logits.view(-1, 2), labels.view(-1)
124
  )
125
  pt = torch.exp(-ce_loss)
126
- focal_loss = ((1 - pt) ** 2 * ce_loss).mean()
127
  loss = focal_loss
128
- else:
 
129
  loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
130
  loss = loss_fct(logits.view(-1, 2), labels.view(-1))
 
 
 
 
131
 
132
  return (loss, outputs) if return_outputs else loss
133
-
134
- def evaluate_baseline(model, tokenizer, test_dataset, device):
135
- """評估未微調的基準模型"""
136
- model.eval()
137
- all_preds = []
138
- all_labels = []
139
 
140
- from torch.utils.data import DataLoader
141
-
142
- def collate_fn(batch):
143
- return {
144
- 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
145
- 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
146
- 'labels': torch.tensor([item['label'] for item in batch])
147
- }
148
-
149
- dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- with torch.no_grad():
152
- for batch in dataloader:
153
- labels = batch.pop('labels')
154
- inputs = {k: v.to(device) for k, v in batch.items()}
155
- outputs = model(**inputs)
156
- preds = torch.argmax(outputs.logits, dim=-1)
157
- all_preds.extend(preds.cpu().numpy())
158
- all_labels.extend(labels.numpy())
159
 
160
- precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
161
- acc = accuracy_score(all_labels, all_preds)
162
- cm = confusion_matrix(all_labels, all_preds)
163
- if cm.shape == (2, 2):
164
- tn, fp, fn, tp = cm.ravel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  else:
166
- tn = fp = fn = tp = 0
167
- sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
168
- specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
169
-
170
- return {
171
- 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
172
- 'sensitivity': sensitivity, 'specificity': specificity,
173
- 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
174
- }
175
 
176
- def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
177
- weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
178
- weight_mult, best_metric):
179
- global trained_models, model_counter, baseline_results
 
 
 
 
180
 
181
  model_mapping = {
182
  "BERT-base": "bert-base-uncased",
 
 
 
183
  }
184
 
185
  model_name = model_mapping.get(base_model, "bert-base-uncased")
186
 
187
  try:
 
188
  if csv_file is None:
189
- return "❌ 請上傳 CSV", "", "", ""
190
 
191
  df = pd.read_csv(csv_file.name)
192
  if 'Text' not in df.columns or 'label' not in df.columns:
193
- return "❌ 需要 Text 和 label 欄位", "", "", ""
194
 
 
195
  df_clean = pd.DataFrame({
196
- 'text': df['Text'].astype(str),
197
  'label': df['label'].astype(int)
198
  }).dropna()
199
 
 
200
  n0 = int(sum(df_clean['label'] == 0))
201
  n1 = int(sum(df_clean['label'] == 1))
202
- if n1 == 0:
203
- return "❌ 無死亡樣本", "", "", ""
204
 
205
- ratio = n0 / n1
206
- # 動態調整權重計算
207
- if ratio > 10: # 極度不平衡
208
- w0, w1 = 1.0, min(ratio * weight_mult, ratio * 0.7) # 限制最大權重
209
- else:
210
- w0, w1 = 1.0, ratio * weight_mult
211
-
212
- info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n"
213
- info += f"⚖️ 權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  tokenizer = BertTokenizer.from_pretrained(model_name)
 
 
216
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
217
 
218
- def preprocess(ex):
219
- return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=128)
 
 
 
 
 
220
 
221
  tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
222
- split = tokenized.train_test_split(test_size=0.2, seed=42)
223
 
 
224
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
225
- info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}"
 
 
 
 
 
226
 
227
- # 🔇 靜默評估基準模型(不顯示在資料資訊中)
228
  baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
229
  baseline_model = baseline_model.to(device)
230
 
231
- baseline_perf = evaluate_baseline(baseline_model, tokenizer, split['test'], device)
232
- baseline_key = f"{base_model}_baseline"
233
- baseline_results[baseline_key] = baseline_perf
 
 
 
234
 
235
- # 清理基準模型以釋放記憶體
236
  del baseline_model
237
  if torch.cuda.is_available():
238
  torch.cuda.empty_cache()
239
  gc.collect()
240
 
241
- # 開始微調
242
- info += f"\n\n🔧 套用 {method.upper()} 微調..."
 
 
243
  model = BertForSequenceClassification.from_pretrained(
244
- model_name, num_labels=2,
 
245
  hidden_dropout_prob=dropout,
246
  attention_probs_dropout_prob=dropout
247
  )
248
 
 
249
  peft_applied = False
250
  if method == "lora":
 
 
251
  config = LoraConfig(
252
- task_type=TaskType.SEQ_CLS,
253
- r=int(lora_r),
254
  lora_alpha=int(lora_alpha),
255
- lora_dropout=lora_dropout,
256
- target_modules=["query", "value"],
257
  bias="none"
258
  )
259
  model = get_peft_model(model, config)
260
  peft_applied = True
261
- info += f"\n✅ LoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
 
 
 
 
262
  elif method == "adalora":
 
 
263
  config = AdaLoraConfig(
264
- task_type=TaskType.SEQ_CLS,
265
- r=int(lora_r),
266
  lora_alpha=int(lora_alpha),
267
- lora_dropout=lora_dropout,
268
  target_modules=["query", "value"],
269
- init_r=12, tinit=200, tfinal=1000, deltaT=10
 
 
 
 
270
  )
271
  model = get_peft_model(model, config)
272
  peft_applied = True
273
- info += f"\n✅ AdaLoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
 
 
 
274
 
275
- if not peft_applied:
276
- info += f"\n⚠️ 警告:{method} 方法未被識別,使用 Full Fine-tuning"
 
277
 
278
  model = model.to(device)
279
 
280
- total = sum(p.numel() for p in model.parameters())
281
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
282
- info += f"\n\n💾 參數量\n總參數: {total:,}\n可訓練: {trainable:,}\n比例: {trainable/total*100:.2f}%"
 
 
 
 
 
 
 
283
 
 
284
  weights = torch.tensor([w0, w1], dtype=torch.float).to(device)
 
285
 
286
- args = TrainingArguments(
287
- output_dir='./results',
 
 
 
 
 
 
288
  num_train_epochs=int(num_epochs),
289
- per_device_train_batch_size=int(batch_size),
290
- per_device_eval_batch_size=int(batch_size)*2,
291
- learning_rate=float(learning_rate),
292
  weight_decay=float(weight_decay),
293
- evaluation_strategy="epoch",
294
- save_strategy="no", # 🔧 改為不保存,避免 PEFT 載入問題
295
- load_best_model_at_end=False, # 🔧 關閉,直接用最後一個 epoch
296
- report_to="none",
297
- logging_steps=10,
298
- warmup_steps=50,
299
- logging_first_step=True
 
 
300
  )
301
 
302
- trainer = WeightedTrainer(
303
- model=model,
304
- args=args,
 
 
 
 
305
  train_dataset=split['train'],
306
- eval_dataset=split['test'],
307
  compute_metrics=compute_metrics,
308
  class_weights=weights,
309
- use_focal_loss=(ratio > 10) # 極度不平衡時使用 Focal Loss
 
 
 
 
310
  )
311
 
312
- if ratio > 10:
313
- info += "\n\n⚡ 使用 Focal Loss 處理極度不平衡資料"
314
-
315
- info += "\n\n⏳ 開始訓練..."
 
 
 
 
 
316
 
317
- # 訓練前檢查
318
- info += f"\n📊 訓練前檢查:"
319
- info += f"\n - 訓練樣本: {len(split['train'])}"
320
- info += f"\n - 測試樣本: {len(split['test'])}"
321
- info += f"\n - 批次數/epoch: {len(split['train']) // int(batch_size)}"
322
 
 
323
  train_result = trainer.train()
324
 
325
- # 訓練後資訊
326
- info += f"\n\n✅ 訓練完成!"
327
- info += f"\n📉 最終 Training Loss: {train_result.training_loss:.4f}"
328
 
329
- results = trainer.evaluate()
 
330
 
331
- # 生成帶時間戳的模型 ID
332
  model_counter += 1
333
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
334
- model_id = f"{base_model}_{method}_{timestamp}"
 
335
  trained_models[model_id] = {
336
- 'model': model,
337
- 'tokenizer': tokenizer,
338
- 'results': results,
339
- 'baseline': baseline_perf, # 保存基準結果供後續使用
340
  'config': {
341
- 'type': base_model,
342
- 'model_name': model_name,
343
- 'method': method,
344
- 'metric': best_metric
 
 
 
 
 
345
  },
346
- 'timestamp': timestamp
 
347
  }
348
 
349
- # 計算改善
350
- f1_improve = calculate_improvement(baseline_perf['f1'], results['eval_f1'])
351
- acc_improve = calculate_improvement(baseline_perf['accuracy'], results['eval_accuracy'])
352
- prec_improve = calculate_improvement(baseline_perf['precision'], results['eval_precision'])
353
- rec_improve = calculate_improvement(baseline_perf['recall'], results['eval_recall'])
354
- sens_improve = calculate_improvement(baseline_perf['sensitivity'], results['eval_sensitivity'])
355
- spec_improve = calculate_improvement(baseline_perf['specificity'], results['eval_specificity'])
356
-
357
- # 純 BERT 輸出
358
- baseline_output = f"🔬 純 BERT(未微調)\n\n"
359
- baseline_output += f"📊 表現\n"
360
- baseline_output += f"F1: {baseline_perf['f1']:.4f}\n"
361
- baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
362
- baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n"
363
- baseline_output += f"Recall: {baseline_perf['recall']:.4f}\n"
364
- baseline_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n"
365
- baseline_output += f"Specificity: {baseline_perf['specificity']:.4f}\n\n"
366
- baseline_output += f"混淆矩陣\n"
367
- baseline_output += f"TP: {baseline_perf['tp']} | TN: {baseline_perf['tn']}\n"
368
- baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}"
369
-
370
- # 微調 BERT 輸出
371
- finetuned_output = f"✅ 微調 BERT\n"
372
- finetuned_output += f"模型: {model_id}\n\n"
373
- finetuned_output += f"📊 表現\n"
374
- finetuned_output += f"F1: {results['eval_f1']:.4f}\n"
375
- finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
376
- finetuned_output += f"Precision: {results['eval_precision']:.4f}\n"
377
- finetuned_output += f"Recall: {results['eval_recall']:.4f}\n"
378
- finetuned_output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n"
379
- finetuned_output += f"Specificity: {results['eval_specificity']:.4f}\n\n"
380
- finetuned_output += f"混淆矩陣\n"
381
- finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n"
382
- finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}"
383
-
384
- # 比較結果輸出
385
- comparison_output = f"📊 純 BERT vs 微調 BERT 比較\n\n"
386
- comparison_output += f"指標改善:\n"
387
- comparison_output += f"F1: {baseline_perf['f1']:.4f} → {results['eval_f1']:.4f} ({format_improve(f1_improve)})\n"
388
- comparison_output += f"Accuracy: {baseline_perf['accuracy']:.4f} → {results['eval_accuracy']:.4f} ({format_improve(acc_improve)})\n"
389
- comparison_output += f"Precision: {baseline_perf['precision']:.4f} → {results['eval_precision']:.4f} ({format_improve(prec_improve)})\n"
390
- comparison_output += f"Recall: {baseline_perf['recall']:.4f} → {results['eval_recall']:.4f} ({format_improve(rec_improve)})\n"
391
- comparison_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f} → {results['eval_sensitivity']:.4f} ({format_improve(sens_improve)})\n"
392
- comparison_output += f"Specificity: {baseline_perf['specificity']:.4f} → {results['eval_specificity']:.4f} ({format_improve(spec_improve)})\n\n"
393
- comparison_output += f"混淆矩陣變化:\n"
394
- comparison_output += f"TP: {baseline_perf['tp']} → {results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n"
395
- comparison_output += f"TN: {baseline_perf['tn']} → {results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n"
396
- comparison_output += f"FP: {baseline_perf['fp']} → {results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n"
397
- comparison_output += f"FN: {baseline_perf['fn']} → {results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})"
398
-
399
- info += "\n\n✅ 訓練完成!"
400
-
401
- return info, baseline_output, finetuned_output, comparison_output
402
 
403
  except Exception as e:
404
  import traceback
405
- error_msg = f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
406
- return error_msg, "", "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
 
408
  def predict(model_id, text):
409
- global baseline_model_cache
410
 
411
  if not model_id or model_id not in trained_models:
412
- return "❌ 請選擇模型"
413
- if not text:
414
- return "❌ 請輸入文字"
 
415
 
416
  try:
 
417
  info = trained_models[model_id]
418
- model, tokenizer = info['model'], info['tokenizer']
 
419
  config = info['config']
420
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
421
 
422
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
423
- inputs_cuda = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
424
 
425
- # 預測:微調模型
426
  model.eval()
427
  with torch.no_grad():
428
- outputs = model(**inputs_cuda)
429
- probs_finetuned = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
430
  pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
 
431
 
432
- result_finetuned = "存活" if pred_finetuned == 0 else "死亡"
433
-
434
- # 預測:基準模型(使用快取)
435
- cache_key = config['model_name']
436
- if cache_key not in baseline_model_cache:
437
- baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2)
438
- baseline_model = baseline_model.to(device)
439
- baseline_model.eval()
440
- baseline_model_cache[cache_key] = baseline_model
441
- else:
442
- baseline_model = baseline_model_cache[cache_key]
443
 
444
  with torch.no_grad():
445
- outputs_baseline = baseline_model(**inputs_cuda)
446
- probs_baseline = torch.nn.functional.softmax(outputs_baseline.logits, dim=-1)
 
447
  pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
 
448
 
449
- result_baseline = "存活" if pred_baseline == 0 else "死亡"
450
-
451
- # 判斷是否一致
452
  agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
453
 
454
- output = f"""🔮 預測結果比較
 
 
 
 
455
 
456
- 📝 輸入文字: {text[:100]}{'...' if len(text) > 100 else ''}
457
 
458
- {'='*50}
 
 
 
459
 
460
- 🧬 微調模型 ({model_id})
461
- 預測: {result_finetuned}
462
- 信心: {probs_finetuned[0][pred_finetuned].item():.2%}
463
  機率分布:
464
- 存活: {probs_finetuned[0][0].item():.2%}
465
- 死亡: {probs_finetuned[0][1].item():.2%}
466
 
467
- {'='*50}
 
 
 
 
 
 
 
 
 
 
468
 
469
- 🔬 基準模型(未微調 {config['type']})
470
- 預測: {result_baseline}
471
- 信心: {probs_baseline[0][pred_baseline].item():.2%}
472
  機率分布:
473
- 存活: {probs_baseline[0][0].item():.2%}
474
- 死亡: {probs_baseline[0][1].item():.2%}
475
 
476
- {'='*50}
477
 
478
- 📊 結論
 
479
  兩模型預測: {agreement}
480
  """
481
 
482
  if pred_finetuned != pred_baseline:
483
- output += f"\n💡 分析: 微調模型預測為【{result_finetuned}】,而基準模型預測為【{result_baseline}】"
484
- output += f"\n 這顯示了 fine-tuning 對此案例的影響!"
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
- f1_improve = calculate_improvement(info['baseline']['f1'], info['results']['eval_f1'])
 
 
 
 
487
 
488
  output += f"""
489
 
490
- 📈 模型表現
 
491
  微調模型 F1: {info['results']['eval_f1']:.4f}
492
  基準模型 F1: {info['baseline']['f1']:.4f}
493
- 改善幅度: {format_improve(f1_improve)}
 
 
 
494
  """
495
 
496
  return output
497
 
498
  except Exception as e:
499
  import traceback
500
- return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
501
 
502
- def compare():
 
 
 
503
  if not trained_models:
504
- return "❌ 尚未訓練模型"
 
 
 
505
 
506
- text = "# 📊 模型比較\n\n"
507
- text += "## 微調模型表現\n\n"
508
- text += "| 模型 | 基礎 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
509
- text += "|------|------|------|-----|-----|------|--------|------|------|\n"
510
 
511
- for mid, info in trained_models.items():
512
  r = info['results']
513
  c = info['config']
514
- text += f"| {mid} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
515
- text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
516
- text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
 
 
 
 
 
517
 
518
- text += "\n## 基準模型表現(未微調)\n\n"
519
- text += "| 模型 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
520
- text += "|------|-----|-----|------|--------|------|------|\n"
521
 
522
- for mid, info in trained_models.items():
523
- b = info['baseline']
524
- c = info['config']
525
- text += f"| {c['type']}-baseline | {b['f1']:.4f} | {b['accuracy']:.4f} | "
526
- text += f"{b['precision']:.4f} | {b['recall']:.4f} | "
527
- text += f"{b['sensitivity']:.4f} | {b['specificity']:.4f} |\n"
528
 
529
- text += "\n## 🏆 最佳模型\n\n"
530
- for metric in ['f1', 'accuracy', 'precision', 'recall', 'sensitivity', 'specificity']:
531
- best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}'])
532
- baseline_val = best[1]['baseline'][metric]
533
- finetuned_val = best[1]['results'][f'eval_{metric}']
534
- improvement = calculate_improvement(baseline_val, finetuned_val)
535
-
536
- text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n"
537
 
538
- return text
539
-
540
- def refresh_model_list():
541
- return gr.Dropdown(choices=list(trained_models.keys()))
542
-
543
- # Gradio UI
544
- with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as demo:
545
- gr.Markdown("# 🧬 BERT Fine-tuning 教學平台")
546
- gr.Markdown("### 比較基準模型 vs 微調模型的表現差異")
547
 
548
- with gr.Tab("訓練"):
549
- gr.Markdown("## 步驟 1: 選擇基礎模型")
550
-
551
- base_model = gr.Dropdown(
552
- choices=["BERT-base"],
553
- value="BERT-base",
554
- label="基礎模型",
555
- info="更多模型即將推出"
556
- )
557
-
558
- gr.Markdown("## 步驟 2: 選擇微調方法")
559
-
560
- method = gr.Radio(
561
- choices=["lora", "adalora"],
562
- value="lora",
563
- label="微調方法",
564
- info="兩種都是參數高效方法,推薦從 LoRA 開始"
565
  )
566
 
567
- gr.Markdown("## 步驟 3: 上傳資料")
568
- csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"])
569
-
570
- gr.Markdown("## 步驟 4: 設定訓練參數")
571
-
572
- gr.Markdown("### 🎯 基本訓練參數")
573
- with gr.Row():
574
- num_epochs = gr.Number(value=5, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0,
575
- info="建議 5-8 輪")
576
- batch_size = gr.Number(value=4, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0,
577
- info="記憶體不足時降到 4")
578
- learning_rate = gr.Number(value=5e-5, label="學習率 (learning_rate)", minimum=0, maximum=1,
579
- info="5e-5 是平衡選擇")
580
-
581
- gr.Markdown("### ⚙️ 進階參數")
582
- with gr.Row():
583
- weight_decay = gr.Number(value=0.01, label="權重衰減 (weight_decay)", minimum=0, maximum=1)
584
- dropout = gr.Number(value=0.1, label="Dropout 機率", minimum=0, maximum=1)
585
-
586
- gr.Markdown("### 🔧 LoRA 參數")
587
- with gr.Row():
588
- lora_r = gr.Number(value=32, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0,
589
- info="提高到 32,增加表達能力")
590
- lora_alpha = gr.Number(value=64, label="LoRA Alpha", minimum=1, maximum=512, precision=0,
591
- info="Alpha = Rank × 2")
592
- lora_dropout = gr.Number(value=0.05, label="LoRA Dropout", minimum=0, maximum=1,
593
- info="降低 dropout,避免欠擬合")
594
-
595
- gr.Markdown("### ⚖️ 評估設定")
596
- with gr.Row():
597
- weight_mult = gr.Number(value=1.0, label="類別權重倍數", minimum=0, maximum=5,
598
- info="⚠️ 資料極度不平衡時建議 0.5-1.5,不要超過 2.0")
599
- best_metric = gr.Dropdown(
600
- choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
601
- value="f1",
602
- label="最佳模型選擇指標",
603
- info="訓練時用此指標選擇最佳模型"
604
- )
605
-
606
- train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
607
-
608
- gr.Markdown("## 📊 訓練結果")
609
-
610
- data_info = gr.Textbox(label="📋 資料資訊", lines=10)
611
-
612
- with gr.Row():
613
- baseline_result = gr.Textbox(label="🔬 純 BERT(未微調)", lines=14)
614
- finetuned_result = gr.Textbox(label="✅ 微調 BERT", lines=14)
615
 
616
- comparison_result = gr.Textbox(label="📊 BERT vs 微調 BERT 比較", lines=14)
617
-
618
- train_btn.click(
619
- train_bert_model,
620
- inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate,
621
- weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
622
- weight_mult, best_metric],
623
- outputs=[data_info, baseline_result, finetuned_result, comparison_result]
624
- )
625
 
626
- with gr.Tab("預測"):
627
- gr.Markdown("## 使用訓練好的模型預測")
628
-
629
- with gr.Row():
630
- model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys()))
631
- refresh = gr.Button("🔄 刷新")
632
-
633
- text_input = gr.Textbox(label="輸入病例描述", lines=4,
634
- placeholder="Patient diagnosed with...")
635
- predict_btn = gr.Button("預測", variant="primary", size="lg")
636
- pred_output = gr.Textbox(label="預測結果(含基準模型對比)", lines=20)
637
-
638
- refresh.click(refresh_model_list, outputs=[model_drop])
639
- predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output])
640
-
641
- gr.Examples(
642
- examples=[
643
- ["Patient with stage II breast cancer, good response to treatment."],
644
- ["Advanced metastatic cancer, multiple organ involvement."]
645
- ],
646
- inputs=text_input
647
- )
648
 
649
- with gr.Tab("比較"):
650
- gr.Markdown("## 比較所有模型(含基準模型)")
651
- compare_btn = gr.Button("比較", variant="primary", size="lg")
652
- compare_output = gr.Markdown()
653
- compare_btn.click(compare, outputs=[compare_output])
 
 
 
 
 
 
 
654
 
655
- with gr.Tab("說明"):
656
- gr.Markdown("""
657
- ## 📖 使用說明
658
-
659
- ### 🎯 平台特色
660
-
661
- 本平台會自動比較:
662
- - **基準模型**:未經微調的原始 BERT
663
- - **微調模型**:使用你的資料訓練後的 BERT
664
-
665
- 這樣可以清楚看到 fine-tuning 帶來的改善!
666
-
667
- ### 基礎模型
668
-
669
- - **BERT-base**: 標準 BERT,110M 參數 ⭐目前支援
670
-
671
- ### 微調方法
672
-
673
- - **LoRA**: 低秩適應,參數高效的微調方法 ⭐強烈推薦
674
- - 只訓練少量參數(通常 <1%)
675
- - 訓練速度快,效果好
676
- - 適合大多數情況
677
-
678
- - **AdaLoRA**: 自適應 LoRA,動態調整秩
679
- - 自動找出最重要的參數
680
- - 可能比 LoRA 效果稍好
681
- - 訓練時間稍長
682
-
683
- ### 評估指標
684
-
685
- - **F1**: 平衡指標,推薦用於不平衡資料 ⭐
686
- - **Accuracy**: 整體準確率
687
- - **Precision**: 減少假陽性
688
- - **Recall/Sensitivity**: 減少假陰性
689
- - **Specificity**: 真陰性率
690
-
691
- ### 參數建議
692
-
693
- 針對不平衡資料(如醫療資料):
694
- - **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致)
695
- - **LoRA Rank**: 8-16(平衡效果與速度)
696
- - **類別權重倍數**:
697
- - ⚠️ **極度不平衡 (>10:1)**: 0.5-1.0(你的情況!)
698
- - 中度不平衡 (3-10:1): 1.0-1.5
699
- - 輕度不平衡 (<3:1): 1.5-2.5
700
- - **Learning rate**: 3e-5 到 5e-5(較高的學習率配合 LoRA)
701
- - **Epochs**: 5-10(極度不平衡需要更多輪)
702
- - **Batch size**: 8-16(依 GPU 記憶體調整)
703
-
704
- ### 資料格式
705
 
706
- CSV 必須包含:
707
- - `Text`: 病例描述
708
- - `label`: 0=存活, 1=死亡
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
 
710
- ### 🚀 快速開始
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
- 1. 上傳包含 `Text` 和 `label` 欄位的 CSV
713
- 2. 使用預設參數(適合大多數情況)
714
- 3. 點擊「開始訓練」
715
- 4. 在「預測」分頁測試模型
716
- 5. 在「比較」分頁查看所有模型表現
717
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
 
 
719
  if __name__ == "__main__":
 
720
  demo.launch(
721
  server_name="0.0.0.0",
722
  server_port=7860,
723
  share=False,
724
- max_threads=4 # 限制執行緒數
725
  )
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import numpy as np
4
  import torch
5
  from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
6
  from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType
7
  from datasets import Dataset
8
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
9
  from torch import nn
10
+ from torch.utils.data import DataLoader, WeightedRandomSampler
11
  import os
12
  from datetime import datetime
13
  import gc
14
+ import json
15
+ from functools import lru_cache
16
+ from typing import Dict, List, Tuple, Optional
17
+ import warnings
18
+ warnings.filterwarnings('ignore')
19
 
20
+ # 環境設置
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
23
 
24
+ # 優化 CUDA 設置
25
  torch.backends.cudnn.benchmark = False
26
  if torch.cuda.is_available():
27
  torch.cuda.empty_cache()
28
 
29
+ # ==================== 全域變數 ====================
30
  trained_models = {}
31
  model_counter = 0
32
+ training_histories = {} # 新增:儲存訓練歷史
 
33
 
34
+ # ==================== 訓練監控類 ====================
35
+ class TrainingMonitor:
36
+ """訓練過程監控器"""
37
+ def __init__(self):
38
+ self.history = {
39
+ 'epoch': [],
40
+ 'train_loss': [],
41
+ 'eval_loss': [],
42
+ 'eval_accuracy': [],
43
+ 'eval_f1': [],
44
+ 'eval_precision': [],
45
+ 'eval_recall': [],
46
+ 'learning_rate': [],
47
+ 'best_epoch': None,
48
+ 'best_metric_value': None
49
+ }
50
+
51
+ def log_epoch(self, epoch: int, train_loss: float, eval_metrics: Dict, lr: float):
52
+ """記錄每個 epoch 的結果"""
53
+ self.history['epoch'].append(epoch)
54
+ self.history['train_loss'].append(train_loss)
55
+ self.history['eval_loss'].append(eval_metrics.get('eval_loss', 0))
56
+ self.history['eval_accuracy'].append(eval_metrics.get('eval_accuracy', 0))
57
+ self.history['eval_f1'].append(eval_metrics.get('eval_f1', 0))
58
+ self.history['eval_precision'].append(eval_metrics.get('eval_precision', 0))
59
+ self.history['eval_recall'].append(eval_metrics.get('eval_recall', 0))
60
+ self.history['learning_rate'].append(lr)
61
+
62
+ def update_best(self, epoch: int, metric_value: float):
63
+ """更新最佳結果"""
64
+ self.history['best_epoch'] = epoch
65
+ self.history['best_metric_value'] = metric_value
66
+
67
+ def get_summary(self) -> str:
68
+ """獲取訓練摘要"""
69
+ if not self.history['epoch']:
70
+ return "尚無訓練記錄"
71
+
72
+ summary = "📈 訓練歷程摘要\n"
73
+ summary += f"總訓練輪數: {len(self.history['epoch'])}\n"
74
+ summary += f"最佳 Epoch: {self.history['best_epoch']}\n"
75
+ summary += f"最佳指標值: {self.history['best_metric_value']:.4f}\n\n"
76
+
77
+ summary += "各 Epoch 表現:\n"
78
+ for i, epoch in enumerate(self.history['epoch']):
79
+ summary += f"Epoch {epoch}: Loss={self.history['train_loss'][i]:.4f}, "
80
+ summary += f"F1={self.history['eval_f1'][i]:.4f}, "
81
+ summary += f"Acc={self.history['eval_accuracy'][i]:.4f}\n"
82
+
83
+ return summary
84
 
85
+ # ==================== 權重計算改進 ====================
86
+ def calculate_class_weights(n0: int, n1: int, weight_mult: float = 1.0,
87
+ method: str = 'sqrt') -> Tuple[float, float]:
88
+ """
89
+ 改進的類別權重計算
90
+
91
+ Args:
92
+ n0: 負類樣本數(存活)
93
+ n1: 正類樣本數(死亡)
94
+ weight_mult: 權重倍數調整
95
+ method: 計算方法 ('balanced', 'sqrt', 'log', 'custom')
96
+
97
+ Returns:
98
+ (w0, w1): 類別權重
99
+ """
100
+ if n1 == 0:
101
+ return 1.0, 1.0
102
+
103
+ ratio = n0 / n1
104
+ total = n0 + n1
105
+
106
+ if method == 'balanced':
107
+ # sklearn 風格的平衡權重
108
+ w0 = total / (2 * n0) if n0 > 0 else 1.0
109
+ w1 = total / (2 * n1) if n1 > 0 else 1.0
110
+ w1 *= weight_mult
111
+ elif method == 'sqrt':
112
+ # 使用平方根緩和極端權重(推薦用於極度不平衡)
113
+ w0 = 1.0
114
+ w1 = min(np.sqrt(ratio) * weight_mult, 10.0) # 設置上限為 10
115
+ elif method == 'log':
116
+ # 使用對數進一步緩和
117
+ w0 = 1.0
118
+ w1 = min(np.log1p(ratio) * weight_mult, 8.0) # 設置上限為 8
119
+ elif method == 'custom':
120
+ # 自定義邏輯,根據不平衡程度調整
121
+ if ratio > 20: # 極���不平衡
122
+ w0 = 1.0
123
+ w1 = min(5.0 * weight_mult, 10.0)
124
+ elif ratio > 10: # 高度不平衡
125
+ w0 = 1.0
126
+ w1 = min(ratio * 0.3 * weight_mult, 8.0)
127
+ elif ratio > 5: # 中度不平衡
128
+ w0 = 1.0
129
+ w1 = min(ratio * 0.5 * weight_mult, 6.0)
130
+ else: # 輕度不平衡
131
+ w0 = 1.0
132
+ w1 = ratio * weight_mult
133
+ else:
134
+ # 預設使用 sqrt 方法
135
+ w0 = 1.0
136
+ w1 = min(np.sqrt(ratio) * weight_mult, 10.0)
137
+
138
+ return w0, w1
139
 
140
+ # ==================== 評估指標計算 ====================
141
  def compute_metrics(pred):
142
+ """計算完整的評估指標"""
143
  try:
144
  labels = pred.label_ids
145
  preds = pred.predictions.argmax(-1)
146
+
147
+ # 基本指標
148
+ precision, recall, f1, _ = precision_recall_fscore_support(
149
+ labels, preds, average='binary', pos_label=1, zero_division=0
150
+ )
151
  acc = accuracy_score(labels, preds)
152
+
153
+ # 混淆矩陣
154
  cm = confusion_matrix(labels, preds)
155
+ tn = fp = fn = tp = 0
156
  if cm.shape == (2, 2):
157
  tn, fp, fn, tp = cm.ravel()
158
+
159
+ # 敏感度和特異度
160
  sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
161
  specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
162
+
163
+ # 額外指標
164
+ ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 # 陽性預測值
165
+ npv = tn / (tn + fn) if (tn + fn) > 0 else 0 # 陰性預測值
166
+
167
  return {
168
+ 'accuracy': acc,
169
+ 'f1': f1,
170
+ 'precision': precision,
171
+ 'recall': recall,
172
+ 'sensitivity': sensitivity,
173
+ 'specificity': specificity,
174
+ 'ppv': ppv,
175
+ 'npv': npv,
176
+ 'tp': int(tp),
177
+ 'tn': int(tn),
178
+ 'fp': int(fp),
179
+ 'fn': int(fn)
180
  }
181
  except Exception as e:
182
  print(f"Error in compute_metrics: {e}")
183
+ return {k: 0 for k in ['accuracy', 'f1', 'precision', 'recall',
184
+ 'sensitivity', 'specificity', 'ppv', 'npv',
185
+ 'tp', 'tn', 'fp', 'fn']}
 
186
 
187
+ # ==================== 基準模型評估(修正版,只保留一個) ====================
188
+ def evaluate_baseline(model, tokenizer, test_dataset, device, batch_size=16):
189
  """評估未微調的基準模型"""
190
  model.eval()
191
  all_preds = []
192
  all_labels = []
193
 
 
 
194
  def collate_fn(batch):
195
  return {
196
  'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
 
198
  'labels': torch.tensor([item['label'] for item in batch])
199
  }
200
 
201
+ dataloader = DataLoader(
202
+ test_dataset,
203
+ batch_size=batch_size,
204
+ collate_fn=collate_fn,
205
+ pin_memory=torch.cuda.is_available(),
206
+ num_workers=0 # 避免多進程問題
207
+ )
208
 
209
  with torch.no_grad():
210
  for batch in dataloader:
 
215
  all_preds.extend(preds.cpu().numpy())
216
  all_labels.extend(labels.numpy())
217
 
218
+ # 計算所有指標
219
+ precision, recall, f1, _ = precision_recall_fscore_support(
220
+ all_labels, all_preds, average='binary', pos_label=1, zero_division=0
221
+ )
222
  acc = accuracy_score(all_labels, all_preds)
223
+
224
  cm = confusion_matrix(all_labels, all_preds)
225
+ tn = fp = fn = tp = 0
226
  if cm.shape == (2, 2):
227
  tn, fp, fn, tp = cm.ravel()
228
+
 
229
  sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
230
  specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
231
+ ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
232
+ npv = tn / (tn + fn) if (tn + fn) > 0 else 0
233
 
234
  return {
235
+ 'accuracy': acc,
236
+ 'f1': f1,
237
+ 'precision': precision,
238
+ 'recall': recall,
239
+ 'sensitivity': sensitivity,
240
+ 'specificity': specificity,
241
+ 'ppv': ppv,
242
+ 'npv': npv,
243
+ 'tp': int(tp),
244
+ 'tn': int(tn),
245
+ 'fp': int(fp),
246
+ 'fn': int(fn)
247
  }
248
 
249
+ # ==================== 自定義 Trainer 與 Early Stopping ====================
250
+ class CustomTrainer(Trainer):
251
+ """支援類別權重、Focal Loss 和 Early Stopping 的 Trainer"""
252
+
253
+ def __init__(self, *args, class_weights=None, use_focal_loss=False,
254
+ focal_gamma=2.0, monitor=None, early_stopping_patience=3,
255
+ early_stopping_metric='eval_f1', **kwargs):
256
  super().__init__(*args, **kwargs)
257
  self.class_weights = class_weights
258
  self.use_focal_loss = use_focal_loss
259
+ self.focal_gamma = focal_gamma
260
+ self.monitor = monitor
261
+ self.early_stopping_patience = early_stopping_patience
262
+ self.early_stopping_metric = early_stopping_metric
263
+ self.best_metric = -float('inf')
264
+ self.best_model_state = None
265
+ self.patience_counter = 0
266
+ self.current_epoch = 0
267
 
268
  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
269
+ """計算損失函數"""
270
  labels = inputs.pop("labels")
271
  outputs = model(**inputs)
272
  logits = outputs.logits
273
 
274
+ if self.use_focal_loss and self.class_weights is not None:
275
+ # Focal Loss 實現
276
  ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')(
277
  logits.view(-1, 2), labels.view(-1)
278
  )
279
  pt = torch.exp(-ce_loss)
280
+ focal_loss = ((1 - pt) ** self.focal_gamma * ce_loss).mean()
281
  loss = focal_loss
282
+ elif self.class_weights is not None:
283
+ # 標準加權交叉熵
284
  loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
285
  loss = loss_fct(logits.view(-1, 2), labels.view(-1))
286
+ else:
287
+ # 標準交叉熵
288
+ loss_fct = nn.CrossEntropyLoss()
289
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
290
 
291
  return (loss, outputs) if return_outputs else loss
 
 
 
 
 
 
292
 
293
+ def on_epoch_end(self, args, state, control, **kwargs):
294
+ """每個 epoch 結束時的回調"""
295
+ self.current_epoch += 1
296
+
297
+ # 評估模型
298
+ metrics = self.evaluate()
299
+
300
+ # 記錄到監控器
301
+ if self.monitor:
302
+ self.monitor.log_epoch(
303
+ epoch=self.current_epoch,
304
+ train_loss=state.log_history[-1].get('loss', 0) if state.log_history else 0,
305
+ eval_metrics=metrics,
306
+ lr=self.get_learning_rate()
307
+ )
308
+
309
+ # Early Stopping 檢查
310
+ current_metric = metrics.get(self.early_stopping_metric, 0)
311
+
312
+ if current_metric > self.best_metric:
313
+ self.best_metric = current_metric
314
+ self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
315
+ self.patience_counter = 0
316
+
317
+ if self.monitor:
318
+ self.monitor.update_best(self.current_epoch, current_metric)
319
+
320
+ print(f"✅ Epoch {self.current_epoch}: 新最佳 {self.early_stopping_metric} = {current_metric:.4f}")
321
+ else:
322
+ self.patience_counter += 1
323
+ print(f"⏳ Epoch {self.current_epoch}: 無改善 (patience: {self.patience_counter}/{self.early_stopping_patience})")
324
+
325
+ if self.patience_counter >= self.early_stopping_patience:
326
+ print(f"🛑 Early Stopping 於 Epoch {self.current_epoch}")
327
+ control.should_training_stop = True
328
+
329
+ return control
330
 
331
+ def get_learning_rate(self):
332
+ """獲取當前學習率"""
333
+ if self.optimizer is None:
334
+ return 0
335
+ return self.optimizer.param_groups[0]['lr']
 
 
 
336
 
337
+ def load_best_model(self):
338
+ """載入最佳模型"""
339
+ if self.best_model_state:
340
+ self.model.load_state_dict(self.best_model_state)
341
+ print(f"✅ 已載入最佳模型 (最佳 {self.early_stopping_metric} = {self.best_metric:.4f})")
342
+
343
+ # ==================== 基準模型快取(改進版) ====================
344
+ @lru_cache(maxsize=3)
345
+ def get_cached_baseline_model(model_name: str, num_labels: int = 2):
346
+ """使用 LRU 快取管理基準模型"""
347
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
348
+ model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
349
+ return model.to(device)
350
+
351
+ # ==================== 改善率計算 ====================
352
+ def calculate_improvement(baseline_val: float, finetuned_val: float) -> float:
353
+ """安全計算改善率"""
354
+ if baseline_val == 0:
355
+ return float('inf') if finetuned_val > 0 else 0.0
356
+ return (finetuned_val - baseline_val) / baseline_val * 100
357
+
358
+ def format_improvement(val: float) -> str:
359
+ """格式化改善率顯示"""
360
+ if val == float('inf'):
361
+ return "N/A (baseline=0)"
362
+ elif val > 0:
363
+ return f"↑ {val:.1f}%"
364
+ elif val < 0:
365
+ return f"↓ {abs(val):.1f}%"
366
  else:
367
+ return "→ 0.0%"
 
 
 
 
 
 
 
 
368
 
369
+ # ==================== 主要訓練函數(改進版) ====================
370
+ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
371
+ weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
372
+ weight_mult, weight_method, best_metric, use_early_stopping, patience):
373
+ """
374
+ 改進的 BERT 模型訓練函數
375
+ """
376
+ global trained_models, model_counter, training_histories
377
 
378
  model_mapping = {
379
  "BERT-base": "bert-base-uncased",
380
+ "BERT-base-chinese": "bert-base-chinese",
381
+ "BioBERT": "dmis-lab/biobert-base-cased-v1.2",
382
+ "SciBERT": "allenai/scibert_scivocab_uncased"
383
  }
384
 
385
  model_name = model_mapping.get(base_model, "bert-base-uncased")
386
 
387
  try:
388
+ # ========== 資料驗證與載入 ==========
389
  if csv_file is None:
390
+ return "❌ 請上傳 CSV 檔案", "", "", "", ""
391
 
392
  df = pd.read_csv(csv_file.name)
393
  if 'Text' not in df.columns or 'label' not in df.columns:
394
+ return "❌ CSV 必須包含 'Text''label' 欄位", "", "", "", ""
395
 
396
+ # 資料清理
397
  df_clean = pd.DataFrame({
398
+ 'text': df['Text'].astype(str),
399
  'label': df['label'].astype(int)
400
  }).dropna()
401
 
402
+ # 統計資料
403
  n0 = int(sum(df_clean['label'] == 0))
404
  n1 = int(sum(df_clean['label'] == 1))
 
 
405
 
406
+ if n1 == 0:
407
+ return "❌ 資料集中沒有正類樣本(死亡)", "", "", "", ""
408
+
409
+ ratio = n0 / n1 if n1 > 0 else 0
410
+
411
+ # ========== 計算類別權重 ==========
412
+ w0, w1 = calculate_class_weights(n0, n1, weight_mult, method=weight_method)
413
+
414
+ # ========== 準備資料資訊 ==========
415
+ info = f"📊 資料集統計\n"
416
+ info += f"{'='*50}\n"
417
+ info += f"總樣本數: {len(df_clean):,}\n"
418
+ info += f"存活 (0): {n0:,} ({n0/len(df_clean)*100:.1f}%)\n"
419
+ info += f"死亡 (1): {n1:,} ({n1/len(df_clean)*100:.1f}%)\n"
420
+ info += f"不平衡比例: {ratio:.2f}:1\n"
421
+ info += f"\n⚖️ 類別權重設定\n"
422
+ info += f"{'='*50}\n"
423
+ info += f"計算方法: {weight_method}\n"
424
+ info += f"存活權重: {w0:.3f}\n"
425
+ info += f"死亡權重: {w1:.3f}\n"
426
+ info += f"權重比例: 1:{w1/w0:.2f}\n"
427
+
428
+ # ========== 模型與分詞器初始化 ==========
429
+ info += f"\n🤖 模型配置\n"
430
+ info += f"{'='*50}\n"
431
+ info += f"基礎模型: {base_model}\n"
432
+ info += f"模型路徑: {model_name}\n"
433
+ info += f"微調方法: {method.upper()}\n"
434
 
435
  tokenizer = BertTokenizer.from_pretrained(model_name)
436
+
437
+ # ========== 資料集準備 ==========
438
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
439
 
440
+ def preprocess(examples):
441
+ return tokenizer(
442
+ examples['text'],
443
+ truncation=True,
444
+ padding='max_length',
445
+ max_length=128
446
+ )
447
 
448
  tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
449
+ split = tokenized.train_test_split(test_size=0.2, seed=42, stratify=tokenized['label'])
450
 
451
+ # ========== 設備配置 ==========
452
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
453
+ info += f"運算設備: {'GPU ✅ (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU ⚠️'}\n"
454
+
455
+ # ========== 評估基準模型 ==========
456
+ info += f"\n📏 基準模型評估\n"
457
+ info += f"{'='*50}\n"
458
+ info += f"正在評估未微調的 {base_model}...\n"
459
 
 
460
  baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
461
  baseline_model = baseline_model.to(device)
462
 
463
+ baseline_perf = evaluate_baseline(
464
+ baseline_model, tokenizer, split['test'], device, batch_size=batch_size*2
465
+ )
466
+
467
+ info += f"基準 F1 分數: {baseline_perf['f1']:.4f}\n"
468
+ info += f"基準準確率: {baseline_perf['accuracy']:.4f}\n"
469
 
470
+ # 清理基準模型記憶體
471
  del baseline_model
472
  if torch.cuda.is_available():
473
  torch.cuda.empty_cache()
474
  gc.collect()
475
 
476
+ # ========== 配置微調模型 ==========
477
+ info += f"\n🔧 微調配置\n"
478
+ info += f"{'='*50}\n"
479
+
480
  model = BertForSequenceClassification.from_pretrained(
481
+ model_name,
482
+ num_labels=2,
483
  hidden_dropout_prob=dropout,
484
  attention_probs_dropout_prob=dropout
485
  )
486
 
487
+ # 應用 PEFT 方法
488
  peft_applied = False
489
  if method == "lora":
490
+ from peft import LoraConfig, get_peft_model, TaskType
491
+
492
  config = LoraConfig(
493
+ task_type=TaskType.SEQ_CLS,
494
+ r=int(lora_r),
495
  lora_alpha=int(lora_alpha),
496
+ lora_dropout=lora_dropout,
497
+ target_modules=["query", "value"],
498
  bias="none"
499
  )
500
  model = get_peft_model(model, config)
501
  peft_applied = True
502
+ info += f"✅ LoRA 已套用\n"
503
+ info += f" - Rank (r): {int(lora_r)}\n"
504
+ info += f" - Alpha: {int(lora_alpha)}\n"
505
+ info += f" - Dropout: {lora_dropout}\n"
506
+
507
  elif method == "adalora":
508
+ from peft import AdaLoraConfig, get_peft_model, TaskType
509
+
510
  config = AdaLoraConfig(
511
+ task_type=TaskType.SEQ_CLS,
512
+ r=int(lora_r),
513
  lora_alpha=int(lora_alpha),
514
+ lora_dropout=lora_dropout,
515
  target_modules=["query", "value"],
516
+ init_r=12,
517
+ target_r=int(lora_r),
518
+ tinit=200,
519
+ tfinal=1000,
520
+ deltaT=10
521
  )
522
  model = get_peft_model(model, config)
523
  peft_applied = True
524
+ info += f"✅ AdaLoRA 已套用\n"
525
+ info += f" - Initial Rank: 12\n"
526
+ info += f" - Target Rank: {int(lora_r)}\n"
527
+ info += f" - Alpha: {int(lora_alpha)}\n"
528
 
529
+ elif method == "full":
530
+ info += f" Full Fine-tuning 模式\n"
531
+ peft_applied = False
532
 
533
  model = model.to(device)
534
 
535
+ # 參數統計
536
+ total_params = sum(p.numel() for p in model.parameters())
537
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
538
+
539
+ info += f"\n💾 模型參數\n"
540
+ info += f"{'='*50}\n"
541
+ info += f"總參數量: {total_params:,}\n"
542
+ info += f"可訓練參數: {trainable_params:,}\n"
543
+ info += f"可訓練比例: {trainable_params/total_params*100:.2f}%\n"
544
+ info += f"記憶體節省: {(1 - trainable_params/total_params)*100:.1f}%\n"
545
 
546
+ # ========== 準備訓練 ==========
547
  weights = torch.tensor([w0, w1], dtype=torch.float).to(device)
548
+ use_focal = ratio > 10 # 極度不平衡時使用 Focal Loss
549
 
550
+ if use_focal:
551
+ info += f"\n⚡ 特殊設定\n"
552
+ info += f"{'='*50}\n"
553
+ info += f"使用 Focal Loss (γ=2.0) 處理極度不平衡\n"
554
+
555
+ # 訓練參數
556
+ training_args = TrainingArguments(
557
+ output_dir='./results',
558
  num_train_epochs=int(num_epochs),
559
+ per_device_train_batch_size=int(batch_size),
560
+ per_device_eval_batch_size=int(batch_size) * 2,
561
+ learning_rate=float(learning_rate),
562
  weight_decay=float(weight_decay),
563
+ evaluation_strategy="epoch",
564
+ save_strategy="no", # 使用自定義保存策略
565
+ load_best_model_at_end=False,
566
+ report_to="none",
567
+ logging_steps=max(1, len(split['train']) // (int(batch_size) * 10)),
568
+ warmup_steps=min(500, len(split['train']) // int(batch_size)),
569
+ logging_first_step=True,
570
+ remove_unused_columns=False,
571
+ label_smoothing_factor=0.1 if ratio > 20 else 0.0, # 極度不平衡時使用標籤平滑
572
  )
573
 
574
+ # 創建監控器
575
+ monitor = TrainingMonitor()
576
+
577
+ # 創建自定義 Trainer
578
+ trainer = CustomTrainer(
579
+ model=model,
580
+ args=training_args,
581
  train_dataset=split['train'],
582
+ eval_dataset=split['test'],
583
  compute_metrics=compute_metrics,
584
  class_weights=weights,
585
+ use_focal_loss=use_focal,
586
+ focal_gamma=2.0,
587
+ monitor=monitor,
588
+ early_stopping_patience=patience if use_early_stopping else 999,
589
+ early_stopping_metric=f'eval_{best_metric}'
590
  )
591
 
592
+ info += f"\n🚀 訓練設定\n"
593
+ info += f"{'='*50}\n"
594
+ info += f"訓練樣本: {len(split['train']):,}\n"
595
+ info += f"測試樣本: {len(split['test']):,}\n"
596
+ info += f"批次大小: {int(batch_size)}\n"
597
+ info += f"訓練輪數: {int(num_epochs)}\n"
598
+ info += f"批次數/輪: {len(split['train']) // int(batch_size)}\n"
599
+ info += f"Early Stopping: {'開啟 (patience=' + str(patience) + ')' if use_early_stopping else '關閉'}\n"
600
+ info += f"最佳指標: {best_metric}\n"
601
 
602
+ info += f"\n⏳ 開始訓練...\n"
603
+ info += f"{'='*50}\n"
 
 
 
604
 
605
+ # ========== 執行訓練 ==========
606
  train_result = trainer.train()
607
 
608
+ # 載入最佳模型
609
+ if use_early_stopping:
610
+ trainer.load_best_model()
611
 
612
+ # 最終評估
613
+ final_results = trainer.evaluate()
614
 
615
+ # ========== 保存模型與結果 ==========
616
  model_counter += 1
617
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
618
+ model_id = f"{base_model}_{method}_{model_counter}_{timestamp}"
619
+
620
  trained_models[model_id] = {
621
+ 'model': model,
622
+ 'tokenizer': tokenizer,
623
+ 'results': final_results,
624
+ 'baseline': baseline_perf,
625
  'config': {
626
+ 'type': base_model,
627
+ 'model_name': model_name,
628
+ 'method': method,
629
+ 'metric': best_metric,
630
+ 'epochs': int(num_epochs),
631
+ 'batch_size': int(batch_size),
632
+ 'learning_rate': float(learning_rate),
633
+ 'weight_method': weight_method,
634
+ 'weight_mult': weight_mult
635
  },
636
+ 'timestamp': timestamp,
637
+ 'monitor': monitor # 保存訓練歷史
638
  }
639
 
640
+ training_histories[model_id] = monitor.history
641
+
642
+ info += f"\n✅ 訓練完成!\n"
643
+ info += f"最終 Training Loss: {train_result.training_loss:.4f}\n"
644
+ if monitor.history['best_epoch']:
645
+ info += f"最佳 Epoch: {monitor.history['best_epoch']}\n"
646
+
647
+ # ========== 準備輸出結果 ==========
648
+ # 基準模型結果
649
+ baseline_output = format_baseline_results(baseline_perf)
650
+
651
+ # 微調模型結果
652
+ finetuned_output = format_finetuned_results(model_id, final_results)
653
+
654
+ # 比較結果
655
+ comparison_output = format_comparison_results(baseline_perf, final_results)
656
+
657
+ # 訓練歷程
658
+ history_output = monitor.get_summary()
659
+
660
+ return info, baseline_output, finetuned_output, comparison_output, history_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
 
662
  except Exception as e:
663
  import traceback
664
+ error_msg = f"❌ 錯誤發生\n\n錯誤類型: {type(e).__name__}\n錯誤訊息: {str(e)}\n\n"
665
+ error_msg += f"詳細追蹤:\n{traceback.format_exc()}"
666
+ return error_msg, "", "", "", ""
667
+
668
+ # ==================== 格式化輸出函數 ====================
669
+ def format_baseline_results(baseline_perf: Dict) -> str:
670
+ """格式化基準模型結果"""
671
+ output = "🔬 純 BERT(未微調)\n\n"
672
+ output += "📊 模型表現\n"
673
+ output += f"{'='*30}\n"
674
+ output += f"F1 Score: {baseline_perf['f1']:.4f}\n"
675
+ output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
676
+ output += f"Precision: {baseline_perf['precision']:.4f}\n"
677
+ output += f"Recall: {baseline_perf['recall']:.4f}\n"
678
+ output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n"
679
+ output += f"Specificity: {baseline_perf['specificity']:.4f}\n"
680
+ output += f"PPV: {baseline_perf['ppv']:.4f}\n"
681
+ output += f"NPV: {baseline_perf['npv']:.4f}\n\n"
682
+ output += "📈 混淆矩陣\n"
683
+ output += f"{'='*30}\n"
684
+ output += f" 預測 0 預測 1\n"
685
+ output += f"實際 0 {baseline_perf['tn']:4d} {baseline_perf['fp']:4d}\n"
686
+ output += f"實際 1 {baseline_perf['fn']:4d} {baseline_perf['tp']:4d}\n"
687
+ return output
688
+
689
+ def format_finetuned_results(model_id: str, results: Dict) -> str:
690
+ """格式化微調模型結果"""
691
+ output = f"✅ 微調 BERT\n"
692
+ output += f"模型 ID: {model_id}\n\n"
693
+ output += "📊 模型表現\n"
694
+ output += f"{'='*30}\n"
695
+ output += f"F1 Score: {results['eval_f1']:.4f}\n"
696
+ output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
697
+ output += f"Precision: {results['eval_precision']:.4f}\n"
698
+ output += f"Recall: {results['eval_recall']:.4f}\n"
699
+ output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n"
700
+ output += f"Specificity: {results['eval_specificity']:.4f}\n"
701
+ output += f"PPV: {results['eval_ppv']:.4f}\n"
702
+ output += f"NPV: {results['eval_npv']:.4f}\n\n"
703
+ output += "📈 混淆矩陣\n"
704
+ output += f"{'='*30}\n"
705
+ output += f" 預測 0 預測 1\n"
706
+ output += f"實際 0 {results['eval_tn']:4d} {results['eval_fp']:4d}\n"
707
+ output += f"實際 1 {results['eval_fn']:4d} {results['eval_tp']:4d}\n"
708
+ return output
709
+
710
+ def format_comparison_results(baseline_perf: Dict, finetuned_results: Dict) -> str:
711
+ """格式化比較結果"""
712
+ output = "📊 純 BERT vs 微調 BERT 比較\n\n"
713
+ output += "指標改善分析:\n"
714
+ output += f"{'='*50}\n"
715
+ output += f"{'指標':<12} {'基準':>8} {'微調':>8} {'變化':>10} {'改善率':>10}\n"
716
+ output += f"{'-'*50}\n"
717
+
718
+ metrics = [
719
+ ('F1', 'f1', 'eval_f1'),
720
+ ('Accuracy', 'accuracy', 'eval_accuracy'),
721
+ ('Precision', 'precision', 'eval_precision'),
722
+ ('Recall', 'recall', 'eval_recall'),
723
+ ('Sensitivity', 'sensitivity', 'eval_sensitivity'),
724
+ ('Specificity', 'specificity', 'eval_specificity'),
725
+ ('PPV', 'ppv', 'eval_ppv'),
726
+ ('NPV', 'npv', 'eval_npv')
727
+ ]
728
+
729
+ for name, base_key, fine_key in metrics:
730
+ base_val = baseline_perf[base_key]
731
+ fine_val = finetuned_results[fine_key]
732
+ change = fine_val - base_val
733
+ improve = calculate_improvement(base_val, fine_val)
734
+
735
+ output += f"{name:<12} {base_val:>8.4f} {fine_val:>8.4f} "
736
+ output += f"{change:+10.4f} {format_improvement(improve):>10}\n"
737
+
738
+ output += f"\n混淆矩陣變化:\n"
739
+ output += f"{'='*40}\n"
740
+ output += f"{'項目':<10} {'基準':>8} {'微調':>8} {'變化':>10}\n"
741
+ output += f"{'-'*40}\n"
742
+
743
+ cm_items = [
744
+ ('True Pos', 'tp', 'eval_tp'),
745
+ ('True Neg', 'tn', 'eval_tn'),
746
+ ('False Pos', 'fp', 'eval_fp'),
747
+ ('False Neg', 'fn', 'eval_fn')
748
+ ]
749
+
750
+ for name, base_key, fine_key in cm_items:
751
+ base_val = baseline_perf[base_key]
752
+ fine_val = finetuned_results[fine_key]
753
+ change = fine_val - base_val
754
+
755
+ output += f"{name:<10} {base_val:>8d} {fine_val:>8d} {change:+10d}\n"
756
+
757
+ # 總結
758
+ output += f"\n📈 整體評估:\n"
759
+ output += f"{'='*40}\n"
760
+
761
+ f1_improve = calculate_improvement(baseline_perf['f1'], finetuned_results['eval_f1'])
762
+ if f1_improve > 10:
763
+ output += "✅ 顯著改善:微調帶來明顯的性能提升!\n"
764
+ elif f1_improve > 0:
765
+ output += "✅ 有所改善:微調產生正向影響。\n"
766
+ elif f1_improve == 0:
767
+ output += "➖ 無變化:微調未產生明顯影響。\n"
768
+ else:
769
+ output += "⚠️ 性能下降:可能需要調整超參數。\n"
770
+
771
+ return output
772
 
773
+ # ==================== 預測函數(改進版) ====================
774
  def predict(model_id, text):
775
+ """使用選定模型進行預測並與基準模型比較"""
776
 
777
  if not model_id or model_id not in trained_models:
778
+ return "❌ 請選擇一個已訓練的模型"
779
+
780
+ if not text or len(text.strip()) == 0:
781
+ return "❌ 請輸入要預測的文字"
782
 
783
  try:
784
+ # 獲取模型資訊
785
  info = trained_models[model_id]
786
+ model = info['model']
787
+ tokenizer = info['tokenizer']
788
  config = info['config']
789
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
790
 
791
+ # 文字預處理
792
+ inputs = tokenizer(
793
+ text,
794
+ return_tensors="pt",
795
+ truncation=True,
796
+ padding=True,
797
+ max_length=128
798
+ )
799
+ inputs_device = {k: v.to(device) for k, v in inputs.items()}
800
 
801
+ # ========== 微調模型預測 ==========
802
  model.eval()
803
  with torch.no_grad():
804
+ outputs = model(**inputs_device)
805
+ logits = outputs.logits
806
+ probs_finetuned = torch.nn.functional.softmax(logits, dim=-1)
807
  pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
808
+ confidence_finetuned = probs_finetuned[0][pred_finetuned].item()
809
 
810
+ # ========== 基準模型預測 ==========
811
+ baseline_model = get_cached_baseline_model(config['model_name'])
812
+ baseline_model.eval()
 
 
 
 
 
 
 
 
813
 
814
  with torch.no_grad():
815
+ outputs_baseline = baseline_model(**inputs_device)
816
+ logits_baseline = outputs_baseline.logits
817
+ probs_baseline = torch.nn.functional.softmax(logits_baseline, dim=-1)
818
  pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
819
+ confidence_baseline = probs_baseline[0][pred_baseline].item()
820
 
821
+ # ========== 格式化輸出 ==========
822
+ result_finetuned = "🟢 存活" if pred_finetuned == 0 else "🔴 死亡"
823
+ result_baseline = "🟢 存活" if pred_baseline == 0 else "🔴 死亡"
824
  agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
825
 
826
+ output = f"""🔮 預測結果比較分析
827
+
828
+ 📝 輸入文字
829
+ {'='*60}
830
+ {text[:200]}{'...' if len(text) > 200 else ''}
831
 
832
+ {'='*60}
833
 
834
+ 🎯 微調模型預測 ({model_id})
835
+ {'='*60}
836
+ 預測結果: {result_finetuned}
837
+ 預測信心: {confidence_finetuned:.1%}
838
 
 
 
 
839
  機率分布:
840
+ 存活 (0): {probs_finetuned[0][0].item():.2%}
841
+ 死亡 (1): {probs_finetuned[0][1].item():.2%}
842
 
843
+ 模型配置:
844
+ • 方法: {config['method'].upper()}
845
+ • 基礎模型: {config['type']}
846
+ • 訓練輪數: {config['epochs']}
847
+
848
+ {'='*60}
849
+
850
+ 🔬 基準模型預測(未微調 {config['type']})
851
+ {'='*60}
852
+ 預測結果: {result_baseline}
853
+ 預測信心: {confidence_baseline:.1%}
854
 
 
 
 
855
  機率分布:
856
+ 存活 (0): {probs_baseline[0][0].item():.2%}
857
+ 死亡 (1): {probs_baseline[0][1].item():.2%}
858
 
859
+ {'='*60}
860
 
861
+ 📊 預測分析
862
+ {'='*60}
863
  兩模型預測: {agreement}
864
  """
865
 
866
  if pred_finetuned != pred_baseline:
867
+ output += f"""
868
+ 💡 差異分析:
869
+ 微調模型預測【{result_finetuned}】(信心: {confidence_finetuned:.1%})
870
+ 基準模型預測【{result_baseline}】(信心: {confidence_baseline:.1%})
871
+
872
+ 這種差異顯示了微調對此特定案例的影響。
873
+ 微調模型可能學習到了更適合您資料集的特徵。
874
+ """
875
+ else:
876
+ output += f"""
877
+ ✅ 預測一致性分析:
878
+ 兩個模型都預測為【{result_finetuned}】
879
+ 信心差異: {abs(confidence_finetuned - confidence_baseline):.1%}
880
+ """
881
 
882
+ # 加入模型整體表現對比
883
+ f1_improve = calculate_improvement(
884
+ info['baseline']['f1'],
885
+ info['results']['eval_f1']
886
+ )
887
 
888
  output += f"""
889
 
890
+ 📈 模型整體表現對比
891
+ {'='*60}
892
  微調模型 F1: {info['results']['eval_f1']:.4f}
893
  基準模型 F1: {info['baseline']['f1']:.4f}
894
+ 改善幅度: {format_improvement(f1_improve)}
895
+
896
+ 微調模型準確率: {info['results']['eval_accuracy']:.4f}
897
+ 基準模型準確率: {info['baseline']['accuracy']:.4f}
898
  """
899
 
900
  return output
901
 
902
  except Exception as e:
903
  import traceback
904
+ return f"❌ 預測時發生錯誤\n\n{str(e)}\n\n{traceback.format_exc()}"
905
 
906
+ # ==================== 模型比較函數 ====================
907
+ def compare_models():
908
+ """比較所有已訓練的模型"""
909
+
910
  if not trained_models:
911
+ return "❌ 尚未訓練任何模型。請先在「訓練」頁面訓練模型。"
912
+
913
+ output = "# 📊 模型比較報告\n\n"
914
+ output += f"共有 {len(trained_models)} 個已訓練模型\n\n"
915
 
916
+ # 微調模型表現表格
917
+ output += "## 🎯 微調模型表現\n\n"
918
+ output += "| 模型 ID | 基礎模型 | 方法 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n"
919
+ output += "|---------|----------|------|-----|--------|--------|--------|--------|--------|\n"
920
 
921
+ for model_id, info in trained_models.items():
922
  r = info['results']
923
  c = info['config']
924
+
925
+ # 縮短模型 ID 顯示
926
+ short_id = f"{c['type']}_{c['method']}_{info['timestamp'][-6:]}"
927
+
928
+ output += f"| {short_id} | {c['type']} | {c['method'].upper()} | "
929
+ output += f"{r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
930
+ output += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
931
+ output += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
932
 
933
+ # 基準模型表現
934
+ output += "\n## 🔬 基準模型表現(未微調)\n\n"
 
935
 
936
+ # 獲取唯一的基準模型
937
+ unique_baselines = {}
938
+ for model_id, info in trained_models.items():
939
+ base_type = info['config']['type']
940
+ if base_type not in unique_baselines:
941
+ unique_baselines[base_type] = info['baseline']
942
 
943
+ output += "| 基礎模型 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n"
944
+ output += "|----------|-----|--------|--------|--------|--------|--------|\n"
 
 
 
 
 
 
945
 
946
+ for base_type, baseline in unique_baselines.items():
947
+ output += f"| {base_type} | {baseline['f1']:.4f} | {baseline['accuracy']:.4f} | "
948
+ output += f"{baseline['precision']:.4f} | {baseline['recall']:.4f} | "
949
+ output += f"{baseline['sensitivity']:.4f} | {baseline['specificity']:.4f} |\n"
 
 
 
 
 
950
 
951
+ # 最佳模型分析
952
+ output += "\n## 🏆 最佳模型(各指標)\n\n"
953
+
954
+ metrics_to_check = [
955
+ ('F1 Score', 'eval_f1'),
956
+ ('準確率', 'eval_accuracy'),
957
+ ('精確率', 'eval_precision'),
958
+ ('召回率', 'eval_recall'),
959
+ ('敏感度', 'eval_sensitivity'),
960
+ ('特異度', 'eval_specificity')
961
+ ]
962
+
963
+ for metric_name, metric_key in metrics_to_check:
964
+ best_model = max(
965
+ trained_models.items(),
966
+ key=lambda x: x[1]['results'][metric_key]
 
967
  )
968
 
969
+ model_id = best_model[0]
970
+ value = best_model[1]['results'][metric_key]
971
+ baseline_val = best_model[1]['baseline'][metric_key.replace('eval_', '')]
972
+ improvement = calculate_improvement(baseline_val, value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973
 
974
+ output += f"**{metric_name}**: {model_id[:30]}... "
975
+ output += f"({value:.4f}, 改善 {format_improvement(improvement)})\n\n"
 
 
 
 
 
 
 
976
 
977
+ # 改善統計
978
+ output += "## 📈 改善統計\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979
 
980
+ improvements = []
981
+ for model_id, info in trained_models.items():
982
+ f1_base = info['baseline']['f1']
983
+ f1_fine = info['results']['eval_f1']
984
+ improve = calculate_improvement(f1_base, f1_fine)
985
+
986
+ if improve != float('inf'):
987
+ improvements.append({
988
+ 'model': model_id,
989
+ 'improvement': improve,
990
+ 'method': info['config']['method']
991
+ })
992
 
993
+ if improvements:
994
+ avg_improvement = np.mean([x['improvement'] for x in improvements])
995
+ max_improvement = max(improvements, key=lambda x: x['improvement'])
996
+ min_improvement = min(improvements, key=lambda x: x['improvement'])
997
+
998
+ output += f"平均 F1 改善: {format_improvement(avg_improvement)}\n"
999
+ output += f"最大改善: {max_improvement['model'][:30]}... ({format_improvement(max_improvement['improvement'])})\n"
1000
+ output += f"最小改善: {min_improvement['model'][:30]}... ({format_improvement(min_improvement['improvement'])})\n\n"
1001
+
1002
+ # 方法比較
1003
+ method_improvements = {}
1004
+ for imp in improvements:
1005
+ method = imp['method']
1006
+ if method not in method_improvements:
1007
+ method_improvements[method] = []
1008
+ method_improvements[method].append(imp['improvement'])
1009
+
1010
+ output += "### 各方法平均改善:\n"
1011
+ for method, imps in method_improvements.items():
1012
+ avg_imp = np.mean(imps)
1013
+ output += f"- **{method.upper()}**: {format_improvement(avg_imp)}\n"
1014
+
1015
+ return output
1016
+
1017
+ # ==================== Gradio UI ====================
1018
+ def create_demo():
1019
+ """創建 Gradio 介面"""
1020
+
1021
+ with gr.Blocks(
1022
+ title="BERT Fine-tuning 教學平台",
1023
+ theme=gr.themes.Soft(),
1024
+ css="""
1025
+ .gradio-container {font-family: 'Microsoft JhengHei', 'Arial', sans-serif;}
1026
+ """
1027
+ ) as demo:
1028
+
1029
+ gr.Markdown(
1030
+ """
1031
+ # 🧬 BERT Fine-tuning 教學平台
1032
+ ### 比較基準模型 vs 微調模型的表現差異(改進版)
1033
+ """
1034
+ )
 
 
 
 
 
 
 
 
1035
 
1036
+ with gr.Tab("🎯 訓練"):
1037
+ gr.Markdown("## 步驟 1: 選擇基礎模型")
1038
+
1039
+ base_model = gr.Dropdown(
1040
+ choices=["BERT-base", "BERT-base-chinese", "BioBERT", "SciBERT"],
1041
+ value="BERT-base",
1042
+ label="基礎模型",
1043
+ info="選擇適合您資料的預訓練模型"
1044
+ )
1045
+
1046
+ gr.Markdown("## 步驟 2: 選擇微調方法")
1047
+
1048
+ method = gr.Radio(
1049
+ choices=["lora", "adalora", "full"],
1050
+ value="lora",
1051
+ label="微調方法",
1052
+ info="LoRA 和 AdaLoRA 是參數高效方法,Full 是完全微調"
1053
+ )
1054
+
1055
+ gr.Markdown("## 步驟 3: 上傳資料")
1056
+
1057
+ csv_file = gr.File(
1058
+ label="CSV 檔案(需包含 Text 和 label 欄位)",
1059
+ file_types=[".csv"]
1060
+ )
1061
+
1062
+ gr.Markdown("## 步驟 4: 設定訓練參數")
1063
+
1064
+ with gr.Accordion("🎯 基本訓練參數", open=True):
1065
+ with gr.Row():
1066
+ num_epochs = gr.Number(
1067
+ value=5, label="訓練輪數", minimum=1, maximum=50, precision=0,
1068
+ info="建議 3-10 輪,過多可能過擬合"
1069
+ )
1070
+ batch_size = gr.Number(
1071
+ value=8, label="批次大小", minimum=1, maximum=64, precision=0,
1072
+ info="GPU 記憶體不足時請降低"
1073
+ )
1074
+ learning_rate = gr.Number(
1075
+ value=3e-5, label="學習率", minimum=1e-6, maximum=1e-3,
1076
+ info="建議 1e-5 到 5e-5"
1077
+ )
1078
+
1079
+ with gr.Accordion("⚙️ 進階參數"):
1080
+ with gr.Row():
1081
+ weight_decay = gr.Number(
1082
+ value=0.01, label="權重衰減", minimum=0, maximum=1,
1083
+ info="防止過擬合,建議 0.01-0.1"
1084
+ )
1085
+ dropout = gr.Number(
1086
+ value=0.1, label="Dropout 率", minimum=0, maximum=0.5,
1087
+ info="防止過擬合,建議 0.1-0.3"
1088
+ )
1089
+
1090
+ with gr.Accordion("🔧 PEFT 參數(LoRA/AdaLoRA)"):
1091
+ with gr.Row():
1092
+ lora_r = gr.Number(
1093
+ value=16, label="LoRA Rank (r)", minimum=1, maximum=64, precision=0,
1094
+ info="越大表達能力越強,但參數越多"
1095
+ )
1096
+ lora_alpha = gr.Number(
1097
+ value=32, label="LoRA Alpha", minimum=1, maximum=128, precision=0,
1098
+ info="通常設為 Rank 的 2 倍"
1099
+ )
1100
+ lora_dropout = gr.Number(
1101
+ value=0.05, label="LoRA Dropout", minimum=0, maximum=0.5,
1102
+ info="LoRA 層的 dropout"
1103
+ )
1104
+
1105
+ with gr.Accordion("⚖️ 類別平衡設定"):
1106
+ with gr.Row():
1107
+ weight_mult = gr.Number(
1108
+ value=1.0, label="權重倍數", minimum=0.1, maximum=5.0,
1109
+ info="調整少數類權重的倍數"
1110
+ )
1111
+ weight_method = gr.Dropdown(
1112
+ choices=["sqrt", "log", "balanced", "custom"],
1113
+ value="sqrt",
1114
+ label="權重計算方法",
1115
+ info="sqrt 和 log 適合極度不平衡資料"
1116
+ )
1117
+
1118
+ with gr.Accordion("🎯 訓練策略"):
1119
+ with gr.Row():
1120
+ best_metric = gr.Dropdown(
1121
+ choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
1122
+ value="f1",
1123
+ label="最佳模型指標",
1124
+ info="根據此指標選擇最佳模型"
1125
+ )
1126
+ use_early_stopping = gr.Checkbox(
1127
+ value=True, label="啟用 Early Stopping",
1128
+ info="當模型不再改善時提前停止"
1129
+ )
1130
+ patience = gr.Number(
1131
+ value=3, label="Patience", minimum=1, maximum=10, precision=0,
1132
+ info="幾輪無改善後停止訓練"
1133
+ )
1134
+
1135
+ train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
1136
+
1137
+ gr.Markdown("## 📊 訓練結果")
1138
+
1139
+ with gr.Row():
1140
+ data_info = gr.Textbox(label="📋 訓練資訊", lines=25)
1141
+ history_output = gr.Textbox(label="📈 訓練歷程", lines=25)
1142
+
1143
+ with gr.Row():
1144
+ baseline_result = gr.Textbox(label="🔬 基準模型(未微調)", lines=15)
1145
+ finetuned_result = gr.Textbox(label="✅ 微調模型", lines=15)
1146
+
1147
+ comparison_result = gr.Textbox(label="📊 效能比較分析", lines=20)
1148
+
1149
+ train_btn.click(
1150
+ train_bert_model,
1151
+ inputs=[
1152
+ csv_file, base_model, method, num_epochs, batch_size, learning_rate,
1153
+ weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
1154
+ weight_mult, weight_method, best_metric, use_early_stopping, patience
1155
+ ],
1156
+ outputs=[data_info, baseline_result, finetuned_result, comparison_result, history_output]
1157
+ )
1158
 
1159
+ with gr.Tab("🔮 預測"):
1160
+ gr.Markdown("## 使用訓練好的模型進行預測")
1161
+
1162
+ with gr.Row():
1163
+ model_dropdown = gr.Dropdown(
1164
+ label="選擇模型",
1165
+ choices=list(trained_models.keys()),
1166
+ interactive=True
1167
+ )
1168
+ refresh_btn = gr.Button("🔄 刷新模型列表", size="sm")
1169
+
1170
+ text_input = gr.Textbox(
1171
+ label="輸入要預測的文字",
1172
+ lines=5,
1173
+ placeholder="請輸入病例描述或相關文字..."
1174
+ )
1175
+
1176
+ predict_btn = gr.Button("🎯 執行預測", variant="primary", size="lg")
1177
+
1178
+ pred_output = gr.Textbox(label="預測結果與分析", lines=25)
1179
+
1180
+ # 刷新模型列表
1181
+ refresh_btn.click(
1182
+ lambda: gr.Dropdown(choices=list(trained_models.keys())),
1183
+ outputs=[model_dropdown]
1184
+ )
1185
+
1186
+ # 執行預測
1187
+ predict_btn.click(
1188
+ predict,
1189
+ inputs=[model_dropdown, text_input],
1190
+ outputs=[pred_output]
1191
+ )
1192
+
1193
+ # 範例
1194
+ gr.Examples(
1195
+ examples=[
1196
+ ["Patient with stage II breast cancer, showing good response to chemotherapy treatment."],
1197
+ ["Advanced metastatic cancer with multiple organ failure, poor prognosis."],
1198
+ ["Early stage tumor detected, surgery scheduled, excellent recovery expected."],
1199
+ ["Terminal stage disease, palliative care initiated, family counseling provided."]
1200
+ ],
1201
+ inputs=text_input
1202
+ )
1203
 
1204
+ with gr.Tab("📊 比較"):
1205
+ gr.Markdown("## 比較所有已訓練的模型")
1206
+
1207
+ compare_btn = gr.Button("📊 生成比較報告", variant="primary", size="lg")
1208
+ compare_output = gr.Markdown()
1209
+
1210
+ compare_btn.click(compare_models, outputs=[compare_output])
1211
+
1212
+ with gr.Tab("📖 說明"):
1213
+ gr.Markdown("""
1214
+ ## 📖 使用說明
1215
+
1216
+ ### 🎯 平台特色
1217
+
1218
+ 本改進版平台提供以下功能:
1219
+
1220
+ 1. **自動基準比較**:每次訓練都會自動評估基準模型,清楚顯示微調的改善
1221
+ 2. **訓練監控**:記錄每個 epoch 的詳細訓練歷程
1222
+ 3. **Early Stopping**:避免過擬合,自動選擇最佳模型
1223
+ 4. **多種權重策略**:針對不平衡資料提供多種處理方法
1224
+ 5. **完整評估指標**:包含 F1、準確率、精確率、召回率、敏感度、特異度、PPV、NPV
1225
+
1226
+ ### 🤖 支援的基礎模型
1227
+
1228
+ - **BERT-base**: 標準英文 BERT,適用於一般英文文本
1229
+ - **BERT-base-chinese**: 中文 BERT,適用於中文文本
1230
+ - **BioBERT**: 生物醫學領域專用 BERT
1231
+ - **SciBERT**: 科學文獻專用 BERT
1232
+
1233
+ ### 🔧 微調方法說明
1234
+
1235
+ - **LoRA** (Low-Rank Adaptation)
1236
+ - 參數效率最高,只訓練 <1% 參數
1237
+ - 訓練速度快,記憶體需求低
1238
+ - 適合大多數場景
1239
+
1240
+ - **AdaLoRA** (Adaptive LoRA)
1241
+ - 自動調整秩的分配
1242
+ - 可能獲得更好的效果
1243
+ - 訓練時間稍長
1244
+
1245
+ - **Full** (完全微調)
1246
+ - 訓練所有參數
1247
+ - 可能獲得最佳效果
1248
+ - 需要較大記憶體和時間
1249
+
1250
+ ### ⚖️ 處理不平衡資料
1251
+
1252
+ #### 權重計算方法:
1253
+
1254
+ 1. **sqrt** (平方根法) - 推薦用於極度不平衡
1255
+ - 使用平方根緩和權重
1256
+ - 避免權重過大導致過擬合
1257
+
1258
+ 2. **log** (對數法) - 更保守的方法
1259
+ - 使用對數進一步緩和
1260
+ - 適合極度不平衡且容易過擬合的情況
1261
+
1262
+ 3. **balanced** (平衡法)
1263
+ - sklearn 風格的自動平衡
1264
+ - 適合中度不平衡
1265
+
1266
+ 4. **custom** (自定義)
1267
+ - 根據不平衡程度自動調整
1268
+ - 綜合考慮多種因素
1269
+
1270
+ #### 建議參數設定:
1271
+
1272
+ **極度不平衡 (>20:1)**
1273
+ - 權重方法: sqrt 或 log
1274
+ - 權重倍數: 0.5-1.0
1275
+ - 使用 Focal Loss (自動啟用)
1276
+ - Early Stopping: 建議開啟
1277
+
1278
+ **高度不平衡 (10-20:1)**
1279
+ - 權重方法: sqrt
1280
+ - 權重倍數: 0.8-1.5
1281
+ - Early Stopping: 建議開啟
1282
+
1283
+ **中度不平衡 (5-10:1)**
1284
+ - 權重方法: balanced
1285
+ - 權重倍數: 1.0-2.0
1286
+
1287
+ **輕度不平衡 (<5:1)**
1288
+ - 權重方法: balanced
1289
+ - 權重倍數: 1.5-3.0
1290
+
1291
+ ### 📊 評估指標說明
1292
+
1293
+ - **F1 Score**: 精確率和召回率的調和平均,適合不平衡資料
1294
+ - **Accuracy**: 整體準確率
1295
+ - **Precision**: 預測為正類中實際為正類的比例
1296
+ - **Recall/Sensitivity**: 實際正類中被正確預測的比例
1297
+ - **Specificity**: 實際負類中被正確預測的比例
1298
+ - **PPV**: 陽性預測值
1299
+ - **NPV**: 陰性預測值
1300
+
1301
+ ### 🚀 快速開始指南
1302
+
1303
+ 1. **準備資料**
1304
+ - CSV 格式,包含 `Text` 和 `label` 欄位
1305
+ - label: 0=負類(如存活), 1=正類(如死亡)
1306
+
1307
+ 2. **選擇模型與方法**
1308
+ - 英文資料:BERT-base + LoRA
1309
+ - 中文資料:BERT-base-chinese + LoRA
1310
+ - 醫學資料:BioBERT + LoRA
1311
+
1312
+ 3. **設定參數**
1313
+ - 使用預設參數作為起點
1314
+ - 根據資料不平衡程度調整權重設定
1315
+
1316
+ 4. **訓練與評估**
1317
+ - 點擊「開始訓練」
1318
+ - 查看基準 vs 微調的比較
1319
+ - 觀察訓練歷程
1320
+
1321
+ 5. **測試預測**
1322
+ - 在「預測」頁面選擇模型
1323
+ - 輸入文字進行預測
1324
+ - 比較微調前後的差異
1325
+
1326
+ ### ⚠️ 注意事項
1327
+
1328
+ - GPU 可大幅加速訓練
1329
+ - 批次大小過大可能導致記憶體不足
1330
+ - Early Stopping 可避免過擬合
1331
+ - 極度不平衡資料建議使用較保守的權重設定
1332
+
1333
+ ### 💡 優化建議
1334
+
1335
+ 1. **記憶體不足**:降低批次大小或使用 LoRA
1336
+ 2. **過擬合**:增加 dropout、使用 Early Stopping、降低學習率
1337
+ 3. **欠擬合**:增加訓練輪數、提高學習率、增加模型容量
1338
+ 4. **不平衡資料**:調整類別權重、使用適當的評估指標(F1)
1339
+ """)
1340
+
1341
+ return demo
1342
 
1343
+ # ==================== 主程式 ====================
1344
  if __name__ == "__main__":
1345
+ demo = create_demo()
1346
  demo.launch(
1347
  server_name="0.0.0.0",
1348
  server_port=7860,
1349
  share=False,
1350
+ max_threads=4
1351
  )