smartTranscend commited on
Commit
fa3056f
·
verified ·
1 Parent(s): 76f5d0d

Upload 3 files

Browse files
Files changed (3) hide show
  1. bert_readme.txt +214 -0
  2. bert_requirements.py +9 -0
  3. bert_second_finetuning.py +1657 -0
bert_readme.txt ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🥼 BERT 乳癌存活預測 - 二次微調完整平台
2
+
3
+ 完整的 BERT 二次微調系統,支援從第一次微調到二次微調的完整流程,並可在新數據上比較多個模型的表現。
4
+
5
+ ## 🌟 核心功能
6
+
7
+ ### 1️⃣ 第一次微調
8
+ - 從純 BERT 開始訓練
9
+ - 支援三種微調方法:
10
+ - **Full Fine-tuning**: 訓練所有參數
11
+ - **LoRA**: 低秩適配,參數高效
12
+ - **AdaLoRA**: 自適應 LoRA,動態調整秩
13
+ - 自動比較純 BERT vs 第一次微調的表現
14
+
15
+ ### 2️⃣ 二次微調
16
+ - 基於第一次微調模型繼續訓練
17
+ - 使用新的訓練數據
18
+ - 自動繼承第一次的微調方法
19
+ - 適合增量學習和領域適應
20
+
21
+ ### 3️⃣ 新數據測試
22
+ - 上傳新測試數據
23
+ - 同時比較最多 3 個模型:
24
+ - 純 BERT (Baseline)
25
+ - 第一次微調模型
26
+ - 第二次微調模型
27
+ - 並排顯示所有評估指標
28
+
29
+ ### 4️⃣ 模型預測
30
+ - 選擇任一已訓練模型
31
+ - 輸入病歷文本進行預測
32
+ - 同時顯示未微調和微調模型的預測結果
33
+
34
+ ## 📋 資料格式
35
+
36
+ CSV 檔案必須包含以下欄位:
37
+ - **Text**: 病歷文本 (英文)
38
+ - **label**: 標籤 (0=存活, 1=死亡)
39
+
40
+ 範例:
41
+ ```csv
42
+ Text,label
43
+ "Patient is a 45-year-old female with stage II breast cancer...",0
44
+ "65-year-old woman diagnosed with triple-negative breast cancer...",1
45
+ ```
46
+
47
+ ## 🚀 使用流程
48
+
49
+ ### 步驟 1: 第一次微調
50
+ 1. 進入「1️⃣ 第一次微調」頁面
51
+ 2. 上傳訓練數據 A (CSV)
52
+ 3. 選擇微調方法 (建議先用 Full Fine-tuning)
53
+ 4. 調整訓練參數:
54
+ - 權重倍數: 0.8 (處理不平衡數據)
55
+ - 訓練輪數: 8-10
56
+ - 學習率: 2e-5
57
+ 5. 點擊「開始第一次微調」
58
+ 6. 等待訓練完成,查看結果
59
+
60
+ ### 步驟 2: 二次微調
61
+ 1. 進入「2️⃣ 二次微調」頁面
62
+ 2. 點擊「🔄 重新整理模型列表」
63
+ 3. 選擇第一次微調的模型
64
+ 4. 上傳新的訓練數據 B
65
+ 5. 調整訓練參數 (建議):
66
+ - 訓練輪數: 3-5 (比第一次少)
67
+ - 學習率: 1e-5 (比第一次小)
68
+ 6. 點擊「開始二次微調」
69
+ 7. 等待訓練完成
70
+
71
+ ### 步驟 3: 新數據測試
72
+ 1. 進入「3️⃣ 新數據測試」頁面
73
+ 2. 上傳測試數據 C
74
+ 3. 選擇要比較的模型:
75
+ - 純 BERT: 選擇「評估純 BERT」
76
+ - 第一次微調: 從下拉選單選擇
77
+ - 第二次微調: 從下拉選單選擇
78
+ 4. 點擊「開始測試」
79
+ 5. 查看三個模型的比較結果
80
+
81
+ ### 步驟 4: 預測
82
+ 1. 進入「4️⃣ 模型預測」頁面
83
+ 2. 選擇要使用的模型
84
+ 3. 輸入病歷文本
85
+ 4. 點擊「開始預測」
86
+ 5. 查看預測結果
87
+
88
+ ## 🎯 微調方法比較
89
+
90
+ | 方法 | 參數量 | 訓練速度 | 記憶體使用 | 效果 |
91
+ |------|--------|---------|-----------|------|
92
+ | **Full Fine-tuning** | 100% | 1x (基準) | 高 | 最佳 |
93
+ | **LoRA** | ~1% | 3-5x 快 | 低 | 良好 |
94
+ | **AdaLoRA** | ~1% | 3-5x 快 | 低 | 良好 |
95
+
96
+ ## 💡 二次微調最佳實踐
97
+
98
+ ### 何時使用二次微調?
99
+
100
+ 1. **領域適應**
101
+ - 第一次: 使用通用醫療數據
102
+ - 第二次: 使用特定醫院/科別數據
103
+
104
+ 2. **增量學習**
105
+ - 第一次: 使用歷史數據
106
+ - 第二次: 加入新收集的數據
107
+
108
+ 3. **數據稀缺**
109
+ - 第一次: 使用大量相關領域數據
110
+ - 第二次: 使用少量目標領域數據
111
+
112
+ ### 參數調整建議
113
+
114
+ | 參數 | 第一次微調 | 第二次微調 | 原因 |
115
+ |------|----------|----------|------|
116
+ | **Epochs** | 8-10 | 3-5 | 避免過度擬合 |
117
+ | **Learning Rate** | 2e-5 | 1e-5 | 保護已學習知識 |
118
+ | **Warmup Steps** | 200 | 100 | 較少的預熱 |
119
+ | **權重倍數** | 根據數據調整 | 根據新數據調整 | 處理不平衡 |
120
+
121
+ ### 注意事項
122
+
123
+ ⚠️ **重要提醒**:
124
+ - 第二次微調會自動使用第一次的微調方法,無法更換
125
+ - 建議第二次的學習率比第一次小,避免「災難性遺忘」
126
+ - 如果第二次數據與第一次差異很大,可能需要更多輪數
127
+ - 始終在新數據上測試,確保沒有性能下降
128
+
129
+ ## 📊 評估指標說明
130
+
131
+ | 指標 | 說明 | 適用場景 |
132
+ |------|------|---------|
133
+ | **F1 Score** | 精確率和召回率的調和平均 | 平衡評估,通用指標 |
134
+ | **Accuracy** | 整體準確率 | 數據平衡時使用 |
135
+ | **Precision** | 預測為死亡中的準確率 | 避免誤報時優化 |
136
+ | **Recall** | 實際死亡中被識別的比例 | 避免漏診時優化 |
137
+ | **Sensitivity** | 等同於 Recall | 醫療場景常用 |
138
+ | **Specificity** | 實際存活中被識別的比例 | 避免過度治療 |
139
+ | **AUC** | ROC 曲線下面積 | 整體分類能力 |
140
+
141
+ ## 🔧 技術細節
142
+
143
+ ### 訓練流程
144
+
145
+ 1. **數據準備**
146
+ - 載入 CSV
147
+ - 保持原始類別比例
148
+ - Tokenization (max_length=256)
149
+ - 80/20 訓練/驗證分割
150
+
151
+ 2. **模型初始化**
152
+ - 第一次: 從 `bert-base-uncased` 載入
153
+ - 第二次: 從第一次微調模型載入
154
+ - 應用 PEFT 配置 (如果使用 LoRA/AdaLoRA)
155
+
156
+ 3. **訓練**
157
+ - 使用類別權重處理不平衡
158
+ - Early stopping (基於驗證集)
159
+ - 保存最佳模型
160
+
161
+ 4. **評估**
162
+ - 在驗證集上評估
163
+ - 計算所有指標
164
+ - 生成混淆矩陣
165
+
166
+ ### 模型儲存
167
+
168
+ - 模型檔案: `./breast_cancer_bert_{method}_{type}_{timestamp}/`
169
+ - 模型清單: `./saved_models_list.json`
170
+ - 包含所有訓練資訊和超參數
171
+
172
+ ## 🐛 常見問題
173
+
174
+ ### Q1: 為什麼二次微調不能更換方法?
175
+ **A**: 因為不同方法的參數結構不同。例如 LoRA 添加了低秩矩陣,如果切換到 Full Fine-tuning,這些參數會遺失。
176
+
177
+ ### Q2: 第二次微調的數據量應該多少?
178
+ **A**: 建議至少 100 筆,但可以比第一次少。如果數據太少,可能會過度擬合。
179
+
180
+ ### Q3: 如何選擇最佳化指標?
181
+ **A**:
182
+ - 醫療場景通常優先 **Recall** (避免漏診)
183
+ - 如果誤報代價高,選 **Precision**
184
+ - 平衡場景選 **F1 Score**
185
+
186
+ ### Q4: GPU 記憶體不足怎麼辦?
187
+ **A**:
188
+ - 使用 LoRA 或 AdaLoRA (減少 90% 記憶體)
189
+ - 減小 batch size
190
+ - 減少 max_length
191
+
192
+ ### Q5: 訓練時間太長?
193
+ **A**:
194
+ - 使用 LoRA/AdaLoRA (快 3-5 倍)
195
+ - 減少 epochs
196
+ - 增加 batch size (如果記憶體允許)
197
+
198
+ ## 📝 版本資訊
199
+
200
+ - **Version**: 1.0.0
201
+ - **Python**: 3.10+
202
+ - **主要依賴**:
203
+ - transformers 4.36.0
204
+ - torch 2.1.0
205
+ - peft 0.7.1
206
+ - gradio 4.44.0
207
+
208
+ ## 📄 授權
209
+
210
+ 本專案完全保留您的原始程式邏輯,僅新增二次微調和測試功能。
211
+
212
+ ## 🙏 致謝
213
+
214
+ 基於 BERT 模型和 Hugging Face Transformers 庫開發。
bert_requirements.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ pandas==2.0.3
3
+ torch==2.1.0
4
+ transformers==4.36.0
5
+ datasets==2.14.6
6
+ scikit-learn==1.3.2
7
+ numpy==1.24.3
8
+ peft==0.7.1
9
+ accelerate==0.25.0
bert_second_finetuning.py ADDED
@@ -0,0 +1,1657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from torch import nn
5
+ from transformers import (
6
+ BertTokenizer,
7
+ BertForSequenceClassification,
8
+ TrainingArguments,
9
+ Trainer
10
+ )
11
+ from datasets import Dataset
12
+ from sklearn.metrics import (
13
+ accuracy_score,
14
+ precision_recall_fscore_support,
15
+ roc_auc_score,
16
+ confusion_matrix
17
+ )
18
+ import numpy as np
19
+ from datetime import datetime
20
+ import json
21
+ import os
22
+ import gc
23
+
24
+ # PEFT 相關的 import(LoRA 和 AdaLoRA)
25
+ try:
26
+ from peft import (
27
+ LoraConfig,
28
+ AdaLoraConfig,
29
+ get_peft_model,
30
+ TaskType,
31
+ PeftModel
32
+ )
33
+ PEFT_AVAILABLE = True
34
+ except ImportError:
35
+ PEFT_AVAILABLE = False
36
+ print("⚠️ PEFT 未安裝,LoRA 和 AdaLoRA 功能將不可用")
37
+
38
+ # 檢查 GPU
39
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+
41
+ _MODEL_PATH = None
42
+ LAST_TOKENIZER = None
43
+ LAST_TUNING_METHOD = None
44
+
45
+ # ==================== 您的原始函數 - 完全不動 ====================
46
+
47
+ def evaluate_baseline_bert(eval_dataset, df_clean):
48
+ """
49
+ 評估原始 BERT(完全沒看過資料)的表現
50
+ 這部分是從您的格子 5 提取的 baseline 比較邏輯
51
+ """
52
+ print("\n" + "=" * 80)
53
+ print("評估 Baseline 純 BERT(完全沒看過資料)")
54
+ print("=" * 80)
55
+
56
+ # 載入純 BERT
57
+ baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
58
+ baseline_model = BertForSequenceClassification.from_pretrained(
59
+ "bert-base-uncased",
60
+ num_labels=2
61
+ ).to(device)
62
+ baseline_model.eval()
63
+
64
+ print(" ⚠️ 這個模型完全沒有使用您的資料訓練")
65
+
66
+ # 重新處理驗證集
67
+ baseline_dataset = Dataset.from_pandas(df_clean[['text', 'label']])
68
+
69
+ def baseline_preprocess(examples):
70
+ return baseline_tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256)
71
+
72
+ baseline_tokenized = baseline_dataset.map(baseline_preprocess, batched=True)
73
+ baseline_split = baseline_tokenized.train_test_split(test_size=0.2, seed=42)
74
+ baseline_eval_dataset = baseline_split['test']
75
+
76
+ # 建立 Baseline Trainer
77
+ baseline_trainer_args = TrainingArguments(
78
+ output_dir='./temp_baseline',
79
+ per_device_eval_batch_size=32,
80
+ report_to="none"
81
+ )
82
+
83
+ baseline_trainer = Trainer(
84
+ model=baseline_model,
85
+ args=baseline_trainer_args,
86
+ )
87
+
88
+ # 評估 Baseline
89
+ print("📄 評估純 BERT...")
90
+ predictions_output = baseline_trainer.predict(baseline_eval_dataset)
91
+
92
+ all_preds = predictions_output.predictions.argmax(-1)
93
+ all_labels = predictions_output.label_ids
94
+ probs = torch.nn.functional.softmax(torch.tensor(predictions_output.predictions), dim=-1)[:, 1].numpy()
95
+
96
+ # 計算指標
97
+ precision, recall, f1, _ = precision_recall_fscore_support(
98
+ all_labels, all_preds, average='binary', pos_label=1, zero_division=0
99
+ )
100
+ acc = accuracy_score(all_labels, all_preds)
101
+
102
+ try:
103
+ auc = roc_auc_score(all_labels, probs)
104
+ except:
105
+ auc = 0.0
106
+
107
+ cm = confusion_matrix(all_labels, all_preds)
108
+ if cm.shape == (2, 2):
109
+ tn, fp, fn, tp = cm.ravel()
110
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
111
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
112
+ else:
113
+ sensitivity = specificity = 0
114
+ tn = fp = fn = tp = 0
115
+
116
+ baseline_results = {
117
+ 'f1': float(f1),
118
+ 'accuracy': float(acc),
119
+ 'precision': float(precision),
120
+ 'recall': float(recall),
121
+ 'sensitivity': float(sensitivity),
122
+ 'specificity': float(specificity),
123
+ 'auc': float(auc),
124
+ 'tp': int(tp),
125
+ 'tn': int(tn),
126
+ 'fp': int(fp),
127
+ 'fn': int(fn)
128
+ }
129
+
130
+ print("✅ Baseline 評估完成")
131
+
132
+ # 清理
133
+ del baseline_model
134
+ del baseline_trainer
135
+ torch.cuda.empty_cache()
136
+ gc.collect()
137
+
138
+ return baseline_results
139
+
140
+ def run_original_code_with_tuning(
141
+ file_path,
142
+ weight_multiplier,
143
+ epochs,
144
+ batch_size,
145
+ learning_rate,
146
+ warmup_steps,
147
+ tuning_method,
148
+ best_metric,
149
+ # LoRA 參數
150
+ lora_r,
151
+ lora_alpha,
152
+ lora_dropout,
153
+ lora_modules,
154
+ # AdaLoRA 參數
155
+ adalora_init_r,
156
+ adalora_target_r,
157
+ adalora_tinit,
158
+ adalora_tfinal,
159
+ adalora_delta_t,
160
+ # 新增:是否為二次微調
161
+ is_second_finetuning=False,
162
+ base_model_path=None
163
+ ):
164
+ """
165
+ 您的原始程式碼 + 不同微調方法的選項 + Baseline 比較
166
+ 核心邏輯完全不變,只是在模型初始化部分加入條件判斷
167
+
168
+ 新增參數:
169
+ - is_second_finetuning: 是否為二次微調
170
+ - base_model_path: 第一次微調模型的路徑(僅二次微調時使用)
171
+ """
172
+
173
+ global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
174
+
175
+ # ==================== 清空記憶體(訓練前) ====================
176
+ torch.cuda.empty_cache()
177
+ gc.collect()
178
+ print("🧹 記憶體已清空")
179
+
180
+ # ==================== 您的原始程式碼開始 ====================
181
+
182
+ # 讀取上傳的檔案
183
+ df_original = pd.read_csv(file_path)
184
+ df_clean = pd.DataFrame({
185
+ 'text': df_original['Text'],
186
+ 'label': df_original['label']
187
+ })
188
+ df_clean = df_clean.dropna()
189
+
190
+ training_type = "二次微調" if is_second_finetuning else "第一次微調"
191
+
192
+ print("\n" + "=" * 80)
193
+ print(f"乳癌存活預測 BERT {training_type} - {tuning_method} 方法")
194
+ print("=" * 80)
195
+ print(f"開始時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
196
+ print(f"訓練類型: {training_type}")
197
+ print(f"微調方法: {tuning_method}")
198
+ print(f"最佳化指標: {best_metric}")
199
+ if is_second_finetuning:
200
+ print(f"基礎模型: {base_model_path}")
201
+ print("=" * 80)
202
+
203
+ # 載入 Tokenizer
204
+ print("\n📦 載入 BERT Tokenizer...")
205
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
206
+ print("✅ Tokenizer 載入完成")
207
+
208
+ # 評估函數 - 完全是您的原始程式碼,不動
209
+ def compute_metrics(pred):
210
+ labels = pred.label_ids
211
+ preds = pred.predictions.argmax(-1)
212
+ probs = torch.nn.functional.softmax(torch.tensor(pred.predictions), dim=-1)[:, 1].numpy()
213
+
214
+ precision, recall, f1, _ = precision_recall_fscore_support(
215
+ labels, preds, average='binary', pos_label=1, zero_division=0
216
+ )
217
+ acc = accuracy_score(labels, preds)
218
+
219
+ try:
220
+ auc = roc_auc_score(labels, probs)
221
+ except:
222
+ auc = 0.0
223
+
224
+ cm = confusion_matrix(labels, preds)
225
+ if cm.shape == (2, 2):
226
+ tn, fp, fn, tp = cm.ravel()
227
+ else:
228
+ if len(np.unique(preds)) == 1:
229
+ if preds[0] == 0:
230
+ tn, fp, fn, tp = sum(labels == 0), 0, sum(labels == 1), 0
231
+ else:
232
+ tn, fp, fn, tp = 0, sum(labels == 0), 0, sum(labels == 1)
233
+ else:
234
+ tn = fp = fn = tp = 0
235
+
236
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
237
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
238
+
239
+ return {
240
+ 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
241
+ 'auc': auc, 'sensitivity': sensitivity, 'specificity': specificity,
242
+ 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
243
+ }
244
+
245
+ # ============================================================================
246
+ # 步驟 1:準備資料(不做平衡) - 您的原始程式碼
247
+ # ============================================================================
248
+
249
+ print("\n" + "=" * 80)
250
+ print("步驟 1:準備資料(保持原始比例)")
251
+ print("=" * 80)
252
+
253
+ print(f"\n原始資料分布:")
254
+ print(f" 存活 (0): {sum(df_clean['label']==0)} 筆 ({sum(df_clean['label']==0)/len(df_clean)*100:.1f}%)")
255
+ print(f" 死亡 (1): {sum(df_clean['label']==1)} 筆 ({sum(df_clean['label']==1)/len(df_clean)*100:.1f}%)")
256
+
257
+ ratio = sum(df_clean['label']==0) / sum(df_clean['label']==1)
258
+ print(f" 不平衡比例: {ratio:.1f}:1")
259
+
260
+ # ============================================================================
261
+ # 步驟 2:Tokenization - 您的原始程式碼
262
+ # ============================================================================
263
+
264
+ print("\n" + "=" * 80)
265
+ print("步驟 2:Tokenization")
266
+ print("=" * 80)
267
+
268
+ dataset = Dataset.from_pandas(df_clean[['text', 'label']])
269
+
270
+ def preprocess_function(examples):
271
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256)
272
+
273
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
274
+ train_test_split = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
275
+ train_dataset = train_test_split['train']
276
+ eval_dataset = train_test_split['test']
277
+
278
+ print(f"\n✅ 資料集準備完成:")
279
+ print(f" 訓練集: {len(train_dataset)} 筆")
280
+ print(f" 驗證集: {len(eval_dataset)} 筆")
281
+
282
+ # ============================================================================
283
+ # 步驟 3:設定權重 - 您的原始程式碼
284
+ # ============================================================================
285
+
286
+ print("\n" + "=" * 80)
287
+ print(f"步驟 3:設定類別權重({weight_multiplier}x 倍數)")
288
+ print("=" * 80)
289
+
290
+ weight_0 = 1.0
291
+ weight_1 = ratio * weight_multiplier
292
+
293
+ print(f"\n權重設定:")
294
+ print(f" 倍數: {weight_multiplier}x")
295
+ print(f" 存活類權重: {weight_0:.3f}")
296
+ print(f" 死亡類權重: {weight_1:.3f} (= {ratio:.1f} × {weight_multiplier})")
297
+
298
+ class_weights = torch.tensor([weight_0, weight_1], dtype=torch.float).to(device)
299
+
300
+ # ============================================================================
301
+ # 步驟 4:訓練模型 - 這裡加入二次微調的邏輯
302
+ # ============================================================================
303
+
304
+ print("\n" + "=" * 80)
305
+ print(f"步驟 4:訓練 {tuning_method} BERT 模型 ({training_type})")
306
+ print("=" * 80)
307
+
308
+ print(f"\n🔄 初始化模型 ({tuning_method})...")
309
+
310
+ # 【新增】二次微調:載入第一次微調的模型
311
+ if is_second_finetuning and base_model_path:
312
+ print(f"📦 載入第一次微調模型: {base_model_path}")
313
+
314
+ # 讀取第一次模型資訊
315
+ with open('./saved_models_list.json', 'r') as f:
316
+ models_list = json.load(f)
317
+
318
+ base_model_info = None
319
+ for model_info in models_list:
320
+ if model_info['model_path'] == base_model_path:
321
+ base_model_info = model_info
322
+ break
323
+
324
+ if base_model_info is None:
325
+ raise ValueError(f"找不到基礎模型資訊: {base_model_path}")
326
+
327
+ base_tuning_method = base_model_info['tuning_method']
328
+ print(f" 第一次微調方法: {base_tuning_method}")
329
+
330
+ # 根據第一次的方法載入模型
331
+ if base_tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
332
+ # 載入 PEFT 模型
333
+ base_bert = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
334
+ model = PeftModel.from_pretrained(base_bert, base_model_path)
335
+ print(f" ✅ 已載入 {base_tuning_method} 模型")
336
+ else:
337
+ # 載入一般模型
338
+ model = BertForSequenceClassification.from_pretrained(base_model_path, num_labels=2)
339
+ print(f" ✅ 已載入 Full Fine-tuning 模型")
340
+
341
+ model = model.to(device)
342
+ print(f" ⚠️ 注意:二次微調將使用與第一次相同的方法 ({base_tuning_method})")
343
+
344
+ # 二次微調時強制使用相同方法
345
+ tuning_method = base_tuning_method
346
+
347
+ else:
348
+ # 【原始邏輯】第一次微調:從純 BERT 開始
349
+ model = BertForSequenceClassification.from_pretrained(
350
+ "bert-base-uncased", num_labels=2, problem_type="single_label_classification"
351
+ )
352
+
353
+ # 根據選擇的微調方法設定模型
354
+ if tuning_method == "Full Fine-tuning":
355
+ # 您的原始方法 - 完全不動
356
+ model = model.to(device)
357
+ print("✅ 使用完整 Fine-tuning(所有參數可訓練)")
358
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
359
+ all_params = sum(p.numel() for p in model.parameters())
360
+ print(f" 可訓練參數: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
361
+
362
+ elif tuning_method == "LoRA" and PEFT_AVAILABLE:
363
+ # LoRA 設定
364
+ target_modules = lora_modules.split(",") if lora_modules else ["query", "value"]
365
+ target_modules = [m.strip() for m in target_modules]
366
+
367
+ peft_config = LoraConfig(
368
+ task_type=TaskType.SEQ_CLS,
369
+ r=int(lora_r),
370
+ lora_alpha=int(lora_alpha),
371
+ lora_dropout=float(lora_dropout),
372
+ target_modules=target_modules
373
+ )
374
+ model = get_peft_model(model, peft_config)
375
+ model = model.to(device)
376
+ print("✅ 使用 LoRA 微調")
377
+ print(f" LoRA rank (r): {lora_r}")
378
+ print(f" LoRA alpha: {lora_alpha}")
379
+ print(f" LoRA dropout: {lora_dropout}")
380
+ print(f" 目標模組: {target_modules}")
381
+ model.print_trainable_parameters()
382
+
383
+ elif tuning_method == "AdaLoRA" and PEFT_AVAILABLE:
384
+ # AdaLoRA 設定
385
+ target_modules = lora_modules.split(",") if lora_modules else ["query", "value"]
386
+ target_modules = [m.strip() for m in target_modules]
387
+
388
+ peft_config = AdaLoraConfig(
389
+ task_type=TaskType.SEQ_CLS,
390
+ init_r=int(adalora_init_r),
391
+ target_r=int(adalora_target_r),
392
+ tinit=int(adalora_tinit),
393
+ tfinal=int(adalora_tfinal),
394
+ deltaT=int(adalora_delta_t),
395
+ lora_alpha=int(lora_alpha),
396
+ lora_dropout=float(lora_dropout),
397
+ target_modules=target_modules
398
+ )
399
+ model = get_peft_model(model, peft_config)
400
+ model = model.to(device)
401
+ print("✅ 使用 AdaLoRA 微調")
402
+ print(f" 初始 rank: {adalora_init_r}")
403
+ print(f" 目標 rank: {adalora_target_r}")
404
+ print(f" Tinit: {adalora_tinit}, Tfinal: {adalora_tfinal}, DeltaT: {adalora_delta_t}")
405
+ model.print_trainable_parameters()
406
+
407
+ else:
408
+ # 預設使用 Full Fine-tuning
409
+ model = model.to(device)
410
+ print("⚠️ PEFT 未安裝或方法無效,使用 Full Fine-tuning")
411
+
412
+ # 自訂 Trainer(使用權重) - 您的原始程式碼
413
+ class WeightedTrainer(Trainer):
414
+ def compute_loss(self, model, inputs, return_outputs=False):
415
+ labels = inputs.pop("labels")
416
+ outputs = model(**inputs)
417
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
418
+ loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
419
+ return (loss, outputs) if return_outputs else loss
420
+
421
+ # 訓練設定 - 根據選擇的最佳指標調整
422
+ metric_map = {
423
+ "f1": "f1",
424
+ "accuracy": "accuracy",
425
+ "precision": "precision",
426
+ "recall": "recall",
427
+ "sensitivity": "sensitivity",
428
+ "specificity": "specificity",
429
+ "auc": "auc"
430
+ }
431
+
432
+ training_args = TrainingArguments(
433
+ output_dir='./results_weight',
434
+ num_train_epochs=epochs,
435
+ per_device_train_batch_size=batch_size,
436
+ per_device_eval_batch_size=batch_size*2,
437
+ warmup_steps=warmup_steps,
438
+ weight_decay=0.01,
439
+ learning_rate=learning_rate,
440
+ logging_steps=50,
441
+ evaluation_strategy="epoch",
442
+ save_strategy="epoch",
443
+ load_best_model_at_end=True,
444
+ metric_for_best_model=metric_map.get(best_metric, "f1"),
445
+ report_to="none",
446
+ greater_is_better=True
447
+ )
448
+
449
+ trainer = WeightedTrainer(
450
+ model=model, args=training_args,
451
+ train_dataset=train_dataset, eval_dataset=eval_dataset,
452
+ compute_metrics=compute_metrics
453
+ )
454
+
455
+ print(f"\n🚀 開始訓練({epochs} epochs)...")
456
+ print(f" 最佳化指標: {best_metric}")
457
+ print("-" * 80)
458
+
459
+ trainer.train()
460
+
461
+ print("\n✅ 模型訓練完成!")
462
+
463
+ # 評估模型
464
+ print("\n📊 評估模型...")
465
+ results = trainer.evaluate()
466
+
467
+ print(f"\n{training_type} {tuning_method} BERT ({weight_multiplier}x 權重) 表現:")
468
+ print(f" F1 Score: {results['eval_f1']:.4f}")
469
+ print(f" Accuracy: {results['eval_accuracy']:.4f}")
470
+ print(f" Precision: {results['eval_precision']:.4f}")
471
+ print(f" Recall: {results['eval_recall']:.4f}")
472
+ print(f" Sensitivity: {results['eval_sensitivity']:.4f}")
473
+ print(f" Specificity: {results['eval_specificity']:.4f}")
474
+ print(f" AUC: {results['eval_auc']:.4f}")
475
+ print(f" 混淆矩陣: Tp={results['eval_tp']}, Tn={results['eval_tn']}, "
476
+ f"Fp={results['eval_fp']}, Fn={results['eval_fn']}")
477
+
478
+ # ============================================================================
479
+ # 步驟 5:Baseline 比較(純 BERT) - 僅第一次微調時執行
480
+ # ============================================================================
481
+
482
+ if not is_second_finetuning:
483
+ print("\n" + "=" * 80)
484
+ print("步驟 5:Baseline 比較 - 純 BERT(完全沒看過資料)")
485
+ print("=" * 80)
486
+
487
+ baseline_results = evaluate_baseline_bert(eval_dataset, df_clean)
488
+
489
+ # ============================================================================
490
+ # 步驟 6:比較結果
491
+ # ============================================================================
492
+
493
+ print("\n" + "=" * 80)
494
+ print(f"📊 【對比結果】純 BERT vs {tuning_method} BERT")
495
+ print("=" * 80)
496
+
497
+ print("\n📋 詳細比較表:")
498
+ print("-" * 100)
499
+ print(f"{'指標':<15} {'純 BERT':<20} {tuning_method:<20} {'改善幅度':<20}")
500
+ print("-" * 100)
501
+
502
+ metrics_to_compare = [
503
+ ('F1 Score', 'f1', 'eval_f1'),
504
+ ('Accuracy', 'accuracy', 'eval_accuracy'),
505
+ ('Precision', 'precision', 'eval_precision'),
506
+ ('Recall', 'recall', 'eval_recall'),
507
+ ('Sensitivity', 'sensitivity', 'eval_sensitivity'),
508
+ ('Specificity', 'specificity', 'eval_specificity'),
509
+ ('AUC', 'auc', 'eval_auc')
510
+ ]
511
+
512
+ for name, baseline_key, finetuned_key in metrics_to_compare:
513
+ baseline_val = baseline_results[baseline_key]
514
+ finetuned_val = results[finetuned_key]
515
+ improvement = ((finetuned_val - baseline_val) / baseline_val * 100) if baseline_val > 0 else 0
516
+
517
+ print(f"{name:<15} {baseline_val:<20.4f} {finetuned_val:<20.4f} {improvement:>+18.1f}%")
518
+
519
+ print("-" * 100)
520
+ else:
521
+ baseline_results = None
522
+
523
+ # 儲存模型
524
+ training_label = "second" if is_second_finetuning else "first"
525
+ save_dir = f'./breast_cancer_bert_{tuning_method.lower().replace(" ", "_")}_{training_label}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
526
+
527
+ if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
528
+ # PEFT 模型儲存方式
529
+ model.save_pretrained(save_dir)
530
+ tokenizer.save_pretrained(save_dir)
531
+ else:
532
+ # 一般模型儲存方式
533
+ model.save_pretrained(save_dir)
534
+ tokenizer.save_pretrained(save_dir)
535
+
536
+ # 儲存模型資訊到 JSON 檔案(用於預測頁面選擇)
537
+ model_info = {
538
+ 'model_path': save_dir,
539
+ 'tuning_method': tuning_method,
540
+ 'training_type': training_type,
541
+ 'best_metric': best_metric,
542
+ 'best_metric_value': float(results[f'eval_{metric_map.get(best_metric, "f1")}']),
543
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
544
+ 'weight_multiplier': weight_multiplier,
545
+ 'epochs': epochs,
546
+ 'is_second_finetuning': is_second_finetuning,
547
+ 'base_model_path': base_model_path if is_second_finetuning else None
548
+ }
549
+
550
+ # 讀取現有的模型列表
551
+ models_list_file = './saved_models_list.json'
552
+ if os.path.exists(models_list_file):
553
+ with open(models_list_file, 'r') as f:
554
+ models_list = json.load(f)
555
+ else:
556
+ models_list = []
557
+
558
+ # 加入新模型資訊
559
+ models_list.append(model_info)
560
+
561
+ # 儲存更新後的列表
562
+ with open(models_list_file, 'w') as f:
563
+ json.dump(models_list, f, indent=2)
564
+
565
+ # 儲存到全域變數供預測使用
566
+ LAST_MODEL_PATH = save_dir
567
+ LAST_TOKENIZER = tokenizer
568
+ LAST_TUNING_METHOD = tuning_method
569
+
570
+ print(f"\n💾 模型已儲存至: {save_dir}")
571
+ print("\n" + "=" * 80)
572
+ print("🎉 訓練完成!")
573
+ print("=" * 80)
574
+ print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
575
+
576
+ # ==================== 清空記憶體(訓練後) ====================
577
+ del model
578
+ del trainer
579
+ torch.cuda.empty_cache()
580
+ gc.collect()
581
+ print("🧹 訓練後記憶體已清空")
582
+
583
+ # 加入所有資訊到結果中
584
+ results['tuning_method'] = tuning_method
585
+ results['training_type'] = training_type
586
+ results['best_metric'] = best_metric
587
+ results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}']
588
+ results['baseline_results'] = baseline_results
589
+ results['model_path'] = save_dir
590
+ results['is_second_finetuning'] = is_second_finetuning
591
+
592
+ return results
593
+
594
+ # ==================== 新增:新數據測試函數 ====================
595
+
596
+ def test_on_new_data(test_file_path, baseline_model_path, first_model_path, second_model_path):
597
+ """
598
+ 在新測試數據上比較三個模型的表現:
599
+ 1. 純 BERT (baseline)
600
+ 2. 第一次微調模型
601
+ 3. 第二次微調模型
602
+ """
603
+
604
+ print("\n" + "=" * 80)
605
+ print("📊 新數據測試 - 三模型比較")
606
+ print("=" * 80)
607
+
608
+ # 載入測試數據
609
+ df_test = pd.read_csv(test_file_path)
610
+ df_clean = pd.DataFrame({
611
+ 'text': df_test['Text'],
612
+ 'label': df_test['label']
613
+ })
614
+ df_clean = df_clean.dropna()
615
+
616
+ print(f"\n測試數據:")
617
+ print(f" 總筆數: {len(df_clean)}")
618
+ print(f" 存活 (0): {sum(df_clean['label']==0)} 筆")
619
+ print(f" 死亡 (1): {sum(df_clean['label']==1)} 筆")
620
+
621
+ # 準備測試數據
622
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
623
+ test_dataset = Dataset.from_pandas(df_clean[['text', 'label']])
624
+
625
+ def preprocess_function(examples):
626
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256)
627
+
628
+ test_tokenized = test_dataset.map(preprocess_function, batched=True)
629
+
630
+ # 評估函數
631
+ def evaluate_model(model, dataset_name):
632
+ model.eval()
633
+
634
+ trainer_args = TrainingArguments(
635
+ output_dir='./temp_test',
636
+ per_device_eval_batch_size=32,
637
+ report_to="none"
638
+ )
639
+
640
+ trainer = Trainer(
641
+ model=model,
642
+ args=trainer_args,
643
+ )
644
+
645
+ predictions_output = trainer.predict(test_tokenized)
646
+
647
+ all_preds = predictions_output.predictions.argmax(-1)
648
+ all_labels = predictions_output.label_ids
649
+ probs = torch.nn.functional.softmax(torch.tensor(predictions_output.predictions), dim=-1)[:, 1].numpy()
650
+
651
+ precision, recall, f1, _ = precision_recall_fscore_support(
652
+ all_labels, all_preds, average='binary', pos_label=1, zero_division=0
653
+ )
654
+ acc = accuracy_score(all_labels, all_preds)
655
+
656
+ try:
657
+ auc = roc_auc_score(all_labels, probs)
658
+ except:
659
+ auc = 0.0
660
+
661
+ cm = confusion_matrix(all_labels, all_preds)
662
+ if cm.shape == (2, 2):
663
+ tn, fp, fn, tp = cm.ravel()
664
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
665
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
666
+ else:
667
+ sensitivity = specificity = 0
668
+ tn = fp = fn = tp = 0
669
+
670
+ results = {
671
+ 'f1': float(f1),
672
+ 'accuracy': float(acc),
673
+ 'precision': float(precision),
674
+ 'recall': float(recall),
675
+ 'sensitivity': float(sensitivity),
676
+ 'specificity': float(specificity),
677
+ 'auc': float(auc),
678
+ 'tp': int(tp),
679
+ 'tn': int(tn),
680
+ 'fp': int(fp),
681
+ 'fn': int(fn)
682
+ }
683
+
684
+ print(f"\n✅ {dataset_name} 評估完成")
685
+
686
+ del trainer
687
+ torch.cuda.empty_cache()
688
+ gc.collect()
689
+
690
+ return results
691
+
692
+ all_results = {}
693
+
694
+ # 1. 評估純 BERT
695
+ if baseline_model_path != "跳過":
696
+ print("\n" + "-" * 80)
697
+ print("1️⃣ 評估純 BERT (Baseline)")
698
+ print("-" * 80)
699
+ baseline_model = BertForSequenceClassification.from_pretrained(
700
+ "bert-base-uncased",
701
+ num_labels=2
702
+ ).to(device)
703
+ all_results['baseline'] = evaluate_model(baseline_model, "純 BERT")
704
+ del baseline_model
705
+ torch.cuda.empty_cache()
706
+ else:
707
+ all_results['baseline'] = None
708
+
709
+ # 2. 評估第一次微調模型
710
+ if first_model_path != "請選擇":
711
+ print("\n" + "-" * 80)
712
+ print("2️⃣ 評估第一次微調模型")
713
+ print("-" * 80)
714
+
715
+ # 讀取模型資訊
716
+ with open('./saved_models_list.json', 'r') as f:
717
+ models_list = json.load(f)
718
+
719
+ first_model_info = None
720
+ for model_info in models_list:
721
+ if model_info['model_path'] == first_model_path:
722
+ first_model_info = model_info
723
+ break
724
+
725
+ if first_model_info:
726
+ tuning_method = first_model_info['tuning_method']
727
+
728
+ if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
729
+ base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
730
+ first_model = PeftModel.from_pretrained(base_model, first_model_path)
731
+ first_model = first_model.to(device)
732
+ else:
733
+ first_model = BertForSequenceClassification.from_pretrained(first_model_path).to(device)
734
+
735
+ all_results['first'] = evaluate_model(first_model, "第一次微調模型")
736
+ del first_model
737
+ torch.cuda.empty_cache()
738
+ else:
739
+ all_results['first'] = None
740
+ else:
741
+ all_results['first'] = None
742
+
743
+ # 3. 評估第二次微調模型
744
+ if second_model_path != "請選擇":
745
+ print("\n" + "-" * 80)
746
+ print("3️⃣ 評估第二次微調模型")
747
+ print("-" * 80)
748
+
749
+ # 讀取模型資訊
750
+ with open('./saved_models_list.json', 'r') as f:
751
+ models_list = json.load(f)
752
+
753
+ second_model_info = None
754
+ for model_info in models_list:
755
+ if model_info['model_path'] == second_model_path:
756
+ second_model_info = model_info
757
+ break
758
+
759
+ if second_model_info:
760
+ tuning_method = second_model_info['tuning_method']
761
+
762
+ if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
763
+ base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
764
+ second_model = PeftModel.from_pretrained(base_model, second_model_path)
765
+ second_model = second_model.to(device)
766
+ else:
767
+ second_model = BertForSequenceClassification.from_pretrained(second_model_path).to(device)
768
+
769
+ all_results['second'] = evaluate_model(second_model, "第二次微調模型")
770
+ del second_model
771
+ torch.cuda.empty_cache()
772
+ else:
773
+ all_results['second'] = None
774
+ else:
775
+ all_results['second'] = None
776
+
777
+ print("\n" + "=" * 80)
778
+ print("✅ 新數據測試完成")
779
+ print("=" * 80)
780
+
781
+ return all_results
782
+
783
+ # ==================== 預測函數(保持原樣) ====================
784
+
785
+ def predict_text(model_choice, text_input):
786
+ """
787
+ 預測功能 - 支持選擇已訓練的模型,並同時顯示未微調和微調的預測結果
788
+ """
789
+
790
+ if not text_input or text_input.strip() == "":
791
+ return "請輸入文本", "請輸入文本"
792
+
793
+ try:
794
+ # ==================== 未微調的 BERT 預測 ====================
795
+ print("\n使用未微調 BERT 預測...")
796
+ baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
797
+ baseline_model = BertForSequenceClassification.from_pretrained(
798
+ "bert-base-uncased",
799
+ num_labels=2
800
+ ).to(device)
801
+ baseline_model.eval()
802
+
803
+ # Tokenize 輸入(未微調)
804
+ baseline_inputs = baseline_tokenizer(
805
+ text_input,
806
+ truncation=True,
807
+ padding='max_length',
808
+ max_length=256,
809
+ return_tensors='pt'
810
+ ).to(device)
811
+
812
+ # 預測(未微調)
813
+ with torch.no_grad():
814
+ baseline_outputs = baseline_model(**baseline_inputs)
815
+ baseline_probs = torch.nn.functional.softmax(baseline_outputs.logits, dim=-1)
816
+ baseline_pred_class = baseline_probs.argmax(-1).item()
817
+ baseline_confidence = baseline_probs[0][baseline_pred_class].item()
818
+
819
+ baseline_result = "存活" if baseline_pred_class == 0 else "死亡"
820
+ baseline_prob_survive = baseline_probs[0][0].item()
821
+ baseline_prob_death = baseline_probs[0][1].item()
822
+
823
+ baseline_output = f"""
824
+ # 🔵 未微調 BERT 預測結果
825
+
826
+ ## 預測類別: **{baseline_result}**
827
+
828
+ ## 信心度: **{baseline_confidence:.1%}**
829
+
830
+ ## 機率分布:
831
+ - 🟢 **存活機率**: {baseline_prob_survive:.2%}
832
+ - 🔴 **死亡機率**: {baseline_prob_death:.2%}
833
+
834
+ ---
835
+ **說明**: 此為原始 BERT 模型,未經任何領域資料訓練
836
+ """
837
+
838
+ # 清空記憶體
839
+ del baseline_model
840
+ del baseline_tokenizer
841
+ torch.cuda.empty_cache()
842
+
843
+ # ==================== 微調後的 BERT 預測 ====================
844
+
845
+ if model_choice == "請先訓練模型":
846
+ finetuned_output = """
847
+ # 🟢 微調 BERT 預測結果
848
+
849
+ ❌ 尚未訓練任何模型,請先在「模型訓練」頁面訓練模型
850
+ """
851
+ return baseline_output, finetuned_output
852
+
853
+ # 解析選擇的模型路徑
854
+ model_path = model_choice.split(" | ")[0].replace("路徑: ", "")
855
+
856
+ # 從 JSON 讀取模型資訊
857
+ with open('./saved_models_list.json', 'r') as f:
858
+ models_list = json.load(f)
859
+
860
+ selected_model_info = None
861
+ for model_info in models_list:
862
+ if model_info['model_path'] == model_path:
863
+ selected_model_info = model_info
864
+ break
865
+
866
+ if selected_model_info is None:
867
+ finetuned_output = f"""
868
+ # 🟢 微調 BERT 預測結果
869
+
870
+ ❌ 找不到模型:{model_path}
871
+ """
872
+ return baseline_output, finetuned_output
873
+
874
+ print(f"\n使用微調模型: {model_path}")
875
+
876
+ # 載入 tokenizer
877
+ finetuned_tokenizer = BertTokenizer.from_pretrained(model_path)
878
+
879
+ # 載入模型
880
+ tuning_method = selected_model_info['tuning_method']
881
+ if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
882
+ # 載入 PEFT 模型
883
+ base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
884
+ finetuned_model = PeftModel.from_pretrained(base_model, model_path)
885
+ finetuned_model = finetuned_model.to(device)
886
+ else:
887
+ # 載入一般模型
888
+ finetuned_model = BertForSequenceClassification.from_pretrained(model_path).to(device)
889
+
890
+ finetuned_model.eval()
891
+
892
+ # Tokenize 輸入(微調)
893
+ finetuned_inputs = finetuned_tokenizer(
894
+ text_input,
895
+ truncation=True,
896
+ padding='max_length',
897
+ max_length=256,
898
+ return_tensors='pt'
899
+ ).to(device)
900
+
901
+ # 預測(微調)
902
+ with torch.no_grad():
903
+ finetuned_outputs = finetuned_model(**finetuned_inputs)
904
+ finetuned_probs = torch.nn.functional.softmax(finetuned_outputs.logits, dim=-1)
905
+ finetuned_pred_class = finetuned_probs.argmax(-1).item()
906
+ finetuned_confidence = finetuned_probs[0][finetuned_pred_class].item()
907
+
908
+ finetuned_result = "存活" if finetuned_pred_class == 0 else "死亡"
909
+ finetuned_prob_survive = finetuned_probs[0][0].item()
910
+ finetuned_prob_death = finetuned_probs[0][1].item()
911
+
912
+ training_type_label = "二次微調" if selected_model_info.get('is_second_finetuning', False) else "第一次微調"
913
+
914
+ finetuned_output = f"""
915
+ # 🟢 微調 BERT 預測結果
916
+
917
+ ## 預測類別: **{finetuned_result}**
918
+
919
+ ## 信心度: **{finetuned_confidence:.1%}**
920
+
921
+ ## 機率分布:
922
+ - 🟢 **存活機率**: {finetuned_prob_survive:.2%}
923
+ - 🔴 **死亡機率**: {finetuned_prob_death:.2%}
924
+
925
+ ---
926
+ ### 模型資訊:
927
+ - **訓練類型**: {training_type_label}
928
+ - **微調方法**: {selected_model_info['tuning_method']}
929
+ - **最佳化指標**: {selected_model_info['best_metric']}
930
+ - **訓練時間**: {selected_model_info['timestamp']}
931
+ - **模型路徑**: {model_path}
932
+
933
+ ---
934
+ **注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。
935
+ """
936
+
937
+ # 清空記憶體
938
+ del finetuned_model
939
+ del finetuned_tokenizer
940
+ torch.cuda.empty_cache()
941
+
942
+ return baseline_output, finetuned_output
943
+
944
+ except Exception as e:
945
+ import traceback
946
+ error_msg = f"❌ 預測錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
947
+ return error_msg, error_msg
948
+
949
+ def get_available_models():
950
+ """
951
+ 取得所有已訓練的模型列表
952
+ """
953
+ models_list_file = './saved_models_list.json'
954
+ if not os.path.exists(models_list_file):
955
+ return ["請先訓練模型"]
956
+
957
+ with open(models_list_file, 'r') as f:
958
+ models_list = json.load(f)
959
+
960
+ if len(models_list) == 0:
961
+ return ["請先訓練模型"]
962
+
963
+ # 格式化模型選項
964
+ model_choices = []
965
+ for i, model_info in enumerate(models_list, 1):
966
+ training_type = model_info.get('training_type', '第一次微調')
967
+ choice = f"路徑: {model_info['model_path']} | 類型: {training_type} | 方法: {model_info['tuning_method']} | 時間: {model_info['timestamp']}"
968
+ model_choices.append(choice)
969
+
970
+ return model_choices
971
+
972
+ def get_first_finetuning_models():
973
+ """
974
+ 取得所有第一次微調的模型(用於二次微調選擇)
975
+ """
976
+ models_list_file = './saved_models_list.json'
977
+ if not os.path.exists(models_list_file):
978
+ return ["請先進行第一次微調"]
979
+
980
+ with open(models_list_file, 'r') as f:
981
+ models_list = json.load(f)
982
+
983
+ # 只返回第一次微調的模型
984
+ first_models = [m for m in models_list if not m.get('is_second_finetuning', False)]
985
+
986
+ if len(first_models) == 0:
987
+ return ["請先進行第一次微調"]
988
+
989
+ model_choices = []
990
+ for model_info in first_models:
991
+ choice = f"{model_info['model_path']}"
992
+ model_choices.append(choice)
993
+
994
+ return model_choices
995
+
996
+ # ==================== Wrapper 函數 ====================
997
+
998
+ def train_first_wrapper(
999
+ file, tuning_method, weight_mult, epochs, batch_size, lr, warmup, best_metric,
1000
+ lora_r, lora_alpha, lora_dropout, lora_modules,
1001
+ adalora_init_r, adalora_target_r, adalora_tinit, adalora_tfinal, adalora_delta_t
1002
+ ):
1003
+ """第一次微調的包裝函數"""
1004
+
1005
+ if file is None:
1006
+ return "請上傳 CSV 檔案", "", ""
1007
+
1008
+ try:
1009
+ results = run_original_code_with_tuning(
1010
+ file_path=file.name,
1011
+ weight_multiplier=weight_mult,
1012
+ epochs=int(epochs),
1013
+ batch_size=int(batch_size),
1014
+ learning_rate=lr,
1015
+ warmup_steps=int(warmup),
1016
+ tuning_method=tuning_method,
1017
+ best_metric=best_metric,
1018
+ lora_r=lora_r,
1019
+ lora_alpha=lora_alpha,
1020
+ lora_dropout=lora_dropout,
1021
+ lora_modules=lora_modules,
1022
+ adalora_init_r=adalora_init_r,
1023
+ adalora_target_r=adalora_target_r,
1024
+ adalora_tinit=adalora_tinit,
1025
+ adalora_tfinal=adalora_tfinal,
1026
+ adalora_delta_t=adalora_delta_t,
1027
+ is_second_finetuning=False
1028
+ )
1029
+
1030
+ baseline_results = results['baseline_results']
1031
+
1032
+ # 格式化輸出
1033
+ data_info = f"""
1034
+ # 📊 資料資訊 (第一次微調)
1035
+
1036
+ ## 🔧 訓練配置
1037
+ - **微調方法**: {results['tuning_method']}
1038
+ - **最佳化指標**: {results['best_metric']}
1039
+ - **最佳指標值**: {results['best_metric_value']:.4f}
1040
+
1041
+ ## ⚙️ 訓練參數
1042
+ - **權重倍數**: {weight_mult}x
1043
+ - **訓練輪數**: {epochs}
1044
+ - **批次大小**: {batch_size}
1045
+ - **學習率**: {lr}
1046
+ - **Warmup Steps**: {warmup}
1047
+
1048
+ ✅ 第一次微調完成!可進行二次微調或預測!
1049
+ """
1050
+
1051
+ baseline_output = f"""
1052
+ # 🔵 純 BERT (Baseline)
1053
+
1054
+ ### 📈 評估指標
1055
+
1056
+ | 指標 | 數值 |
1057
+ |------|------|
1058
+ | **F1 Score** | {baseline_results['f1']:.4f} |
1059
+ | **Accuracy** | {baseline_results['accuracy']:.4f} |
1060
+ | **Precision** | {baseline_results['precision']:.4f} |
1061
+ | **Recall** | {baseline_results['recall']:.4f} |
1062
+ | **Sensitivity** | {baseline_results['sensitivity']:.4f} |
1063
+ | **Specificity** | {baseline_results['specificity']:.4f} |
1064
+ | **AUC** | {baseline_results['auc']:.4f} |
1065
+
1066
+ ### 📈 混淆矩陣
1067
+
1068
+ | | 預測:存活 | 預測:死亡 |
1069
+ |---|-----------|-----------|
1070
+ | **實際:存活** | TN={baseline_results['tn']} | FP={baseline_results['fp']} |
1071
+ | **實際:死亡** | FN={baseline_results['fn']} | TP={baseline_results['tp']} |
1072
+ """
1073
+
1074
+ finetuned_output = f"""
1075
+ # 🟢 第一次微調 BERT
1076
+
1077
+ ### 📈 評估指標
1078
+
1079
+ | 指標 | 數值 |
1080
+ |------|------|
1081
+ | **F1 Score** | {results['eval_f1']:.4f} |
1082
+ | **Accuracy** | {results['eval_accuracy']:.4f} |
1083
+ | **Precision** | {results['eval_precision']:.4f} |
1084
+ | **Recall** | {results['eval_recall']:.4f} |
1085
+ | **Sensitivity** | {results['eval_sensitivity']:.4f} |
1086
+ | **Specificity** | {results['eval_specificity']:.4f} |
1087
+ | **AUC** | {results['eval_auc']:.4f} |
1088
+
1089
+ ### 📈 混淆矩陣
1090
+
1091
+ | | 預測:存活 | 預測:死亡 |
1092
+ |---|-----------|-----------|
1093
+ | **實際:存活** | TN={results['eval_tn']} | FP={results['eval_fp']} |
1094
+ | **實際:死亡** | FN={results['eval_fn']} | TP={results['eval_tp']} |
1095
+ """
1096
+
1097
+ return data_info, baseline_output, finetuned_output
1098
+
1099
+ except Exception as e:
1100
+ import traceback
1101
+ error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
1102
+ return error_msg, "", ""
1103
+
1104
+ def train_second_wrapper(
1105
+ base_model_choice, file, weight_mult, epochs, batch_size, lr, warmup, best_metric
1106
+ ):
1107
+ """二次微調的包裝函數"""
1108
+
1109
+ if base_model_choice == "請先進行第一次微調":
1110
+ return "請先在「第一次微調」頁面訓練模型", ""
1111
+
1112
+ if file is None:
1113
+ return "請上傳新的訓練數據 CSV 檔案", ""
1114
+
1115
+ try:
1116
+ # 解析基礎模型路徑
1117
+ base_model_path = base_model_choice
1118
+
1119
+ # 讀取第一次模型資訊
1120
+ with open('./saved_models_list.json', 'r') as f:
1121
+ models_list = json.load(f)
1122
+
1123
+ base_model_info = None
1124
+ for model_info in models_list:
1125
+ if model_info['model_path'] == base_model_path:
1126
+ base_model_info = model_info
1127
+ break
1128
+
1129
+ if base_model_info is None:
1130
+ return "找不到基礎模型資訊", ""
1131
+
1132
+ # 使用第一次的參數(二次微調不更換方法)
1133
+ tuning_method = base_model_info['tuning_method']
1134
+
1135
+ # 獲取第一次的 PEFT 參數
1136
+ lora_r = 16
1137
+ lora_alpha = 32
1138
+ lora_dropout = 0.1
1139
+ lora_modules = "query,value"
1140
+ adalora_init_r = 12
1141
+ adalora_target_r = 8
1142
+ adalora_tinit = 0
1143
+ adalora_tfinal = 0
1144
+ adalora_delta_t = 1
1145
+
1146
+ results = run_original_code_with_tuning(
1147
+ file_path=file.name,
1148
+ weight_multiplier=weight_mult,
1149
+ epochs=int(epochs),
1150
+ batch_size=int(batch_size),
1151
+ learning_rate=lr,
1152
+ warmup_steps=int(warmup),
1153
+ tuning_method=tuning_method,
1154
+ best_metric=best_metric,
1155
+ lora_r=lora_r,
1156
+ lora_alpha=lora_alpha,
1157
+ lora_dropout=lora_dropout,
1158
+ lora_modules=lora_modules,
1159
+ adalora_init_r=adalora_init_r,
1160
+ adalora_target_r=adalora_target_r,
1161
+ adalora_tinit=adalora_tinit,
1162
+ adalora_tfinal=adalora_tfinal,
1163
+ adalora_delta_t=adalora_delta_t,
1164
+ is_second_finetuning=True,
1165
+ base_model_path=base_model_path
1166
+ )
1167
+
1168
+ data_info = f"""
1169
+ # 📊 二次微調結果
1170
+
1171
+ ## 🔧 訓練配置
1172
+ - **基礎模型**: {base_model_path}
1173
+ - **微調方法**: {results['tuning_method']} (繼承自第一次)
1174
+ - **最佳化指標**: {results['best_metric']}
1175
+ - **最佳指標值**: {results['best_metric_value']:.4f}
1176
+
1177
+ ## ⚙️ 訓練參數
1178
+ - **權重倍數**: {weight_mult}x
1179
+ - **訓練輪數**: {epochs}
1180
+ - **批次大小**: {batch_size}
1181
+ - **學習率**: {lr}
1182
+ - **Warmup Steps**: {warmup}
1183
+
1184
+ ✅ 二次微調完成!可進行預測或新數據測試!
1185
+ """
1186
+
1187
+ finetuned_output = f"""
1188
+ # 🟢 二次微調 BERT
1189
+
1190
+ ### 📈 評估指標
1191
+
1192
+ | 指標 | 數值 |
1193
+ |------|------|
1194
+ | **F1 Score** | {results['eval_f1']:.4f} |
1195
+ | **Accuracy** | {results['eval_accuracy']:.4f} |
1196
+ | **Precision** | {results['eval_precision']:.4f} |
1197
+ | **Recall** | {results['eval_recall']:.4f} |
1198
+ | **Sensitivity** | {results['eval_sensitivity']:.4f} |
1199
+ | **Specificity** | {results['eval_specificity']:.4f} |
1200
+ | **AUC** | {results['eval_auc']:.4f} |
1201
+
1202
+ ### 📈 混淆矩陣
1203
+
1204
+ | | 預測:存活 | 預測:死亡 |
1205
+ |---|-----------|-----------|
1206
+ | **實際:存活** | TN={results['eval_tn']} | FP={results['eval_fp']} |
1207
+ | **實際:死亡** | FN={results['eval_fn']} | TP={results['eval_tp']} |
1208
+ """
1209
+
1210
+ return data_info, finetuned_output
1211
+
1212
+ except Exception as e:
1213
+ import traceback
1214
+ error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
1215
+ return error_msg, ""
1216
+
1217
+ def test_new_data_wrapper(test_file, baseline_choice, first_choice, second_choice):
1218
+ """新數據測試的包裝函數"""
1219
+
1220
+ if test_file is None:
1221
+ return "請上傳測試數據 CSV 檔案", "", ""
1222
+
1223
+ try:
1224
+ all_results = test_on_new_data(
1225
+ test_file.name,
1226
+ baseline_choice,
1227
+ first_choice,
1228
+ second_choice
1229
+ )
1230
+
1231
+ # 格式化輸出
1232
+ outputs = []
1233
+
1234
+ # 1. 純 BERT
1235
+ if all_results['baseline']:
1236
+ r = all_results['baseline']
1237
+ baseline_output = f"""
1238
+ # 🔵 純 BERT (Baseline)
1239
+
1240
+ | 指標 | 數值 |
1241
+ |------|------|
1242
+ | **F1 Score** | {r['f1']:.4f} |
1243
+ | **Accuracy** | {r['accuracy']:.4f} |
1244
+ | **Precision** | {r['precision']:.4f} |
1245
+ | **Recall** | {r['recall']:.4f} |
1246
+ | **Sensitivity** | {r['sensitivity']:.4f} |
1247
+ | **Specificity** | {r['specificity']:.4f} |
1248
+ | **AUC** | {r['auc']:.4f} |
1249
+
1250
+ ### 混淆矩陣
1251
+ | | 預測:存活 | 預測:死亡 |
1252
+ |---|-----------|-----------|
1253
+ | **實際:存活** | TN={r['tn']} | FP={r['fp']} |
1254
+ | **實際:死亡** | FN={r['fn']} | TP={r['tp']} |
1255
+ """
1256
+ else:
1257
+ baseline_output = "未選擇評估純 BERT"
1258
+ outputs.append(baseline_output)
1259
+
1260
+ # 2. 第一次微調
1261
+ if all_results['first']:
1262
+ r = all_results['first']
1263
+ first_output = f"""
1264
+ # 🟢 第一次微調模型
1265
+
1266
+ | 指標 | 數值 |
1267
+ |------|------|
1268
+ | **F1 Score** | {r['f1']:.4f} |
1269
+ | **Accuracy** | {r['accuracy']:.4f} |
1270
+ | **Precision** | {r['precision']:.4f} |
1271
+ | **Recall** | {r['recall']:.4f} |
1272
+ | **Sensitivity** | {r['sensitivity']:.4f} |
1273
+ | **Specificity** | {r['specificity']:.4f} |
1274
+ | **AUC** | {r['auc']:.4f} |
1275
+
1276
+ ### 混淆矩陣
1277
+ | | 預測:存活 | 預測:死亡 |
1278
+ |---|-----------|-----------|
1279
+ | **實際:存活** | TN={r['tn']} | FP={r['fp']} |
1280
+ | **實際:死亡** | FN={r['fn']} | TP={r['tp']} |
1281
+ """
1282
+ else:
1283
+ first_output = "未選擇第一次微調模型"
1284
+ outputs.append(first_output)
1285
+
1286
+ # 3. 第二次微調
1287
+ if all_results['second']:
1288
+ r = all_results['second']
1289
+ second_output = f"""
1290
+ # 🟡 第二次微調模型
1291
+
1292
+ | 指標 | 數值 |
1293
+ |------|------|
1294
+ | **F1 Score** | {r['f1']:.4f} |
1295
+ | **Accuracy** | {r['accuracy']:.4f} |
1296
+ | **Precision** | {r['precision']:.4f} |
1297
+ | **Recall** | {r['recall']:.4f} |
1298
+ | **Sensitivity** | {r['sensitivity']:.4f} |
1299
+ | **Specificity** | {r['specificity']:.4f} |
1300
+ | **AUC** | {r['auc']:.4f} |
1301
+
1302
+ ### 混淆矩陣
1303
+ | | 預測:存活 | 預測:死亡 |
1304
+ |---|-----------|-----------|
1305
+ | **實際:存活** | TN={r['tn']} | FP={r['fp']} |
1306
+ | **實際:死亡** | FN={r['fn']} | TP={r['tp']} |
1307
+ """
1308
+ else:
1309
+ second_output = "未選擇第二次微調模型"
1310
+ outputs.append(second_output)
1311
+
1312
+ return outputs[0], outputs[1], outputs[2]
1313
+
1314
+ except Exception as e:
1315
+ import traceback
1316
+ error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
1317
+ return error_msg, "", ""
1318
+
1319
+ # ============================================================================
1320
+ # Gradio 介面
1321
+ # ============================================================================
1322
+
1323
+ with gr.Blocks(title="BERT 二次微調平台", theme=gr.themes.Soft()) as demo:
1324
+
1325
+ gr.Markdown("""
1326
+ # 🥼 BERT 乳癌存活預測 - 二次微調完整平台
1327
+
1328
+ ### 🌟 功能特色:
1329
+ - 🎯 第一次微調:從純 BERT 開始訓練
1330
+ - 🔄 第二次微調:基於第一次模型用新數據繼續訓練
1331
+ - 📊 新數據測試:比較三個模型在新數據的表現
1332
+ - 🔮 預測功能:使用訓練好的模型進行預測
1333
+ """)
1334
+
1335
+ # Tab 1: 第一次微調
1336
+ with gr.Tab("1️⃣ 第一次微調"):
1337
+ with gr.Row():
1338
+ with gr.Column(scale=1):
1339
+ gr.Markdown("### 📤 資料上傳")
1340
+ file_input_first = gr.File(label="上傳訓練數據 CSV", file_types=[".csv"])
1341
+
1342
+ gr.Markdown("### 🔧 微調方法選擇")
1343
+ tuning_method_first = gr.Radio(
1344
+ choices=["Full Fine-tuning", "LoRA", "AdaLoRA"],
1345
+ value="Full Fine-tuning",
1346
+ label="選擇微調方法"
1347
+ )
1348
+
1349
+ gr.Markdown("### 🎯 最佳模型選擇")
1350
+ best_metric_first = gr.Dropdown(
1351
+ choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity", "auc"],
1352
+ value="f1",
1353
+ label="選擇最佳化指標"
1354
+ )
1355
+
1356
+ gr.Markdown("### ⚙️ 訓練參數")
1357
+ weight_slider_first = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="權重倍數")
1358
+ epochs_input_first = gr.Number(value=8, label="訓練輪數")
1359
+ batch_size_input_first = gr.Number(value=16, label="批次大小")
1360
+ lr_input_first = gr.Number(value=2e-5, label="學習率")
1361
+ warmup_input_first = gr.Number(value=200, label="Warmup Steps")
1362
+
1363
+ # LoRA 參數
1364
+ with gr.Column(visible=False) as lora_params_first:
1365
+ gr.Markdown("### 🔷 LoRA 參數")
1366
+ lora_r_first = gr.Slider(4, 64, value=16, step=4, label="LoRA Rank (r)")
1367
+ lora_alpha_first = gr.Slider(8, 128, value=32, step=8, label="LoRA Alpha")
1368
+ lora_dropout_first = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="LoRA Dropout")
1369
+ lora_modules_first = gr.Textbox(value="query,value", label="目標模組")
1370
+
1371
+ # AdaLoRA 參數
1372
+ with gr.Column(visible=False) as adalora_params_first:
1373
+ gr.Markdown("### 🔶 AdaLoRA 參數")
1374
+ adalora_init_r_first = gr.Slider(4, 64, value=12, step=4, label="初始 Rank")
1375
+ adalora_target_r_first = gr.Slider(4, 64, value=8, step=4, label="目標 Rank")
1376
+ adalora_tinit_first = gr.Number(value=0, label="Tinit")
1377
+ adalora_tfinal_first = gr.Number(value=0, label="Tfinal")
1378
+ adalora_delta_t_first = gr.Number(value=1, label="Delta T")
1379
+
1380
+ train_button_first = gr.Button("🚀 開始第一次微調", variant="primary", size="lg")
1381
+
1382
+ with gr.Column(scale=2):
1383
+ gr.Markdown("### 📊 第一次微調結果")
1384
+ data_info_output_first = gr.Markdown(value="等待訓練...")
1385
+ with gr.Row():
1386
+ baseline_output_first = gr.Markdown(value="### 純 BERT\n等待訓練...")
1387
+ finetuned_output_first = gr.Markdown(value="### 第一次微調\n等待訓練...")
1388
+
1389
+ # Tab 2: 二次微調
1390
+ with gr.Tab("2️⃣ 二次微調"):
1391
+ with gr.Row():
1392
+ with gr.Column(scale=1):
1393
+ gr.Markdown("### 🔄 選擇基礎模型")
1394
+ base_model_dropdown = gr.Dropdown(
1395
+ label="選擇第一次微調的模型",
1396
+ choices=["請先進行第一次微調"],
1397
+ value="請先進行第一次微調"
1398
+ )
1399
+ refresh_base_models = gr.Button("🔄 重新整理模型列表", size="sm")
1400
+
1401
+ gr.Markdown("### 📤 上傳新訓練數據")
1402
+ file_input_second = gr.File(label="上傳新的訓練數據 CSV", file_types=[".csv"])
1403
+
1404
+ gr.Markdown("### ⚙️ 訓練參數")
1405
+ gr.Markdown("⚠️ 微調方法將自動繼承第一次微調的方法")
1406
+ best_metric_second = gr.Dropdown(
1407
+ choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity", "auc"],
1408
+ value="f1",
1409
+ label="選擇最佳化指標"
1410
+ )
1411
+ weight_slider_second = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="權重倍數")
1412
+ epochs_input_second = gr.Number(value=5, label="訓練輪數", info="建議比第一次少")
1413
+ batch_size_input_second = gr.Number(value=16, label="批次大小")
1414
+ lr_input_second = gr.Number(value=1e-5, label="學習率", info="建議比第一次小")
1415
+ warmup_input_second = gr.Number(value=100, label="Warmup Steps")
1416
+
1417
+ train_button_second = gr.Button("🚀 開始二次微調", variant="primary", size="lg")
1418
+
1419
+ with gr.Column(scale=2):
1420
+ gr.Markdown("### 📊 二次微調結果")
1421
+ data_info_output_second = gr.Markdown(value="等待訓練...")
1422
+ finetuned_output_second = gr.Markdown(value="### 二次微調\n等待訓練...")
1423
+
1424
+ # Tab 3: 新數據測試
1425
+ with gr.Tab("3️⃣ 新數據測試"):
1426
+ with gr.Row():
1427
+ with gr.Column(scale=1):
1428
+ gr.Markdown("### 📤 上傳測試數據")
1429
+ test_file_input = gr.File(label="上傳測試數據 CSV", file_types=[".csv"])
1430
+
1431
+ gr.Markdown("### 🎯 選擇要比較的模型")
1432
+ gr.Markdown("可選擇 1-3 個模型進行比較")
1433
+
1434
+ baseline_test_choice = gr.Radio(
1435
+ choices=["評估純 BERT", "跳過"],
1436
+ value="評估純 BERT",
1437
+ label="純 BERT (Baseline)"
1438
+ )
1439
+
1440
+ first_model_test_dropdown = gr.Dropdown(
1441
+ label="第一次微調模型",
1442
+ choices=["請選擇"],
1443
+ value="請選擇"
1444
+ )
1445
+
1446
+ second_model_test_dropdown = gr.Dropdown(
1447
+ label="第二次微調模型",
1448
+ choices=["請選擇"],
1449
+ value="請選擇"
1450
+ )
1451
+
1452
+ refresh_test_models = gr.Button("🔄 重新整理模型列表", size="sm")
1453
+ test_button = gr.Button("📊 開始測試", variant="primary", size="lg")
1454
+
1455
+ with gr.Column(scale=2):
1456
+ gr.Markdown("### 📊 新數據測試結果 - 三模型比較")
1457
+ with gr.Row():
1458
+ baseline_test_output = gr.Markdown(value="### 純 BERT\n等待測試...")
1459
+ first_test_output = gr.Markdown(value="### 第一次微調\n等待測試...")
1460
+ second_test_output = gr.Markdown(value="### 二次微調\n等待測試...")
1461
+
1462
+ # Tab 4: 預測
1463
+ with gr.Tab("4️⃣ 模型預測"):
1464
+ gr.Markdown("""
1465
+ ### 使用訓練好的模型進行預測
1466
+ 選擇已訓練的模型,輸入病歷文本進行預測。
1467
+ """)
1468
+
1469
+ with gr.Row():
1470
+ with gr.Column():
1471
+ model_dropdown = gr.Dropdown(
1472
+ label="選擇模型",
1473
+ choices=["請先訓練模型"],
1474
+ value="請先訓練模型"
1475
+ )
1476
+ refresh_predict_models = gr.Button("🔄 重新整理模型列表", size="sm")
1477
+
1478
+ text_input = gr.Textbox(
1479
+ label="輸入病歷文本",
1480
+ placeholder="請輸入患者的病歷描述(英文)...",
1481
+ lines=10
1482
+ )
1483
+
1484
+ predict_button = gr.Button("🔮 開始預測", variant="primary", size="lg")
1485
+
1486
+ with gr.Column():
1487
+ gr.Markdown("### 預測結果比較")
1488
+ baseline_prediction_output = gr.Markdown(label="未微調 BERT", value="等待預測...")
1489
+ finetuned_prediction_output = gr.Markdown(label="微調 BERT", value="等待預測...")
1490
+
1491
+ # Tab 5: 使用說明
1492
+ with gr.Tab("📖 使用說明"):
1493
+ gr.Markdown("""
1494
+ ## 🔄 二次微調流程說明
1495
+
1496
+ ### 步驟 1: 第一次微調
1497
+ 1. 上傳訓練數據 A (CSV 格式: Text, label)
1498
+ 2. 選擇微調方法 (Full Fine-tuning / LoRA / AdaLoRA)
1499
+ 3. 調整訓練參數
1500
+ 4. 開始訓練
1501
+ 5. 系統會自動比較純 BERT vs 第一次微調的表現
1502
+
1503
+ ### 步驟 2: 二次微調
1504
+ 1. 選擇已訓練的第一次微調模型
1505
+ 2. 上傳新的訓練數據 B
1506
+ 3. 調整訓練參數 (建議 epochs 更少, learning rate 更小)
1507
+ 4. 開始訓練 (方法自動繼承第一次)
1508
+ 5. 模型會基於第一次的權重繼續學習
1509
+
1510
+ ### 步驟 3: 新數據測試
1511
+ 1. 上傳測試數據 C
1512
+ 2. 選擇要比較的模型 (純 BERT / 第一次 / 第二次)
1513
+ 3. 系統會並排顯示三個模型的表現
1514
+
1515
+ ### 步驟 4: 預測
1516
+ 1. 選擇任一已訓練模型
1517
+ 2. 輸入病歷文本
1518
+ 3. 查看預測結果
1519
+
1520
+ ## 🎯 微調方法說明
1521
+
1522
+ | 方法 | 訓練速度 | 記憶體 | 效果 |
1523
+ |------|---------|--------|------|
1524
+ | **Full Fine-tuning** | 1x (基準) | 高 | 最佳 |
1525
+ | **LoRA** | 3-5x 快 | 低 | 良好 |
1526
+ | **AdaLoRA** | 3-5x 快 | 低 | 良好 |
1527
+
1528
+ ## 💡 二次微調建議
1529
+
1530
+ ### 訓練參數調整:
1531
+ - **Epochs**: 第二次建議 3-5 輪 (第一次通常 8-10 輪)
1532
+ - **Learning Rate**: 第二次建議 1e-5 (第一次通常 2e-5)
1533
+ - **Warmup Steps**: 第二次建議減半
1534
+
1535
+ ### 適用場景:
1536
+ 1. **領域適應**: 第一次用通用醫療數據,第二次用特定醫院數據
1537
+ 2. **增量學習**: 隨時間增加新病例數據
1538
+ 3. **數據稀缺**: 先用大量相關數據預訓練,再用少量目標數據微調
1539
+
1540
+ ## ⚠️ 注意事項
1541
+
1542
+ - CSV 格式必須包含 `Text` 和 `label` 欄位
1543
+ - 第二次微調會自動使用第一次的微調方法
1544
+ - 建議第二次的學習率比第一次小,避免破壞已學習的知識
1545
+ - 新數據測試可以同時評估最多 3 個模型
1546
+
1547
+ ## 📊 指標說明
1548
+
1549
+ - **F1 Score**: 平衡指標,綜合考慮精確率和召回率
1550
+ - **Accuracy**: 整體準確率
1551
+ - **Precision**: 預測為死亡中的準確率
1552
+ - **Recall/Sensitivity**: 實際死亡中被正確識別的比例
1553
+ - **Specificity**: 實際存活中被正確識別的比例
1554
+ - **AUC**: ROC 曲線下面積,整體分類能力
1555
+ """)
1556
+
1557
+ # ==================== 事件綁定 ====================
1558
+
1559
+ # 第一次微調 - 參數面板顯示/隱藏
1560
+ def update_first_params(method):
1561
+ if method == "LoRA":
1562
+ return gr.update(visible=True), gr.update(visible=False)
1563
+ elif method == "AdaLoRA":
1564
+ return gr.update(visible=True), gr.update(visible=True)
1565
+ else:
1566
+ return gr.update(visible=False), gr.update(visible=False)
1567
+
1568
+ tuning_method_first.change(
1569
+ fn=update_first_params,
1570
+ inputs=[tuning_method_first],
1571
+ outputs=[lora_params_first, adalora_params_first]
1572
+ )
1573
+
1574
+ # 第一次微調按鈕
1575
+ train_button_first.click(
1576
+ fn=train_first_wrapper,
1577
+ inputs=[
1578
+ file_input_first, tuning_method_first, weight_slider_first,
1579
+ epochs_input_first, batch_size_input_first, lr_input_first,
1580
+ warmup_input_first, best_metric_first,
1581
+ lora_r_first, lora_alpha_first, lora_dropout_first, lora_modules_first,
1582
+ adalora_init_r_first, adalora_target_r_first, adalora_tinit_first,
1583
+ adalora_tfinal_first, adalora_delta_t_first
1584
+ ],
1585
+ outputs=[data_info_output_first, baseline_output_first, finetuned_output_first]
1586
+ )
1587
+
1588
+ # 刷新基礎模型列表
1589
+ def refresh_base_models_list():
1590
+ choices = get_first_finetuning_models()
1591
+ return gr.update(choices=choices, value=choices[0])
1592
+
1593
+ refresh_base_models.click(
1594
+ fn=refresh_base_models_list,
1595
+ outputs=[base_model_dropdown]
1596
+ )
1597
+
1598
+ # 二次微調按鈕
1599
+ train_button_second.click(
1600
+ fn=train_second_wrapper,
1601
+ inputs=[
1602
+ base_model_dropdown, file_input_second, weight_slider_second,
1603
+ epochs_input_second, batch_size_input_second, lr_input_second,
1604
+ warmup_input_second, best_metric_second
1605
+ ],
1606
+ outputs=[data_info_output_second, finetuned_output_second]
1607
+ )
1608
+
1609
+ # 刷新測試模型列表
1610
+ def refresh_test_models_list():
1611
+ all_models = get_available_models()
1612
+ first_models = get_first_finetuning_models()
1613
+
1614
+ # 篩選第二次微調模型
1615
+ with open('./saved_models_list.json', 'r') as f:
1616
+ models_list = json.load(f)
1617
+ second_models = [m['model_path'] for m in models_list if m.get('is_second_finetuning', False)]
1618
+
1619
+ if len(second_models) == 0:
1620
+ second_models = ["請選擇"]
1621
+
1622
+ return (
1623
+ gr.update(choices=first_models if first_models[0] != "請先進行第一次微調" else ["請選擇"], value="請選擇"),
1624
+ gr.update(choices=second_models, value="請選擇")
1625
+ )
1626
+
1627
+ refresh_test_models.click(
1628
+ fn=refresh_test_models_list,
1629
+ outputs=[first_model_test_dropdown, second_model_test_dropdown]
1630
+ )
1631
+
1632
+ # 測試按鈕
1633
+ test_button.click(
1634
+ fn=test_new_data_wrapper,
1635
+ inputs=[test_file_input, baseline_test_choice, first_model_test_dropdown, second_model_test_dropdown],
1636
+ outputs=[baseline_test_output, first_test_output, second_test_output]
1637
+ )
1638
+
1639
+ # 刷新預測模型列表
1640
+ def refresh_predict_models_list():
1641
+ choices = get_available_models()
1642
+ return gr.update(choices=choices, value=choices[0])
1643
+
1644
+ refresh_predict_models.click(
1645
+ fn=refresh_predict_models_list,
1646
+ outputs=[model_dropdown]
1647
+ )
1648
+
1649
+ # 預測按鈕
1650
+ predict_button.click(
1651
+ fn=predict_text,
1652
+ inputs=[model_dropdown, text_input],
1653
+ outputs=[baseline_prediction_output, finetuned_prediction_output]
1654
+ )
1655
+
1656
+ if __name__ == "__main__":
1657
+ demo.launch()