import gradio as gr import pandas as pd import numpy as np import torch from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType from datasets import Dataset from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from torch import nn from torch.utils.data import DataLoader, WeightedRandomSampler import os from datetime import datetime import gc import json from functools import lru_cache from typing import Dict, List, Tuple, Optional import warnings warnings.filterwarnings('ignore') # 環境設置 os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # 優化 CUDA 設置 torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): torch.cuda.empty_cache() # ==================== 全域變數 ==================== trained_models = {} model_counter = 0 training_histories = {} # 新增:儲存訓練歷史 # ==================== 訓練監控類 ==================== class TrainingMonitor: """訓練過程監控器""" def __init__(self): self.history = { 'epoch': [], 'train_loss': [], 'eval_loss': [], 'eval_accuracy': [], 'eval_f1': [], 'eval_precision': [], 'eval_recall': [], 'learning_rate': [], 'best_epoch': None, 'best_metric_value': None } def log_epoch(self, epoch: int, train_loss: float, eval_metrics: Dict, lr: float): """記錄每個 epoch 的結果""" self.history['epoch'].append(epoch) self.history['train_loss'].append(train_loss) self.history['eval_loss'].append(eval_metrics.get('eval_loss', 0)) self.history['eval_accuracy'].append(eval_metrics.get('eval_accuracy', 0)) self.history['eval_f1'].append(eval_metrics.get('eval_f1', 0)) self.history['eval_precision'].append(eval_metrics.get('eval_precision', 0)) self.history['eval_recall'].append(eval_metrics.get('eval_recall', 0)) self.history['learning_rate'].append(lr) def update_best(self, epoch: int, metric_value: float): """更新最佳結果""" self.history['best_epoch'] = epoch self.history['best_metric_value'] = metric_value def get_summary(self) -> str: """獲取訓練摘要""" if not self.history['epoch']: return "尚無訓練記錄" summary = "📈 訓練歷程摘要\n" summary += f"總訓練輪數: {len(self.history['epoch'])}\n" summary += f"最佳 Epoch: {self.history['best_epoch']}\n" summary += f"最佳指標值: {self.history['best_metric_value']:.4f}\n\n" summary += "各 Epoch 表現:\n" for i, epoch in enumerate(self.history['epoch']): summary += f"Epoch {epoch}: Loss={self.history['train_loss'][i]:.4f}, " summary += f"F1={self.history['eval_f1'][i]:.4f}, " summary += f"Acc={self.history['eval_accuracy'][i]:.4f}\n" return summary # ==================== 權重計算改進 ==================== def calculate_class_weights(n0: int, n1: int, weight_mult: float = 1.0, method: str = 'sqrt') -> Tuple[float, float]: """ 改進的類別權重計算 Args: n0: 負類樣本數(存活) n1: 正類樣本數(死亡) weight_mult: 權重倍數調整 method: 計算方法 ('balanced', 'sqrt', 'log', 'custom') Returns: (w0, w1): 類別權重 """ if n1 == 0: return 1.0, 1.0 ratio = n0 / n1 total = n0 + n1 if method == 'balanced': # sklearn 風格的平衡權重 w0 = total / (2 * n0) if n0 > 0 else 1.0 w1 = total / (2 * n1) if n1 > 0 else 1.0 w1 *= weight_mult elif method == 'sqrt': # 使用平方根緩和極端權重(推薦用於極度不平衡) w0 = 1.0 w1 = min(np.sqrt(ratio) * weight_mult, 10.0) # 設置上限為 10 elif method == 'log': # 使用對數進一步緩和 w0 = 1.0 w1 = min(np.log1p(ratio) * weight_mult, 8.0) # 設置上限為 8 elif method == 'custom': # 自定義邏輯,根據不平衡程度調整 if ratio > 20: # 極度不平衡 w0 = 1.0 w1 = min(5.0 * weight_mult, 10.0) elif ratio > 10: # 高度不平衡 w0 = 1.0 w1 = min(ratio * 0.3 * weight_mult, 8.0) elif ratio > 5: # 中度不平衡 w0 = 1.0 w1 = min(ratio * 0.5 * weight_mult, 6.0) else: # 輕度不平衡 w0 = 1.0 w1 = ratio * weight_mult else: # 預設使用 sqrt 方法 w0 = 1.0 w1 = min(np.sqrt(ratio) * weight_mult, 10.0) return w0, w1 # ==================== 評估指標計算 ==================== def compute_metrics(pred): """計算完整的評估指標""" try: labels = pred.label_ids preds = pred.predictions.argmax(-1) # 基本指標 precision, recall, f1, _ = precision_recall_fscore_support( labels, preds, average='binary', pos_label=1, zero_division=0 ) acc = accuracy_score(labels, preds) # 混淆矩陣 cm = confusion_matrix(labels, preds) tn = fp = fn = tp = 0 if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel() # 敏感度和特異度 sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 # 額外指標 ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 # 陽性預測值 npv = tn / (tn + fn) if (tn + fn) > 0 else 0 # 陰性預測值 return { 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, 'sensitivity': sensitivity, 'specificity': specificity, 'ppv': ppv, 'npv': npv, 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) } except Exception as e: print(f"Error in compute_metrics: {e}") return {k: 0 for k in ['accuracy', 'f1', 'precision', 'recall', 'sensitivity', 'specificity', 'ppv', 'npv', 'tp', 'tn', 'fp', 'fn']} # ==================== 基準模型評估(修正版,只保留一個) ==================== def evaluate_baseline(model, tokenizer, test_dataset, device, batch_size=16): """評估未微調的基準模型""" model.eval() all_preds = [] all_labels = [] def collate_fn(batch): return { 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]), 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]), 'labels': torch.tensor([item['label'] for item in batch]) } dataloader = DataLoader( test_dataset, batch_size=batch_size, collate_fn=collate_fn, pin_memory=torch.cuda.is_available(), num_workers=0 # 避免多進程問題 ) with torch.no_grad(): for batch in dataloader: labels = batch.pop('labels') inputs = {k: v.to(device) for k, v in batch.items()} outputs = model(**inputs) preds = torch.argmax(outputs.logits, dim=-1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) # 計算所有指標 precision, recall, f1, _ = precision_recall_fscore_support( all_labels, all_preds, average='binary', pos_label=1, zero_division=0 ) acc = accuracy_score(all_labels, all_preds) cm = confusion_matrix(all_labels, all_preds) tn = fp = fn = tp = 0 if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel() sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 npv = tn / (tn + fn) if (tn + fn) > 0 else 0 return { 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, 'sensitivity': sensitivity, 'specificity': specificity, 'ppv': ppv, 'npv': npv, 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) } # ==================== 自定義 Trainer 與 Early Stopping ==================== class CustomTrainer(Trainer): """支援類別權重、Focal Loss 和 Early Stopping 的 Trainer""" def __init__(self, *args, class_weights=None, use_focal_loss=False, focal_gamma=2.0, monitor=None, early_stopping_patience=3, early_stopping_metric='eval_f1', **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights self.use_focal_loss = use_focal_loss self.focal_gamma = focal_gamma self.monitor = monitor self.early_stopping_patience = early_stopping_patience self.early_stopping_metric = early_stopping_metric self.best_metric = -float('inf') self.best_model_state = None self.patience_counter = 0 self.current_epoch = 0 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """計算損失函數""" labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits if self.use_focal_loss and self.class_weights is not None: # Focal Loss 實現 ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')( logits.view(-1, 2), labels.view(-1) ) pt = torch.exp(-ce_loss) focal_loss = ((1 - pt) ** self.focal_gamma * ce_loss).mean() loss = focal_loss elif self.class_weights is not None: # 標準加權交叉熵 loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) loss = loss_fct(logits.view(-1, 2), labels.view(-1)) else: # 標準交叉熵 loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, 2), labels.view(-1)) return (loss, outputs) if return_outputs else loss def on_epoch_end(self, args, state, control, **kwargs): """每個 epoch 結束時的回調""" self.current_epoch += 1 # 評估模型 metrics = self.evaluate() # 記錄到監控器 if self.monitor: self.monitor.log_epoch( epoch=self.current_epoch, train_loss=state.log_history[-1].get('loss', 0) if state.log_history else 0, eval_metrics=metrics, lr=self.get_learning_rate() ) # Early Stopping 檢查 current_metric = metrics.get(self.early_stopping_metric, 0) if current_metric > self.best_metric: self.best_metric = current_metric self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} self.patience_counter = 0 if self.monitor: self.monitor.update_best(self.current_epoch, current_metric) print(f"✅ Epoch {self.current_epoch}: 新最佳 {self.early_stopping_metric} = {current_metric:.4f}") else: self.patience_counter += 1 print(f"⏳ Epoch {self.current_epoch}: 無改善 (patience: {self.patience_counter}/{self.early_stopping_patience})") if self.patience_counter >= self.early_stopping_patience: print(f"🛑 Early Stopping 於 Epoch {self.current_epoch}") control.should_training_stop = True return control def get_learning_rate(self): """獲取當前學習率""" if self.optimizer is None: return 0 return self.optimizer.param_groups[0]['lr'] def load_best_model(self): """載入最佳模型""" if self.best_model_state: self.model.load_state_dict(self.best_model_state) print(f"✅ 已載入最佳模型 (最佳 {self.early_stopping_metric} = {self.best_metric:.4f})") # ==================== 基準模型快取(改進版) ==================== @lru_cache(maxsize=3) def get_cached_baseline_model(model_name: str, num_labels: int = 2): """使用 LRU 快取管理基準模型""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) return model.to(device) # ==================== 改善率計算 ==================== def calculate_improvement(baseline_val: float, finetuned_val: float) -> float: """安全計算改善率""" if baseline_val == 0: return float('inf') if finetuned_val > 0 else 0.0 return (finetuned_val - baseline_val) / baseline_val * 100 def format_improvement(val: float) -> str: """格式化改善率顯示""" if val == float('inf'): return "N/A (baseline=0)" elif val > 0: return f"↑ {val:.1f}%" elif val < 0: return f"↓ {abs(val):.1f}%" else: return "→ 0.0%" # ==================== 主要訓練函數(改進版) ==================== def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate, weight_decay, dropout, lora_r, lora_alpha, lora_dropout, weight_mult, weight_method, best_metric, use_early_stopping, patience): """ 改進的 BERT 模型訓練函數 """ global trained_models, model_counter, training_histories model_mapping = { "BERT-base": "bert-base-uncased", "BERT-base-chinese": "bert-base-chinese", "BioBERT": "dmis-lab/biobert-base-cased-v1.2", "SciBERT": "allenai/scibert_scivocab_uncased" } model_name = model_mapping.get(base_model, "bert-base-uncased") try: # ========== 資料驗證與載入 ========== if csv_file is None: return "❌ 請上傳 CSV 檔案", "", "", "", "" df = pd.read_csv(csv_file.name) if 'Text' not in df.columns or 'label' not in df.columns: return "❌ CSV 必須包含 'Text' 和 'label' 欄位", "", "", "", "" # 資料清理 df_clean = pd.DataFrame({ 'text': df['Text'].astype(str), 'label': df['label'].astype(int) }).dropna() # 統計資料 n0 = int(sum(df_clean['label'] == 0)) n1 = int(sum(df_clean['label'] == 1)) if n1 == 0: return "❌ 資料集中沒有正類樣本(死亡)", "", "", "", "" ratio = n0 / n1 if n1 > 0 else 0 # ========== 計算類別權重 ========== w0, w1 = calculate_class_weights(n0, n1, weight_mult, method=weight_method) # ========== 準備資料資訊 ========== info = f"📊 資料集統計\n" info += f"{'='*50}\n" info += f"總樣本數: {len(df_clean):,}\n" info += f"存活 (0): {n0:,} ({n0/len(df_clean)*100:.1f}%)\n" info += f"死亡 (1): {n1:,} ({n1/len(df_clean)*100:.1f}%)\n" info += f"不平衡比例: {ratio:.2f}:1\n" info += f"\n⚖️ 類別權重設定\n" info += f"{'='*50}\n" info += f"計算方法: {weight_method}\n" info += f"存活權重: {w0:.3f}\n" info += f"死亡權重: {w1:.3f}\n" info += f"權重比例: 1:{w1/w0:.2f}\n" # ========== 模型與分詞器初始化 ========== info += f"\n🤖 模型配置\n" info += f"{'='*50}\n" info += f"基礎模型: {base_model}\n" info += f"模型路徑: {model_name}\n" info += f"微調方法: {method.upper()}\n" tokenizer = BertTokenizer.from_pretrained(model_name) # ========== 資料集準備 ========== dataset = Dataset.from_pandas(df_clean[['text', 'label']]) def preprocess(examples): return tokenizer( examples['text'], truncation=True, padding='max_length', max_length=128 ) tokenized = dataset.map(preprocess, batched=True, remove_columns=['text']) split = tokenized.train_test_split(test_size=0.2, seed=42, stratify=tokenized['label']) # ========== 設備配置 ========== device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') info += f"運算設備: {'GPU ✅ (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU ⚠️'}\n" # ========== 評估基準模型 ========== info += f"\n📏 基準模型評估\n" info += f"{'='*50}\n" info += f"正在評估未微調的 {base_model}...\n" baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) baseline_model = baseline_model.to(device) baseline_perf = evaluate_baseline( baseline_model, tokenizer, split['test'], device, batch_size=batch_size*2 ) info += f"基準 F1 分數: {baseline_perf['f1']:.4f}\n" info += f"基準準確率: {baseline_perf['accuracy']:.4f}\n" # 清理基準模型記憶體 del baseline_model if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # ========== 配置微調模型 ========== info += f"\n🔧 微調配置\n" info += f"{'='*50}\n" model = BertForSequenceClassification.from_pretrained( model_name, num_labels=2, hidden_dropout_prob=dropout, attention_probs_dropout_prob=dropout ) # 應用 PEFT 方法 peft_applied = False if method == "lora": from peft import LoraConfig, get_peft_model, TaskType config = LoraConfig( task_type=TaskType.SEQ_CLS, r=int(lora_r), lora_alpha=int(lora_alpha), lora_dropout=lora_dropout, target_modules=["query", "value"], bias="none" ) model = get_peft_model(model, config) peft_applied = True info += f"✅ LoRA 已套用\n" info += f" - Rank (r): {int(lora_r)}\n" info += f" - Alpha: {int(lora_alpha)}\n" info += f" - Dropout: {lora_dropout}\n" elif method == "adalora": from peft import AdaLoraConfig, get_peft_model, TaskType config = AdaLoraConfig( task_type=TaskType.SEQ_CLS, r=int(lora_r), lora_alpha=int(lora_alpha), lora_dropout=lora_dropout, target_modules=["query", "value"], init_r=12, target_r=int(lora_r), tinit=200, tfinal=1000, deltaT=10 ) model = get_peft_model(model, config) peft_applied = True info += f"✅ AdaLoRA 已套用\n" info += f" - Initial Rank: 12\n" info += f" - Target Rank: {int(lora_r)}\n" info += f" - Alpha: {int(lora_alpha)}\n" elif method == "full": info += f"✅ Full Fine-tuning 模式\n" peft_applied = False model = model.to(device) # 參數統計 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) info += f"\n💾 模型參數\n" info += f"{'='*50}\n" info += f"總參數量: {total_params:,}\n" info += f"可訓練參數: {trainable_params:,}\n" info += f"可訓練比例: {trainable_params/total_params*100:.2f}%\n" info += f"記憶體節省: {(1 - trainable_params/total_params)*100:.1f}%\n" # ========== 準備訓練 ========== weights = torch.tensor([w0, w1], dtype=torch.float).to(device) use_focal = ratio > 10 # 極度不平衡時使用 Focal Loss if use_focal: info += f"\n⚡ 特殊設定\n" info += f"{'='*50}\n" info += f"使用 Focal Loss (γ=2.0) 處理極度不平衡\n" # 訓練參數 training_args = TrainingArguments( output_dir='./results', num_train_epochs=int(num_epochs), per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size) * 2, learning_rate=float(learning_rate), weight_decay=float(weight_decay), evaluation_strategy="epoch", save_strategy="no", # 使用自定義保存策略 load_best_model_at_end=False, report_to="none", logging_steps=max(1, len(split['train']) // (int(batch_size) * 10)), warmup_steps=min(500, len(split['train']) // int(batch_size)), logging_first_step=True, remove_unused_columns=False, label_smoothing_factor=0.1 if ratio > 20 else 0.0, # 極度不平衡時使用標籤平滑 ) # 創建監控器 monitor = TrainingMonitor() # 創建自定義 Trainer trainer = CustomTrainer( model=model, args=training_args, train_dataset=split['train'], eval_dataset=split['test'], compute_metrics=compute_metrics, class_weights=weights, use_focal_loss=use_focal, focal_gamma=2.0, monitor=monitor, early_stopping_patience=patience if use_early_stopping else 999, early_stopping_metric=f'eval_{best_metric}' ) info += f"\n🚀 訓練設定\n" info += f"{'='*50}\n" info += f"訓練樣本: {len(split['train']):,}\n" info += f"測試樣本: {len(split['test']):,}\n" info += f"批次大小: {int(batch_size)}\n" info += f"訓練輪數: {int(num_epochs)}\n" info += f"批次數/輪: {len(split['train']) // int(batch_size)}\n" info += f"Early Stopping: {'開啟 (patience=' + str(patience) + ')' if use_early_stopping else '關閉'}\n" info += f"最佳指標: {best_metric}\n" info += f"\n⏳ 開始訓練...\n" info += f"{'='*50}\n" # ========== 執行訓練 ========== train_result = trainer.train() # 載入最佳模型 if use_early_stopping: trainer.load_best_model() # 最終評估 final_results = trainer.evaluate() # ========== 保存模型與結果 ========== model_counter += 1 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_id = f"{base_model}_{method}_{model_counter}_{timestamp}" trained_models[model_id] = { 'model': model, 'tokenizer': tokenizer, 'results': final_results, 'baseline': baseline_perf, 'config': { 'type': base_model, 'model_name': model_name, 'method': method, 'metric': best_metric, 'epochs': int(num_epochs), 'batch_size': int(batch_size), 'learning_rate': float(learning_rate), 'weight_method': weight_method, 'weight_mult': weight_mult }, 'timestamp': timestamp, 'monitor': monitor # 保存訓練歷史 } training_histories[model_id] = monitor.history info += f"\n✅ 訓練完成!\n" info += f"最終 Training Loss: {train_result.training_loss:.4f}\n" if monitor.history['best_epoch']: info += f"最佳 Epoch: {monitor.history['best_epoch']}\n" # ========== 準備輸出結果 ========== # 基準模型結果 baseline_output = format_baseline_results(baseline_perf) # 微調模型結果 finetuned_output = format_finetuned_results(model_id, final_results) # 比較結果 comparison_output = format_comparison_results(baseline_perf, final_results) # 訓練歷程 history_output = monitor.get_summary() return info, baseline_output, finetuned_output, comparison_output, history_output except Exception as e: import traceback error_msg = f"❌ 錯誤發生\n\n錯誤類型: {type(e).__name__}\n錯誤訊息: {str(e)}\n\n" error_msg += f"詳細追蹤:\n{traceback.format_exc()}" return error_msg, "", "", "", "" # ==================== 格式化輸出函數 ==================== def format_baseline_results(baseline_perf: Dict) -> str: """格式化基準模型結果""" output = "🔬 純 BERT(未微調)\n\n" output += "📊 模型表現\n" output += f"{'='*30}\n" output += f"F1 Score: {baseline_perf['f1']:.4f}\n" output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n" output += f"Precision: {baseline_perf['precision']:.4f}\n" output += f"Recall: {baseline_perf['recall']:.4f}\n" output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n" output += f"Specificity: {baseline_perf['specificity']:.4f}\n" output += f"PPV: {baseline_perf['ppv']:.4f}\n" output += f"NPV: {baseline_perf['npv']:.4f}\n\n" output += "📈 混淆矩陣\n" output += f"{'='*30}\n" output += f" 預測 0 預測 1\n" output += f"實際 0 {baseline_perf['tn']:4d} {baseline_perf['fp']:4d}\n" output += f"實際 1 {baseline_perf['fn']:4d} {baseline_perf['tp']:4d}\n" return output def format_finetuned_results(model_id: str, results: Dict) -> str: """格式化微調模型結果""" output = f"✅ 微調 BERT\n" output += f"模型 ID: {model_id}\n\n" output += "📊 模型表現\n" output += f"{'='*30}\n" output += f"F1 Score: {results['eval_f1']:.4f}\n" output += f"Accuracy: {results['eval_accuracy']:.4f}\n" output += f"Precision: {results['eval_precision']:.4f}\n" output += f"Recall: {results['eval_recall']:.4f}\n" output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n" output += f"Specificity: {results['eval_specificity']:.4f}\n" output += f"PPV: {results['eval_ppv']:.4f}\n" output += f"NPV: {results['eval_npv']:.4f}\n\n" output += "📈 混淆矩陣\n" output += f"{'='*30}\n" output += f" 預測 0 預測 1\n" output += f"實際 0 {results['eval_tn']:4d} {results['eval_fp']:4d}\n" output += f"實際 1 {results['eval_fn']:4d} {results['eval_tp']:4d}\n" return output def format_comparison_results(baseline_perf: Dict, finetuned_results: Dict) -> str: """格式化比較結果""" output = "📊 純 BERT vs 微調 BERT 比較\n\n" output += "指標改善分析:\n" output += f"{'='*50}\n" output += f"{'指標':<12} {'基準':>8} {'微調':>8} {'變化':>10} {'改善率':>10}\n" output += f"{'-'*50}\n" metrics = [ ('F1', 'f1', 'eval_f1'), ('Accuracy', 'accuracy', 'eval_accuracy'), ('Precision', 'precision', 'eval_precision'), ('Recall', 'recall', 'eval_recall'), ('Sensitivity', 'sensitivity', 'eval_sensitivity'), ('Specificity', 'specificity', 'eval_specificity'), ('PPV', 'ppv', 'eval_ppv'), ('NPV', 'npv', 'eval_npv') ] for name, base_key, fine_key in metrics: base_val = baseline_perf[base_key] fine_val = finetuned_results[fine_key] change = fine_val - base_val improve = calculate_improvement(base_val, fine_val) output += f"{name:<12} {base_val:>8.4f} {fine_val:>8.4f} " output += f"{change:+10.4f} {format_improvement(improve):>10}\n" output += f"\n混淆矩陣變化:\n" output += f"{'='*40}\n" output += f"{'項目':<10} {'基準':>8} {'微調':>8} {'變化':>10}\n" output += f"{'-'*40}\n" cm_items = [ ('True Pos', 'tp', 'eval_tp'), ('True Neg', 'tn', 'eval_tn'), ('False Pos', 'fp', 'eval_fp'), ('False Neg', 'fn', 'eval_fn') ] for name, base_key, fine_key in cm_items: base_val = baseline_perf[base_key] fine_val = finetuned_results[fine_key] change = fine_val - base_val output += f"{name:<10} {base_val:>8d} {fine_val:>8d} {change:+10d}\n" # 總結 output += f"\n📈 整體評估:\n" output += f"{'='*40}\n" f1_improve = calculate_improvement(baseline_perf['f1'], finetuned_results['eval_f1']) if f1_improve > 10: output += "✅ 顯著改善:微調帶來明顯的性能提升!\n" elif f1_improve > 0: output += "✅ 有所改善:微調產生正向影響。\n" elif f1_improve == 0: output += "➖ 無變化:微調未產生明顯影響。\n" else: output += "⚠️ 性能下降:可能需要調整超參數。\n" return output # ==================== 預測函數(改進版) ==================== def predict(model_id, text): """使用選定模型進行預測並與基準模型比較""" if not model_id or model_id not in trained_models: return "❌ 請選擇一個已訓練的模型" if not text or len(text.strip()) == 0: return "❌ 請輸入要預測的文字" try: # 獲取模型資訊 info = trained_models[model_id] model = info['model'] tokenizer = info['tokenizer'] config = info['config'] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 文字預處理 inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) inputs_device = {k: v.to(device) for k, v in inputs.items()} # ========== 微調模型預測 ========== model.eval() with torch.no_grad(): outputs = model(**inputs_device) logits = outputs.logits probs_finetuned = torch.nn.functional.softmax(logits, dim=-1) pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item() confidence_finetuned = probs_finetuned[0][pred_finetuned].item() # ========== 基準模型預測 ========== baseline_model = get_cached_baseline_model(config['model_name']) baseline_model.eval() with torch.no_grad(): outputs_baseline = baseline_model(**inputs_device) logits_baseline = outputs_baseline.logits probs_baseline = torch.nn.functional.softmax(logits_baseline, dim=-1) pred_baseline = torch.argmax(probs_baseline, dim=-1).item() confidence_baseline = probs_baseline[0][pred_baseline].item() # ========== 格式化輸出 ========== result_finetuned = "🟢 存活" if pred_finetuned == 0 else "🔴 死亡" result_baseline = "🟢 存活" if pred_baseline == 0 else "🔴 死亡" agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致" output = f"""🔮 預測結果比較分析 📝 輸入文字 {'='*60} {text[:200]}{'...' if len(text) > 200 else ''} {'='*60} 🎯 微調模型預測 ({model_id}) {'='*60} 預測結果: {result_finetuned} 預測信心: {confidence_finetuned:.1%} 機率分布: • 存活 (0): {probs_finetuned[0][0].item():.2%} • 死亡 (1): {probs_finetuned[0][1].item():.2%} 模型配置: • 方法: {config['method'].upper()} • 基礎模型: {config['type']} • 訓練輪數: {config['epochs']} {'='*60} 🔬 基準模型預測(未微調 {config['type']}) {'='*60} 預測結果: {result_baseline} 預測信心: {confidence_baseline:.1%} 機率分布: • 存活 (0): {probs_baseline[0][0].item():.2%} • 死亡 (1): {probs_baseline[0][1].item():.2%} {'='*60} 📊 預測分析 {'='*60} 兩模型預測: {agreement} """ if pred_finetuned != pred_baseline: output += f""" 💡 差異分析: 微調模型預測【{result_finetuned}】(信心: {confidence_finetuned:.1%}) 基準模型預測【{result_baseline}】(信心: {confidence_baseline:.1%}) 這種差異顯示了微調對此特定案例的影響。 微調模型可能學習到了更適合您資料集的特徵。 """ else: output += f""" ✅ 預測一致性分析: 兩個模型都預測為【{result_finetuned}】 信心差異: {abs(confidence_finetuned - confidence_baseline):.1%} """ # 加入模型整體表現對比 f1_improve = calculate_improvement( info['baseline']['f1'], info['results']['eval_f1'] ) output += f""" 📈 模型整體表現對比 {'='*60} 微調模型 F1: {info['results']['eval_f1']:.4f} 基準模型 F1: {info['baseline']['f1']:.4f} 改善幅度: {format_improvement(f1_improve)} 微調模型準確率: {info['results']['eval_accuracy']:.4f} 基準模型準確率: {info['baseline']['accuracy']:.4f} """ return output except Exception as e: import traceback return f"❌ 預測時發生錯誤\n\n{str(e)}\n\n{traceback.format_exc()}" # ==================== 模型比較函數 ==================== def compare_models(): """比較所有已訓練的模型""" if not trained_models: return "❌ 尚未訓練任何模型。請先在「訓練」頁面訓練模型。" output = "# 📊 模型比較報告\n\n" output += f"共有 {len(trained_models)} 個已訓練模型\n\n" # 微調模型表現表格 output += "## 🎯 微調模型表現\n\n" output += "| 模型 ID | 基礎模型 | 方法 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n" output += "|---------|----------|------|-----|--------|--------|--------|--------|--------|\n" for model_id, info in trained_models.items(): r = info['results'] c = info['config'] # 縮短模型 ID 顯示 short_id = f"{c['type']}_{c['method']}_{info['timestamp'][-6:]}" output += f"| {short_id} | {c['type']} | {c['method'].upper()} | " output += f"{r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | " output += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | " output += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n" # 基準模型表現 output += "\n## 🔬 基準模型表現(未微調)\n\n" # 獲取唯一的基準模型 unique_baselines = {} for model_id, info in trained_models.items(): base_type = info['config']['type'] if base_type not in unique_baselines: unique_baselines[base_type] = info['baseline'] output += "| 基礎模型 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n" output += "|----------|-----|--------|--------|--------|--------|--------|\n" for base_type, baseline in unique_baselines.items(): output += f"| {base_type} | {baseline['f1']:.4f} | {baseline['accuracy']:.4f} | " output += f"{baseline['precision']:.4f} | {baseline['recall']:.4f} | " output += f"{baseline['sensitivity']:.4f} | {baseline['specificity']:.4f} |\n" # 最佳模型分析 output += "\n## 🏆 最佳模型(各指標)\n\n" metrics_to_check = [ ('F1 Score', 'eval_f1'), ('準確率', 'eval_accuracy'), ('精確率', 'eval_precision'), ('召回率', 'eval_recall'), ('敏感度', 'eval_sensitivity'), ('特異度', 'eval_specificity') ] for metric_name, metric_key in metrics_to_check: best_model = max( trained_models.items(), key=lambda x: x[1]['results'][metric_key] ) model_id = best_model[0] value = best_model[1]['results'][metric_key] baseline_val = best_model[1]['baseline'][metric_key.replace('eval_', '')] improvement = calculate_improvement(baseline_val, value) output += f"**{metric_name}**: {model_id[:30]}... " output += f"({value:.4f}, 改善 {format_improvement(improvement)})\n\n" # 改善統計 output += "## 📈 改善統計\n\n" improvements = [] for model_id, info in trained_models.items(): f1_base = info['baseline']['f1'] f1_fine = info['results']['eval_f1'] improve = calculate_improvement(f1_base, f1_fine) if improve != float('inf'): improvements.append({ 'model': model_id, 'improvement': improve, 'method': info['config']['method'] }) if improvements: avg_improvement = np.mean([x['improvement'] for x in improvements]) max_improvement = max(improvements, key=lambda x: x['improvement']) min_improvement = min(improvements, key=lambda x: x['improvement']) output += f"平均 F1 改善: {format_improvement(avg_improvement)}\n" output += f"最大改善: {max_improvement['model'][:30]}... ({format_improvement(max_improvement['improvement'])})\n" output += f"最小改善: {min_improvement['model'][:30]}... ({format_improvement(min_improvement['improvement'])})\n\n" # 方法比較 method_improvements = {} for imp in improvements: method = imp['method'] if method not in method_improvements: method_improvements[method] = [] method_improvements[method].append(imp['improvement']) output += "### 各方法平均改善:\n" for method, imps in method_improvements.items(): avg_imp = np.mean(imps) output += f"- **{method.upper()}**: {format_improvement(avg_imp)}\n" return output # ==================== Gradio UI ==================== def create_demo(): """創建 Gradio 介面""" with gr.Blocks( title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft(), css=""" .gradio-container {font-family: 'Microsoft JhengHei', 'Arial', sans-serif;} """ ) as demo: gr.Markdown( """ # 🧬 BERT Fine-tuning 教學平台 ### 比較基準模型 vs 微調模型的表現差異(改進版) """ ) with gr.Tab("🎯 訓練"): gr.Markdown("## 步驟 1: 選擇基礎模型") base_model = gr.Dropdown( choices=["BERT-base", "BERT-base-chinese", "BioBERT", "SciBERT"], value="BERT-base", label="基礎模型", info="選擇適合您資料的預訓練模型" ) gr.Markdown("## 步驟 2: 選擇微調方法") method = gr.Radio( choices=["lora", "adalora", "full"], value="lora", label="微調方法", info="LoRA 和 AdaLoRA 是參數高效方法,Full 是完全微調" ) gr.Markdown("## 步驟 3: 上傳資料") csv_file = gr.File( label="CSV 檔案(需包含 Text 和 label 欄位)", file_types=[".csv"] ) gr.Markdown("## 步驟 4: 設定訓練參數") with gr.Accordion("🎯 基本訓練參數", open=True): with gr.Row(): num_epochs = gr.Number( value=5, label="訓練輪數", minimum=1, maximum=50, precision=0, info="建議 3-10 輪,過多可能過擬合" ) batch_size = gr.Number( value=8, label="批次大小", minimum=1, maximum=64, precision=0, info="GPU 記憶體不足時請降低" ) learning_rate = gr.Number( value=3e-5, label="學習率", minimum=1e-6, maximum=1e-3, info="建議 1e-5 到 5e-5" ) with gr.Accordion("⚙️ 進階參數"): with gr.Row(): weight_decay = gr.Number( value=0.01, label="權重衰減", minimum=0, maximum=1, info="防止過擬合,建議 0.01-0.1" ) dropout = gr.Number( value=0.1, label="Dropout 率", minimum=0, maximum=0.5, info="防止過擬合,建議 0.1-0.3" ) with gr.Accordion("🔧 PEFT 參數(LoRA/AdaLoRA)"): with gr.Row(): lora_r = gr.Number( value=16, label="LoRA Rank (r)", minimum=1, maximum=64, precision=0, info="越大表達能力越強,但參數越多" ) lora_alpha = gr.Number( value=32, label="LoRA Alpha", minimum=1, maximum=128, precision=0, info="通常設為 Rank 的 2 倍" ) lora_dropout = gr.Number( value=0.05, label="LoRA Dropout", minimum=0, maximum=0.5, info="LoRA 層的 dropout" ) with gr.Accordion("⚖️ 類別平衡設定"): with gr.Row(): weight_mult = gr.Number( value=1.0, label="權重倍數", minimum=0.1, maximum=5.0, info="調整少數類權重的倍數" ) weight_method = gr.Dropdown( choices=["sqrt", "log", "balanced", "custom"], value="sqrt", label="權重計算方法", info="sqrt 和 log 適合極度不平衡資料" ) with gr.Accordion("🎯 訓練策略"): with gr.Row(): best_metric = gr.Dropdown( choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"], value="f1", label="最佳模型指標", info="根據此指標選擇最佳模型" ) use_early_stopping = gr.Checkbox( value=True, label="啟用 Early Stopping", info="當模型不再改善時提前停止" ) patience = gr.Number( value=3, label="Patience", minimum=1, maximum=10, precision=0, info="幾輪無改善後停止訓練" ) train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg") gr.Markdown("## 📊 訓練結果") with gr.Row(): data_info = gr.Textbox(label="📋 訓練資訊", lines=25) history_output = gr.Textbox(label="📈 訓練歷程", lines=25) with gr.Row(): baseline_result = gr.Textbox(label="🔬 基準模型(未微調)", lines=15) finetuned_result = gr.Textbox(label="✅ 微調模型", lines=15) comparison_result = gr.Textbox(label="📊 效能比較分析", lines=20) train_btn.click( train_bert_model, inputs=[ csv_file, base_model, method, num_epochs, batch_size, learning_rate, weight_decay, dropout, lora_r, lora_alpha, lora_dropout, weight_mult, weight_method, best_metric, use_early_stopping, patience ], outputs=[data_info, baseline_result, finetuned_result, comparison_result, history_output] ) with gr.Tab("🔮 預測"): gr.Markdown("## 使用訓練好的模型進行預測") with gr.Row(): model_dropdown = gr.Dropdown( label="選擇模型", choices=list(trained_models.keys()), interactive=True ) refresh_btn = gr.Button("🔄 刷新模型列表", size="sm") text_input = gr.Textbox( label="輸入要預測的文字", lines=5, placeholder="請輸入病例描述或相關文字..." ) predict_btn = gr.Button("🎯 執行預測", variant="primary", size="lg") pred_output = gr.Textbox(label="預測結果與分析", lines=25) # 刷新模型列表 refresh_btn.click( lambda: gr.Dropdown(choices=list(trained_models.keys())), outputs=[model_dropdown] ) # 執行預測 predict_btn.click( predict, inputs=[model_dropdown, text_input], outputs=[pred_output] ) # 範例 gr.Examples( examples=[ ["Patient with stage II breast cancer, showing good response to chemotherapy treatment."], ["Advanced metastatic cancer with multiple organ failure, poor prognosis."], ["Early stage tumor detected, surgery scheduled, excellent recovery expected."], ["Terminal stage disease, palliative care initiated, family counseling provided."] ], inputs=text_input ) with gr.Tab("📊 比較"): gr.Markdown("## 比較所有已訓練的模型") compare_btn = gr.Button("📊 生成比較報告", variant="primary", size="lg") compare_output = gr.Markdown() compare_btn.click(compare_models, outputs=[compare_output]) with gr.Tab("📖 說明"): gr.Markdown(""" ## 📖 使用說明 ### 🎯 平台特色 本改進版平台提供以下功能: 1. **自動基準比較**:每次訓練都會自動評估基準模型,清楚顯示微調的改善 2. **訓練監控**:記錄每個 epoch 的詳細訓練歷程 3. **Early Stopping**:避免過擬合,自動選擇最佳模型 4. **多種權重策略**:針對不平衡資料提供多種處理方法 5. **完整評估指標**:包含 F1、準確率、精確率、召回率、敏感度、特異度、PPV、NPV ### 🤖 支援的基礎模型 - **BERT-base**: 標準英文 BERT,適用於一般英文文本 - **BERT-base-chinese**: 中文 BERT,適用於中文文本 - **BioBERT**: 生物醫學領域專用 BERT - **SciBERT**: 科學文獻專用 BERT ### 🔧 微調方法說明 - **LoRA** (Low-Rank Adaptation) - 參數效率最高,只訓練 <1% 參數 - 訓練速度快,記憶體需求低 - 適合大多數場景 - **AdaLoRA** (Adaptive LoRA) - 自動調整秩的分配 - 可能獲得更好的效果 - 訓練時間稍長 - **Full** (完全微調) - 訓練所有參數 - 可能獲得最佳效果 - 需要較大記憶體和時間 ### ⚖️ 處理不平衡資料 #### 權重計算方法: 1. **sqrt** (平方根法) - 推薦用於極度不平衡 - 使用平方根緩和權重 - 避免權重過大導致過擬合 2. **log** (對數法) - 更保守的方法 - 使用對數進一步緩和 - 適合極度不平衡且容易過擬合的情況 3. **balanced** (平衡法) - sklearn 風格的自動平衡 - 適合中度不平衡 4. **custom** (自定義) - 根據不平衡程度自動調整 - 綜合考慮多種因素 #### 建議參數設定: **極度不平衡 (>20:1)** - 權重方法: sqrt 或 log - 權重倍數: 0.5-1.0 - 使用 Focal Loss (自動啟用) - Early Stopping: 建議開啟 **高度不平衡 (10-20:1)** - 權重方法: sqrt - 權重倍數: 0.8-1.5 - Early Stopping: 建議開啟 **中度不平衡 (5-10:1)** - 權重方法: balanced - 權重倍數: 1.0-2.0 **輕度不平衡 (<5:1)** - 權重方法: balanced - 權重倍數: 1.5-3.0 ### 📊 評估指標說明 - **F1 Score**: 精確率和召回率的調和平均,適合不平衡資料 - **Accuracy**: 整體準確率 - **Precision**: 預測為正類中實際為正類的比例 - **Recall/Sensitivity**: 實際正類中被正確預測的比例 - **Specificity**: 實際負類中被正確預測的比例 - **PPV**: 陽性預測值 - **NPV**: 陰性預測值 ### 🚀 快速開始指南 1. **準備資料** - CSV 格式,包含 `Text` 和 `label` 欄位 - label: 0=負類(如存活), 1=正類(如死亡) 2. **選擇模型與方法** - 英文資料:BERT-base + LoRA - 中文資料:BERT-base-chinese + LoRA - 醫學資料:BioBERT + LoRA 3. **設定參數** - 使用預設參數作為起點 - 根據資料不平衡程度調整權重設定 4. **訓練與評估** - 點擊「開始訓練」 - 查看基準 vs 微調的比較 - 觀察訓練歷程 5. **測試預測** - 在「預測」頁面選擇模型 - 輸入文字進行預測 - 比較微調前後的差異 ### ⚠️ 注意事項 - GPU 可大幅加速訓練 - 批次大小過大可能導致記憶體不足 - Early Stopping 可避免過擬合 - 極度不平衡資料建議使用較保守的權重設定 ### 💡 優化建議 1. **記憶體不足**:降低批次大小或使用 LoRA 2. **過擬合**:增加 dropout、使用 Early Stopping、降低學習率 3. **欠擬合**:增加訓練輪數、提高學習率、增加模型容量 4. **不平衡資料**:調整類別權重、使用適當的評估指標(F1) """) return demo # ==================== 主程式 ==================== if __name__ == "__main__": demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, max_threads=4 )