Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,77 +1,196 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
|
|
|
| 3 |
import torch
|
| 4 |
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
|
| 5 |
from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType
|
| 6 |
from datasets import Dataset
|
| 7 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
| 8 |
from torch import nn
|
|
|
|
| 9 |
import os
|
| 10 |
from datetime import datetime
|
| 11 |
import gc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
|
|
|
| 13 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
| 15 |
|
| 16 |
-
#
|
| 17 |
torch.backends.cudnn.benchmark = False
|
| 18 |
if torch.cuda.is_available():
|
| 19 |
torch.cuda.empty_cache()
|
| 20 |
|
| 21 |
-
# 全域變數
|
| 22 |
trained_models = {}
|
| 23 |
model_counter = 0
|
| 24 |
-
|
| 25 |
-
baseline_model_cache = {}
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
| 42 |
def compute_metrics(pred):
|
|
|
|
| 43 |
try:
|
| 44 |
labels = pred.label_ids
|
| 45 |
preds = pred.predictions.argmax(-1)
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
acc = accuracy_score(labels, preds)
|
|
|
|
|
|
|
| 48 |
cm = confusion_matrix(labels, preds)
|
|
|
|
| 49 |
if cm.shape == (2, 2):
|
| 50 |
tn, fp, fn, tp = cm.ravel()
|
| 51 |
-
|
| 52 |
-
|
| 53 |
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 54 |
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return {
|
| 56 |
-
'accuracy': acc,
|
| 57 |
-
'
|
| 58 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
except Exception as e:
|
| 61 |
print(f"Error in compute_metrics: {e}")
|
| 62 |
-
return {
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
}
|
| 66 |
|
| 67 |
-
|
|
|
|
| 68 |
"""評估未微調的基準模型"""
|
| 69 |
model.eval()
|
| 70 |
all_preds = []
|
| 71 |
all_labels = []
|
| 72 |
|
| 73 |
-
from torch.utils.data import DataLoader
|
| 74 |
-
|
| 75 |
def collate_fn(batch):
|
| 76 |
return {
|
| 77 |
'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
|
|
@@ -79,7 +198,13 @@ def evaluate_baseline(model, tokenizer, test_dataset, device):
|
|
| 79 |
'labels': torch.tensor([item['label'] for item in batch])
|
| 80 |
}
|
| 81 |
|
| 82 |
-
dataloader = DataLoader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
with torch.no_grad():
|
| 85 |
for batch in dataloader:
|
|
@@ -90,636 +215,1137 @@ def evaluate_baseline(model, tokenizer, test_dataset, device):
|
|
| 90 |
all_preds.extend(preds.cpu().numpy())
|
| 91 |
all_labels.extend(labels.numpy())
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
acc = accuracy_score(all_labels, all_preds)
|
|
|
|
| 95 |
cm = confusion_matrix(all_labels, all_preds)
|
|
|
|
| 96 |
if cm.shape == (2, 2):
|
| 97 |
tn, fp, fn, tp = cm.ravel()
|
| 98 |
-
|
| 99 |
-
tn = fp = fn = tp = 0
|
| 100 |
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 101 |
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
|
|
|
|
|
|
| 102 |
|
| 103 |
return {
|
| 104 |
-
'accuracy': acc,
|
| 105 |
-
'
|
| 106 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
}
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
super().__init__(*args, **kwargs)
|
| 112 |
self.class_weights = class_weights
|
| 113 |
self.use_focal_loss = use_focal_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
|
|
| 116 |
labels = inputs.pop("labels")
|
| 117 |
outputs = model(**inputs)
|
| 118 |
logits = outputs.logits
|
| 119 |
|
| 120 |
-
if self.use_focal_loss:
|
| 121 |
-
# Focal Loss
|
| 122 |
ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')(
|
| 123 |
logits.view(-1, 2), labels.view(-1)
|
| 124 |
)
|
| 125 |
pt = torch.exp(-ce_loss)
|
| 126 |
-
focal_loss = ((1 - pt) **
|
| 127 |
loss = focal_loss
|
| 128 |
-
|
|
|
|
| 129 |
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
|
| 130 |
loss = loss_fct(logits.view(-1, 2), labels.view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
return (loss, outputs) if return_outputs else loss
|
| 133 |
-
|
| 134 |
-
def evaluate_baseline(model, tokenizer, test_dataset, device):
|
| 135 |
-
"""評估未微調的基準模型"""
|
| 136 |
-
model.eval()
|
| 137 |
-
all_preds = []
|
| 138 |
-
all_labels = []
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
preds = torch.argmax(outputs.logits, dim=-1)
|
| 157 |
-
all_preds.extend(preds.cpu().numpy())
|
| 158 |
-
all_labels.extend(labels.numpy())
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
else:
|
| 166 |
-
|
| 167 |
-
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 168 |
-
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 169 |
-
|
| 170 |
-
return {
|
| 171 |
-
'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
|
| 172 |
-
'sensitivity': sensitivity, 'specificity': specificity,
|
| 173 |
-
'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
|
| 174 |
-
}
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
model_mapping = {
|
| 182 |
"BERT-base": "bert-base-uncased",
|
|
|
|
|
|
|
|
|
|
| 183 |
}
|
| 184 |
|
| 185 |
model_name = model_mapping.get(base_model, "bert-base-uncased")
|
| 186 |
|
| 187 |
try:
|
|
|
|
| 188 |
if csv_file is None:
|
| 189 |
-
return "❌ 請上傳 CSV", "", "", ""
|
| 190 |
|
| 191 |
df = pd.read_csv(csv_file.name)
|
| 192 |
if 'Text' not in df.columns or 'label' not in df.columns:
|
| 193 |
-
return "❌
|
| 194 |
|
|
|
|
| 195 |
df_clean = pd.DataFrame({
|
| 196 |
-
'text': df['Text'].astype(str),
|
| 197 |
'label': df['label'].astype(int)
|
| 198 |
}).dropna()
|
| 199 |
|
|
|
|
| 200 |
n0 = int(sum(df_clean['label'] == 0))
|
| 201 |
n1 = int(sum(df_clean['label'] == 1))
|
| 202 |
-
if n1 == 0:
|
| 203 |
-
return "❌ 無死亡樣本", "", "", ""
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
tokenizer = BertTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
| 216 |
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
|
| 217 |
|
| 218 |
-
def preprocess(
|
| 219 |
-
return tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
|
| 222 |
-
split = tokenized.train_test_split(test_size=0.2, seed=42)
|
| 223 |
|
|
|
|
| 224 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 225 |
-
info += f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
# 🔇 靜默評估基準模型(不顯示在資料資訊中)
|
| 228 |
baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
| 229 |
baseline_model = baseline_model.to(device)
|
| 230 |
|
| 231 |
-
baseline_perf = evaluate_baseline(
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
#
|
| 236 |
del baseline_model
|
| 237 |
if torch.cuda.is_available():
|
| 238 |
torch.cuda.empty_cache()
|
| 239 |
gc.collect()
|
| 240 |
|
| 241 |
-
#
|
| 242 |
-
info += f"\n
|
|
|
|
|
|
|
| 243 |
model = BertForSequenceClassification.from_pretrained(
|
| 244 |
-
model_name,
|
|
|
|
| 245 |
hidden_dropout_prob=dropout,
|
| 246 |
attention_probs_dropout_prob=dropout
|
| 247 |
)
|
| 248 |
|
|
|
|
| 249 |
peft_applied = False
|
| 250 |
if method == "lora":
|
|
|
|
|
|
|
| 251 |
config = LoraConfig(
|
| 252 |
-
task_type=TaskType.SEQ_CLS,
|
| 253 |
-
r=int(lora_r),
|
| 254 |
lora_alpha=int(lora_alpha),
|
| 255 |
-
lora_dropout=lora_dropout,
|
| 256 |
-
target_modules=["query", "value"],
|
| 257 |
bias="none"
|
| 258 |
)
|
| 259 |
model = get_peft_model(model, config)
|
| 260 |
peft_applied = True
|
| 261 |
-
info += f"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
elif method == "adalora":
|
|
|
|
|
|
|
| 263 |
config = AdaLoraConfig(
|
| 264 |
-
task_type=TaskType.SEQ_CLS,
|
| 265 |
-
r=int(lora_r),
|
| 266 |
lora_alpha=int(lora_alpha),
|
| 267 |
-
lora_dropout=lora_dropout,
|
| 268 |
target_modules=["query", "value"],
|
| 269 |
-
init_r=12,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
)
|
| 271 |
model = get_peft_model(model, config)
|
| 272 |
peft_applied = True
|
| 273 |
-
info += f"
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
info += f"
|
|
|
|
| 277 |
|
| 278 |
model = model.to(device)
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
|
|
|
| 284 |
weights = torch.tensor([w0, w1], dtype=torch.float).to(device)
|
|
|
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
num_train_epochs=int(num_epochs),
|
| 289 |
-
per_device_train_batch_size=int(batch_size),
|
| 290 |
-
per_device_eval_batch_size=int(batch_size)*2,
|
| 291 |
-
learning_rate=float(learning_rate),
|
| 292 |
weight_decay=float(weight_decay),
|
| 293 |
-
evaluation_strategy="epoch",
|
| 294 |
-
save_strategy="no", #
|
| 295 |
-
load_best_model_at_end=False,
|
| 296 |
-
report_to="none",
|
| 297 |
-
logging_steps=10,
|
| 298 |
-
warmup_steps=
|
| 299 |
-
logging_first_step=True
|
|
|
|
|
|
|
| 300 |
)
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
train_dataset=split['train'],
|
| 306 |
-
eval_dataset=split['test'],
|
| 307 |
compute_metrics=compute_metrics,
|
| 308 |
class_weights=weights,
|
| 309 |
-
use_focal_loss=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
)
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
info += "\n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
-
|
| 318 |
-
info += f"\n
|
| 319 |
-
info += f"\n - 訓練樣本: {len(split['train'])}"
|
| 320 |
-
info += f"\n - 測試樣本: {len(split['test'])}"
|
| 321 |
-
info += f"\n - 批次數/epoch: {len(split['train']) // int(batch_size)}"
|
| 322 |
|
|
|
|
| 323 |
train_result = trainer.train()
|
| 324 |
|
| 325 |
-
#
|
| 326 |
-
|
| 327 |
-
|
| 328 |
|
| 329 |
-
|
|
|
|
| 330 |
|
| 331 |
-
#
|
| 332 |
model_counter += 1
|
| 333 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 334 |
-
model_id = f"{base_model}_{method}_{timestamp}"
|
|
|
|
| 335 |
trained_models[model_id] = {
|
| 336 |
-
'model': model,
|
| 337 |
-
'tokenizer': tokenizer,
|
| 338 |
-
'results':
|
| 339 |
-
'baseline': baseline_perf,
|
| 340 |
'config': {
|
| 341 |
-
'type': base_model,
|
| 342 |
-
'model_name': model_name,
|
| 343 |
-
'method': method,
|
| 344 |
-
'metric': best_metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
},
|
| 346 |
-
'timestamp': timestamp
|
|
|
|
| 347 |
}
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
#
|
| 358 |
-
baseline_output =
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
# 微調 BERT 輸出
|
| 371 |
-
finetuned_output = f"✅ 微調 BERT\n"
|
| 372 |
-
finetuned_output += f"模型: {model_id}\n\n"
|
| 373 |
-
finetuned_output += f"📊 表現\n"
|
| 374 |
-
finetuned_output += f"F1: {results['eval_f1']:.4f}\n"
|
| 375 |
-
finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
|
| 376 |
-
finetuned_output += f"Precision: {results['eval_precision']:.4f}\n"
|
| 377 |
-
finetuned_output += f"Recall: {results['eval_recall']:.4f}\n"
|
| 378 |
-
finetuned_output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n"
|
| 379 |
-
finetuned_output += f"Specificity: {results['eval_specificity']:.4f}\n\n"
|
| 380 |
-
finetuned_output += f"混淆矩陣\n"
|
| 381 |
-
finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n"
|
| 382 |
-
finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}"
|
| 383 |
-
|
| 384 |
-
# 比較結果輸出
|
| 385 |
-
comparison_output = f"📊 純 BERT vs 微調 BERT 比較\n\n"
|
| 386 |
-
comparison_output += f"指標改善:\n"
|
| 387 |
-
comparison_output += f"F1: {baseline_perf['f1']:.4f} → {results['eval_f1']:.4f} ({format_improve(f1_improve)})\n"
|
| 388 |
-
comparison_output += f"Accuracy: {baseline_perf['accuracy']:.4f} → {results['eval_accuracy']:.4f} ({format_improve(acc_improve)})\n"
|
| 389 |
-
comparison_output += f"Precision: {baseline_perf['precision']:.4f} → {results['eval_precision']:.4f} ({format_improve(prec_improve)})\n"
|
| 390 |
-
comparison_output += f"Recall: {baseline_perf['recall']:.4f} → {results['eval_recall']:.4f} ({format_improve(rec_improve)})\n"
|
| 391 |
-
comparison_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f} → {results['eval_sensitivity']:.4f} ({format_improve(sens_improve)})\n"
|
| 392 |
-
comparison_output += f"Specificity: {baseline_perf['specificity']:.4f} → {results['eval_specificity']:.4f} ({format_improve(spec_improve)})\n\n"
|
| 393 |
-
comparison_output += f"混淆矩陣變化:\n"
|
| 394 |
-
comparison_output += f"TP: {baseline_perf['tp']} → {results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n"
|
| 395 |
-
comparison_output += f"TN: {baseline_perf['tn']} → {results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n"
|
| 396 |
-
comparison_output += f"FP: {baseline_perf['fp']} → {results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n"
|
| 397 |
-
comparison_output += f"FN: {baseline_perf['fn']} → {results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})"
|
| 398 |
-
|
| 399 |
-
info += "\n\n✅ 訓練完成!"
|
| 400 |
-
|
| 401 |
-
return info, baseline_output, finetuned_output, comparison_output
|
| 402 |
|
| 403 |
except Exception as e:
|
| 404 |
import traceback
|
| 405 |
-
error_msg = f"❌
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
|
|
|
| 408 |
def predict(model_id, text):
|
| 409 |
-
|
| 410 |
|
| 411 |
if not model_id or model_id not in trained_models:
|
| 412 |
-
return "❌
|
| 413 |
-
|
| 414 |
-
|
|
|
|
| 415 |
|
| 416 |
try:
|
|
|
|
| 417 |
info = trained_models[model_id]
|
| 418 |
-
model
|
|
|
|
| 419 |
config = info['config']
|
| 420 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
-
#
|
| 426 |
model.eval()
|
| 427 |
with torch.no_grad():
|
| 428 |
-
outputs = model(**
|
| 429 |
-
|
|
|
|
| 430 |
pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
|
|
|
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
cache_key = config['model_name']
|
| 436 |
-
if cache_key not in baseline_model_cache:
|
| 437 |
-
baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2)
|
| 438 |
-
baseline_model = baseline_model.to(device)
|
| 439 |
-
baseline_model.eval()
|
| 440 |
-
baseline_model_cache[cache_key] = baseline_model
|
| 441 |
-
else:
|
| 442 |
-
baseline_model = baseline_model_cache[cache_key]
|
| 443 |
|
| 444 |
with torch.no_grad():
|
| 445 |
-
outputs_baseline = baseline_model(**
|
| 446 |
-
|
|
|
|
| 447 |
pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
|
|
|
|
| 448 |
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
|
| 453 |
|
| 454 |
-
output = f"""🔮
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
-
|
| 457 |
|
| 458 |
-
{
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
-
🧬 微調模型 ({model_id})
|
| 461 |
-
預測: {result_finetuned}
|
| 462 |
-
信心: {probs_finetuned[0][pred_finetuned].item():.2%}
|
| 463 |
機率分布:
|
| 464 |
-
•
|
| 465 |
-
•
|
| 466 |
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
🔬 基準模型(未微調 {config['type']})
|
| 470 |
-
預測: {result_baseline}
|
| 471 |
-
信心: {probs_baseline[0][pred_baseline].item():.2%}
|
| 472 |
機率分布:
|
| 473 |
-
•
|
| 474 |
-
•
|
| 475 |
|
| 476 |
-
{'='*
|
| 477 |
|
| 478 |
-
📊
|
|
|
|
| 479 |
兩模型預測: {agreement}
|
| 480 |
"""
|
| 481 |
|
| 482 |
if pred_finetuned != pred_baseline:
|
| 483 |
-
output += f"
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
output += f"""
|
| 489 |
|
| 490 |
-
📈
|
|
|
|
| 491 |
微調模型 F1: {info['results']['eval_f1']:.4f}
|
| 492 |
基準模型 F1: {info['baseline']['f1']:.4f}
|
| 493 |
-
改善幅度: {
|
|
|
|
|
|
|
|
|
|
| 494 |
"""
|
| 495 |
|
| 496 |
return output
|
| 497 |
|
| 498 |
except Exception as e:
|
| 499 |
import traceback
|
| 500 |
-
return f"❌
|
| 501 |
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
| 503 |
if not trained_models:
|
| 504 |
-
return "❌
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
|
| 511 |
-
for
|
| 512 |
r = info['results']
|
| 513 |
c = info['config']
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
text += "|------|-----|-----|------|--------|------|------|\n"
|
| 521 |
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}'])
|
| 532 |
-
baseline_val = best[1]['baseline'][metric]
|
| 533 |
-
finetuned_val = best[1]['results'][f'eval_{metric}']
|
| 534 |
-
improvement = calculate_improvement(baseline_val, finetuned_val)
|
| 535 |
-
|
| 536 |
-
text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n"
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
# Gradio UI
|
| 544 |
-
with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as demo:
|
| 545 |
-
gr.Markdown("# 🧬 BERT Fine-tuning 教學平台")
|
| 546 |
-
gr.Markdown("### 比較基準模型 vs 微調模型的表現差異")
|
| 547 |
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
info="兩種都是參數高效方法,推薦從 LoRA 開始"
|
| 565 |
)
|
| 566 |
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
gr.Markdown("### 🎯 基本訓練參數")
|
| 573 |
-
with gr.Row():
|
| 574 |
-
num_epochs = gr.Number(value=5, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0,
|
| 575 |
-
info="建議 5-8 輪")
|
| 576 |
-
batch_size = gr.Number(value=4, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0,
|
| 577 |
-
info="記憶體不足時降到 4")
|
| 578 |
-
learning_rate = gr.Number(value=5e-5, label="學習率 (learning_rate)", minimum=0, maximum=1,
|
| 579 |
-
info="5e-5 是平衡選擇")
|
| 580 |
-
|
| 581 |
-
gr.Markdown("### ⚙️ 進階參數")
|
| 582 |
-
with gr.Row():
|
| 583 |
-
weight_decay = gr.Number(value=0.01, label="權重衰減 (weight_decay)", minimum=0, maximum=1)
|
| 584 |
-
dropout = gr.Number(value=0.1, label="Dropout 機率", minimum=0, maximum=1)
|
| 585 |
-
|
| 586 |
-
gr.Markdown("### 🔧 LoRA 參數")
|
| 587 |
-
with gr.Row():
|
| 588 |
-
lora_r = gr.Number(value=32, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0,
|
| 589 |
-
info="提高到 32,增加表達能力")
|
| 590 |
-
lora_alpha = gr.Number(value=64, label="LoRA Alpha", minimum=1, maximum=512, precision=0,
|
| 591 |
-
info="Alpha = Rank × 2")
|
| 592 |
-
lora_dropout = gr.Number(value=0.05, label="LoRA Dropout", minimum=0, maximum=1,
|
| 593 |
-
info="降低 dropout,避免欠擬合")
|
| 594 |
-
|
| 595 |
-
gr.Markdown("### ⚖️ 評估設定")
|
| 596 |
-
with gr.Row():
|
| 597 |
-
weight_mult = gr.Number(value=1.0, label="類別權重倍數", minimum=0, maximum=5,
|
| 598 |
-
info="⚠️ 資料極度不平衡時建議 0.5-1.5,不要超過 2.0")
|
| 599 |
-
best_metric = gr.Dropdown(
|
| 600 |
-
choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
|
| 601 |
-
value="f1",
|
| 602 |
-
label="最佳模型選擇指標",
|
| 603 |
-
info="訓練時用此指標選擇最佳模型"
|
| 604 |
-
)
|
| 605 |
-
|
| 606 |
-
train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
|
| 607 |
-
|
| 608 |
-
gr.Markdown("## 📊 訓練結果")
|
| 609 |
-
|
| 610 |
-
data_info = gr.Textbox(label="📋 資料資訊", lines=10)
|
| 611 |
-
|
| 612 |
-
with gr.Row():
|
| 613 |
-
baseline_result = gr.Textbox(label="🔬 純 BERT(未微調)", lines=14)
|
| 614 |
-
finetuned_result = gr.Textbox(label="✅ 微調 BERT", lines=14)
|
| 615 |
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
train_btn.click(
|
| 619 |
-
train_bert_model,
|
| 620 |
-
inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate,
|
| 621 |
-
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
|
| 622 |
-
weight_mult, best_metric],
|
| 623 |
-
outputs=[data_info, baseline_result, finetuned_result, comparison_result]
|
| 624 |
-
)
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
with gr.Row():
|
| 630 |
-
model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys()))
|
| 631 |
-
refresh = gr.Button("🔄 刷新")
|
| 632 |
-
|
| 633 |
-
text_input = gr.Textbox(label="輸入病例描述", lines=4,
|
| 634 |
-
placeholder="Patient diagnosed with...")
|
| 635 |
-
predict_btn = gr.Button("預測", variant="primary", size="lg")
|
| 636 |
-
pred_output = gr.Textbox(label="預測結果(含基準模型對比)", lines=20)
|
| 637 |
-
|
| 638 |
-
refresh.click(refresh_model_list, outputs=[model_drop])
|
| 639 |
-
predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output])
|
| 640 |
-
|
| 641 |
-
gr.Examples(
|
| 642 |
-
examples=[
|
| 643 |
-
["Patient with stage II breast cancer, good response to treatment."],
|
| 644 |
-
["Advanced metastatic cancer, multiple organ involvement."]
|
| 645 |
-
],
|
| 646 |
-
inputs=text_input
|
| 647 |
-
)
|
| 648 |
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
- ⚠️ **極度不平衡 (>10:1)**: 0.5-1.0(你的情況!)
|
| 698 |
-
- 中度不平衡 (3-10:1): 1.0-1.5
|
| 699 |
-
- 輕度不平衡 (<3:1): 1.5-2.5
|
| 700 |
-
- **Learning rate**: 3e-5 到 5e-5(較高的學習率配合 LoRA)
|
| 701 |
-
- **Epochs**: 5-10(極度不平衡需要更多輪)
|
| 702 |
-
- **Batch size**: 8-16(依 GPU 記憶體調整)
|
| 703 |
-
|
| 704 |
-
### 資料格式
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
|
|
|
|
| 719 |
if __name__ == "__main__":
|
|
|
|
| 720 |
demo.launch(
|
| 721 |
server_name="0.0.0.0",
|
| 722 |
server_port=7860,
|
| 723 |
share=False,
|
| 724 |
-
max_threads=4
|
| 725 |
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
import torch
|
| 5 |
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
|
| 6 |
from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType
|
| 7 |
from datasets import Dataset
|
| 8 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
| 9 |
from torch import nn
|
| 10 |
+
from torch.utils.data import DataLoader, WeightedRandomSampler
|
| 11 |
import os
|
| 12 |
from datetime import datetime
|
| 13 |
import gc
|
| 14 |
+
import json
|
| 15 |
+
from functools import lru_cache
|
| 16 |
+
from typing import Dict, List, Tuple, Optional
|
| 17 |
+
import warnings
|
| 18 |
+
warnings.filterwarnings('ignore')
|
| 19 |
|
| 20 |
+
# 環境設置
|
| 21 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
| 23 |
|
| 24 |
+
# 優化 CUDA 設置
|
| 25 |
torch.backends.cudnn.benchmark = False
|
| 26 |
if torch.cuda.is_available():
|
| 27 |
torch.cuda.empty_cache()
|
| 28 |
|
| 29 |
+
# ==================== 全域變數 ====================
|
| 30 |
trained_models = {}
|
| 31 |
model_counter = 0
|
| 32 |
+
training_histories = {} # 新增:儲存訓練歷史
|
|
|
|
| 33 |
|
| 34 |
+
# ==================== 訓練監控類 ====================
|
| 35 |
+
class TrainingMonitor:
|
| 36 |
+
"""訓練過程監控器"""
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.history = {
|
| 39 |
+
'epoch': [],
|
| 40 |
+
'train_loss': [],
|
| 41 |
+
'eval_loss': [],
|
| 42 |
+
'eval_accuracy': [],
|
| 43 |
+
'eval_f1': [],
|
| 44 |
+
'eval_precision': [],
|
| 45 |
+
'eval_recall': [],
|
| 46 |
+
'learning_rate': [],
|
| 47 |
+
'best_epoch': None,
|
| 48 |
+
'best_metric_value': None
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def log_epoch(self, epoch: int, train_loss: float, eval_metrics: Dict, lr: float):
|
| 52 |
+
"""記錄每個 epoch 的結果"""
|
| 53 |
+
self.history['epoch'].append(epoch)
|
| 54 |
+
self.history['train_loss'].append(train_loss)
|
| 55 |
+
self.history['eval_loss'].append(eval_metrics.get('eval_loss', 0))
|
| 56 |
+
self.history['eval_accuracy'].append(eval_metrics.get('eval_accuracy', 0))
|
| 57 |
+
self.history['eval_f1'].append(eval_metrics.get('eval_f1', 0))
|
| 58 |
+
self.history['eval_precision'].append(eval_metrics.get('eval_precision', 0))
|
| 59 |
+
self.history['eval_recall'].append(eval_metrics.get('eval_recall', 0))
|
| 60 |
+
self.history['learning_rate'].append(lr)
|
| 61 |
+
|
| 62 |
+
def update_best(self, epoch: int, metric_value: float):
|
| 63 |
+
"""更新最佳結果"""
|
| 64 |
+
self.history['best_epoch'] = epoch
|
| 65 |
+
self.history['best_metric_value'] = metric_value
|
| 66 |
+
|
| 67 |
+
def get_summary(self) -> str:
|
| 68 |
+
"""獲取訓練摘要"""
|
| 69 |
+
if not self.history['epoch']:
|
| 70 |
+
return "尚無訓練記錄"
|
| 71 |
+
|
| 72 |
+
summary = "📈 訓練歷程摘要\n"
|
| 73 |
+
summary += f"總訓練輪數: {len(self.history['epoch'])}\n"
|
| 74 |
+
summary += f"最佳 Epoch: {self.history['best_epoch']}\n"
|
| 75 |
+
summary += f"最佳指標值: {self.history['best_metric_value']:.4f}\n\n"
|
| 76 |
+
|
| 77 |
+
summary += "各 Epoch 表現:\n"
|
| 78 |
+
for i, epoch in enumerate(self.history['epoch']):
|
| 79 |
+
summary += f"Epoch {epoch}: Loss={self.history['train_loss'][i]:.4f}, "
|
| 80 |
+
summary += f"F1={self.history['eval_f1'][i]:.4f}, "
|
| 81 |
+
summary += f"Acc={self.history['eval_accuracy'][i]:.4f}\n"
|
| 82 |
+
|
| 83 |
+
return summary
|
| 84 |
|
| 85 |
+
# ==================== 權重計算改進 ====================
|
| 86 |
+
def calculate_class_weights(n0: int, n1: int, weight_mult: float = 1.0,
|
| 87 |
+
method: str = 'sqrt') -> Tuple[float, float]:
|
| 88 |
+
"""
|
| 89 |
+
改進的類別權重計算
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
n0: 負類樣本數(存活)
|
| 93 |
+
n1: 正類樣本數(死亡)
|
| 94 |
+
weight_mult: 權重倍數調整
|
| 95 |
+
method: 計算方法 ('balanced', 'sqrt', 'log', 'custom')
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
(w0, w1): 類別權重
|
| 99 |
+
"""
|
| 100 |
+
if n1 == 0:
|
| 101 |
+
return 1.0, 1.0
|
| 102 |
+
|
| 103 |
+
ratio = n0 / n1
|
| 104 |
+
total = n0 + n1
|
| 105 |
+
|
| 106 |
+
if method == 'balanced':
|
| 107 |
+
# sklearn 風格的平衡權重
|
| 108 |
+
w0 = total / (2 * n0) if n0 > 0 else 1.0
|
| 109 |
+
w1 = total / (2 * n1) if n1 > 0 else 1.0
|
| 110 |
+
w1 *= weight_mult
|
| 111 |
+
elif method == 'sqrt':
|
| 112 |
+
# 使用平方根緩和極端權重(推薦用於極度不平衡)
|
| 113 |
+
w0 = 1.0
|
| 114 |
+
w1 = min(np.sqrt(ratio) * weight_mult, 10.0) # 設置上限為 10
|
| 115 |
+
elif method == 'log':
|
| 116 |
+
# 使用對數進一步緩和
|
| 117 |
+
w0 = 1.0
|
| 118 |
+
w1 = min(np.log1p(ratio) * weight_mult, 8.0) # 設置上限為 8
|
| 119 |
+
elif method == 'custom':
|
| 120 |
+
# 自定義邏輯,根據不平衡程度調整
|
| 121 |
+
if ratio > 20: # 極���不平衡
|
| 122 |
+
w0 = 1.0
|
| 123 |
+
w1 = min(5.0 * weight_mult, 10.0)
|
| 124 |
+
elif ratio > 10: # 高度不平衡
|
| 125 |
+
w0 = 1.0
|
| 126 |
+
w1 = min(ratio * 0.3 * weight_mult, 8.0)
|
| 127 |
+
elif ratio > 5: # 中度不平衡
|
| 128 |
+
w0 = 1.0
|
| 129 |
+
w1 = min(ratio * 0.5 * weight_mult, 6.0)
|
| 130 |
+
else: # 輕度不平衡
|
| 131 |
+
w0 = 1.0
|
| 132 |
+
w1 = ratio * weight_mult
|
| 133 |
+
else:
|
| 134 |
+
# 預設使用 sqrt 方法
|
| 135 |
+
w0 = 1.0
|
| 136 |
+
w1 = min(np.sqrt(ratio) * weight_mult, 10.0)
|
| 137 |
+
|
| 138 |
+
return w0, w1
|
| 139 |
|
| 140 |
+
# ==================== 評估指標計算 ====================
|
| 141 |
def compute_metrics(pred):
|
| 142 |
+
"""計算完整的評估指標"""
|
| 143 |
try:
|
| 144 |
labels = pred.label_ids
|
| 145 |
preds = pred.predictions.argmax(-1)
|
| 146 |
+
|
| 147 |
+
# 基本指標
|
| 148 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 149 |
+
labels, preds, average='binary', pos_label=1, zero_division=0
|
| 150 |
+
)
|
| 151 |
acc = accuracy_score(labels, preds)
|
| 152 |
+
|
| 153 |
+
# 混淆矩陣
|
| 154 |
cm = confusion_matrix(labels, preds)
|
| 155 |
+
tn = fp = fn = tp = 0
|
| 156 |
if cm.shape == (2, 2):
|
| 157 |
tn, fp, fn, tp = cm.ravel()
|
| 158 |
+
|
| 159 |
+
# 敏感度和特異度
|
| 160 |
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 161 |
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 162 |
+
|
| 163 |
+
# 額外指標
|
| 164 |
+
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 # 陽性預測值
|
| 165 |
+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0 # 陰性預測值
|
| 166 |
+
|
| 167 |
return {
|
| 168 |
+
'accuracy': acc,
|
| 169 |
+
'f1': f1,
|
| 170 |
+
'precision': precision,
|
| 171 |
+
'recall': recall,
|
| 172 |
+
'sensitivity': sensitivity,
|
| 173 |
+
'specificity': specificity,
|
| 174 |
+
'ppv': ppv,
|
| 175 |
+
'npv': npv,
|
| 176 |
+
'tp': int(tp),
|
| 177 |
+
'tn': int(tn),
|
| 178 |
+
'fp': int(fp),
|
| 179 |
+
'fn': int(fn)
|
| 180 |
}
|
| 181 |
except Exception as e:
|
| 182 |
print(f"Error in compute_metrics: {e}")
|
| 183 |
+
return {k: 0 for k in ['accuracy', 'f1', 'precision', 'recall',
|
| 184 |
+
'sensitivity', 'specificity', 'ppv', 'npv',
|
| 185 |
+
'tp', 'tn', 'fp', 'fn']}
|
|
|
|
| 186 |
|
| 187 |
+
# ==================== 基準模型評估(修正版,只保留一個) ====================
|
| 188 |
+
def evaluate_baseline(model, tokenizer, test_dataset, device, batch_size=16):
|
| 189 |
"""評估未微調的基準模型"""
|
| 190 |
model.eval()
|
| 191 |
all_preds = []
|
| 192 |
all_labels = []
|
| 193 |
|
|
|
|
|
|
|
| 194 |
def collate_fn(batch):
|
| 195 |
return {
|
| 196 |
'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
|
|
|
|
| 198 |
'labels': torch.tensor([item['label'] for item in batch])
|
| 199 |
}
|
| 200 |
|
| 201 |
+
dataloader = DataLoader(
|
| 202 |
+
test_dataset,
|
| 203 |
+
batch_size=batch_size,
|
| 204 |
+
collate_fn=collate_fn,
|
| 205 |
+
pin_memory=torch.cuda.is_available(),
|
| 206 |
+
num_workers=0 # 避免多進程問題
|
| 207 |
+
)
|
| 208 |
|
| 209 |
with torch.no_grad():
|
| 210 |
for batch in dataloader:
|
|
|
|
| 215 |
all_preds.extend(preds.cpu().numpy())
|
| 216 |
all_labels.extend(labels.numpy())
|
| 217 |
|
| 218 |
+
# 計算所有指標
|
| 219 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 220 |
+
all_labels, all_preds, average='binary', pos_label=1, zero_division=0
|
| 221 |
+
)
|
| 222 |
acc = accuracy_score(all_labels, all_preds)
|
| 223 |
+
|
| 224 |
cm = confusion_matrix(all_labels, all_preds)
|
| 225 |
+
tn = fp = fn = tp = 0
|
| 226 |
if cm.shape == (2, 2):
|
| 227 |
tn, fp, fn, tp = cm.ravel()
|
| 228 |
+
|
|
|
|
| 229 |
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 230 |
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 231 |
+
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 232 |
+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0
|
| 233 |
|
| 234 |
return {
|
| 235 |
+
'accuracy': acc,
|
| 236 |
+
'f1': f1,
|
| 237 |
+
'precision': precision,
|
| 238 |
+
'recall': recall,
|
| 239 |
+
'sensitivity': sensitivity,
|
| 240 |
+
'specificity': specificity,
|
| 241 |
+
'ppv': ppv,
|
| 242 |
+
'npv': npv,
|
| 243 |
+
'tp': int(tp),
|
| 244 |
+
'tn': int(tn),
|
| 245 |
+
'fp': int(fp),
|
| 246 |
+
'fn': int(fn)
|
| 247 |
}
|
| 248 |
|
| 249 |
+
# ==================== 自定義 Trainer 與 Early Stopping ====================
|
| 250 |
+
class CustomTrainer(Trainer):
|
| 251 |
+
"""支援類別權重、Focal Loss 和 Early Stopping 的 Trainer"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, *args, class_weights=None, use_focal_loss=False,
|
| 254 |
+
focal_gamma=2.0, monitor=None, early_stopping_patience=3,
|
| 255 |
+
early_stopping_metric='eval_f1', **kwargs):
|
| 256 |
super().__init__(*args, **kwargs)
|
| 257 |
self.class_weights = class_weights
|
| 258 |
self.use_focal_loss = use_focal_loss
|
| 259 |
+
self.focal_gamma = focal_gamma
|
| 260 |
+
self.monitor = monitor
|
| 261 |
+
self.early_stopping_patience = early_stopping_patience
|
| 262 |
+
self.early_stopping_metric = early_stopping_metric
|
| 263 |
+
self.best_metric = -float('inf')
|
| 264 |
+
self.best_model_state = None
|
| 265 |
+
self.patience_counter = 0
|
| 266 |
+
self.current_epoch = 0
|
| 267 |
|
| 268 |
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 269 |
+
"""計算損失函數"""
|
| 270 |
labels = inputs.pop("labels")
|
| 271 |
outputs = model(**inputs)
|
| 272 |
logits = outputs.logits
|
| 273 |
|
| 274 |
+
if self.use_focal_loss and self.class_weights is not None:
|
| 275 |
+
# Focal Loss 實現
|
| 276 |
ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')(
|
| 277 |
logits.view(-1, 2), labels.view(-1)
|
| 278 |
)
|
| 279 |
pt = torch.exp(-ce_loss)
|
| 280 |
+
focal_loss = ((1 - pt) ** self.focal_gamma * ce_loss).mean()
|
| 281 |
loss = focal_loss
|
| 282 |
+
elif self.class_weights is not None:
|
| 283 |
+
# 標準加權交叉熵
|
| 284 |
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
|
| 285 |
loss = loss_fct(logits.view(-1, 2), labels.view(-1))
|
| 286 |
+
else:
|
| 287 |
+
# 標準交叉熵
|
| 288 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 289 |
+
loss = loss_fct(logits.view(-1, 2), labels.view(-1))
|
| 290 |
|
| 291 |
return (loss, outputs) if return_outputs else loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
| 294 |
+
"""每個 epoch 結束時的回調"""
|
| 295 |
+
self.current_epoch += 1
|
| 296 |
+
|
| 297 |
+
# 評估模型
|
| 298 |
+
metrics = self.evaluate()
|
| 299 |
+
|
| 300 |
+
# 記錄到監控器
|
| 301 |
+
if self.monitor:
|
| 302 |
+
self.monitor.log_epoch(
|
| 303 |
+
epoch=self.current_epoch,
|
| 304 |
+
train_loss=state.log_history[-1].get('loss', 0) if state.log_history else 0,
|
| 305 |
+
eval_metrics=metrics,
|
| 306 |
+
lr=self.get_learning_rate()
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Early Stopping 檢查
|
| 310 |
+
current_metric = metrics.get(self.early_stopping_metric, 0)
|
| 311 |
+
|
| 312 |
+
if current_metric > self.best_metric:
|
| 313 |
+
self.best_metric = current_metric
|
| 314 |
+
self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
|
| 315 |
+
self.patience_counter = 0
|
| 316 |
+
|
| 317 |
+
if self.monitor:
|
| 318 |
+
self.monitor.update_best(self.current_epoch, current_metric)
|
| 319 |
+
|
| 320 |
+
print(f"✅ Epoch {self.current_epoch}: 新最佳 {self.early_stopping_metric} = {current_metric:.4f}")
|
| 321 |
+
else:
|
| 322 |
+
self.patience_counter += 1
|
| 323 |
+
print(f"⏳ Epoch {self.current_epoch}: 無改善 (patience: {self.patience_counter}/{self.early_stopping_patience})")
|
| 324 |
+
|
| 325 |
+
if self.patience_counter >= self.early_stopping_patience:
|
| 326 |
+
print(f"🛑 Early Stopping 於 Epoch {self.current_epoch}")
|
| 327 |
+
control.should_training_stop = True
|
| 328 |
+
|
| 329 |
+
return control
|
| 330 |
|
| 331 |
+
def get_learning_rate(self):
|
| 332 |
+
"""獲取當前學習率"""
|
| 333 |
+
if self.optimizer is None:
|
| 334 |
+
return 0
|
| 335 |
+
return self.optimizer.param_groups[0]['lr']
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
+
def load_best_model(self):
|
| 338 |
+
"""載入最佳模型"""
|
| 339 |
+
if self.best_model_state:
|
| 340 |
+
self.model.load_state_dict(self.best_model_state)
|
| 341 |
+
print(f"✅ 已載入最佳模型 (最佳 {self.early_stopping_metric} = {self.best_metric:.4f})")
|
| 342 |
+
|
| 343 |
+
# ==================== 基準模型快取(改進版) ====================
|
| 344 |
+
@lru_cache(maxsize=3)
|
| 345 |
+
def get_cached_baseline_model(model_name: str, num_labels: int = 2):
|
| 346 |
+
"""使用 LRU 快取管理基準模型"""
|
| 347 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 348 |
+
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
|
| 349 |
+
return model.to(device)
|
| 350 |
+
|
| 351 |
+
# ==================== 改善率計算 ====================
|
| 352 |
+
def calculate_improvement(baseline_val: float, finetuned_val: float) -> float:
|
| 353 |
+
"""安全計算改善率"""
|
| 354 |
+
if baseline_val == 0:
|
| 355 |
+
return float('inf') if finetuned_val > 0 else 0.0
|
| 356 |
+
return (finetuned_val - baseline_val) / baseline_val * 100
|
| 357 |
+
|
| 358 |
+
def format_improvement(val: float) -> str:
|
| 359 |
+
"""格式化改善率顯示"""
|
| 360 |
+
if val == float('inf'):
|
| 361 |
+
return "N/A (baseline=0)"
|
| 362 |
+
elif val > 0:
|
| 363 |
+
return f"↑ {val:.1f}%"
|
| 364 |
+
elif val < 0:
|
| 365 |
+
return f"↓ {abs(val):.1f}%"
|
| 366 |
else:
|
| 367 |
+
return "→ 0.0%"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
+
# ==================== 主要訓練函數(改進版) ====================
|
| 370 |
+
def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
|
| 371 |
+
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
|
| 372 |
+
weight_mult, weight_method, best_metric, use_early_stopping, patience):
|
| 373 |
+
"""
|
| 374 |
+
改進的 BERT 模型訓練函數
|
| 375 |
+
"""
|
| 376 |
+
global trained_models, model_counter, training_histories
|
| 377 |
|
| 378 |
model_mapping = {
|
| 379 |
"BERT-base": "bert-base-uncased",
|
| 380 |
+
"BERT-base-chinese": "bert-base-chinese",
|
| 381 |
+
"BioBERT": "dmis-lab/biobert-base-cased-v1.2",
|
| 382 |
+
"SciBERT": "allenai/scibert_scivocab_uncased"
|
| 383 |
}
|
| 384 |
|
| 385 |
model_name = model_mapping.get(base_model, "bert-base-uncased")
|
| 386 |
|
| 387 |
try:
|
| 388 |
+
# ========== 資料驗證與載入 ==========
|
| 389 |
if csv_file is None:
|
| 390 |
+
return "❌ 請上傳 CSV 檔案", "", "", "", ""
|
| 391 |
|
| 392 |
df = pd.read_csv(csv_file.name)
|
| 393 |
if 'Text' not in df.columns or 'label' not in df.columns:
|
| 394 |
+
return "❌ CSV 必須包含 'Text' 和 'label' 欄位", "", "", "", ""
|
| 395 |
|
| 396 |
+
# 資料清理
|
| 397 |
df_clean = pd.DataFrame({
|
| 398 |
+
'text': df['Text'].astype(str),
|
| 399 |
'label': df['label'].astype(int)
|
| 400 |
}).dropna()
|
| 401 |
|
| 402 |
+
# 統計資料
|
| 403 |
n0 = int(sum(df_clean['label'] == 0))
|
| 404 |
n1 = int(sum(df_clean['label'] == 1))
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
if n1 == 0:
|
| 407 |
+
return "❌ 資料集中沒有正類樣本(死亡)", "", "", "", ""
|
| 408 |
+
|
| 409 |
+
ratio = n0 / n1 if n1 > 0 else 0
|
| 410 |
+
|
| 411 |
+
# ========== 計算類別權重 ==========
|
| 412 |
+
w0, w1 = calculate_class_weights(n0, n1, weight_mult, method=weight_method)
|
| 413 |
+
|
| 414 |
+
# ========== 準備資料資訊 ==========
|
| 415 |
+
info = f"📊 資料集統計\n"
|
| 416 |
+
info += f"{'='*50}\n"
|
| 417 |
+
info += f"總樣本數: {len(df_clean):,}\n"
|
| 418 |
+
info += f"存活 (0): {n0:,} ({n0/len(df_clean)*100:.1f}%)\n"
|
| 419 |
+
info += f"死亡 (1): {n1:,} ({n1/len(df_clean)*100:.1f}%)\n"
|
| 420 |
+
info += f"不平衡比例: {ratio:.2f}:1\n"
|
| 421 |
+
info += f"\n⚖️ 類別權重設定\n"
|
| 422 |
+
info += f"{'='*50}\n"
|
| 423 |
+
info += f"計算方法: {weight_method}\n"
|
| 424 |
+
info += f"存活權重: {w0:.3f}\n"
|
| 425 |
+
info += f"死亡權重: {w1:.3f}\n"
|
| 426 |
+
info += f"權重比例: 1:{w1/w0:.2f}\n"
|
| 427 |
+
|
| 428 |
+
# ========== 模型與分詞器初始化 ==========
|
| 429 |
+
info += f"\n🤖 模型配置\n"
|
| 430 |
+
info += f"{'='*50}\n"
|
| 431 |
+
info += f"基礎模型: {base_model}\n"
|
| 432 |
+
info += f"模型路徑: {model_name}\n"
|
| 433 |
+
info += f"微調方法: {method.upper()}\n"
|
| 434 |
|
| 435 |
tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 436 |
+
|
| 437 |
+
# ========== 資料集準備 ==========
|
| 438 |
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
|
| 439 |
|
| 440 |
+
def preprocess(examples):
|
| 441 |
+
return tokenizer(
|
| 442 |
+
examples['text'],
|
| 443 |
+
truncation=True,
|
| 444 |
+
padding='max_length',
|
| 445 |
+
max_length=128
|
| 446 |
+
)
|
| 447 |
|
| 448 |
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
|
| 449 |
+
split = tokenized.train_test_split(test_size=0.2, seed=42, stratify=tokenized['label'])
|
| 450 |
|
| 451 |
+
# ========== 設備配置 ==========
|
| 452 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 453 |
+
info += f"運算設備: {'GPU ✅ (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU ⚠️'}\n"
|
| 454 |
+
|
| 455 |
+
# ========== 評估基準模型 ==========
|
| 456 |
+
info += f"\n📏 基準模型評估\n"
|
| 457 |
+
info += f"{'='*50}\n"
|
| 458 |
+
info += f"正在評估未微調的 {base_model}...\n"
|
| 459 |
|
|
|
|
| 460 |
baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
| 461 |
baseline_model = baseline_model.to(device)
|
| 462 |
|
| 463 |
+
baseline_perf = evaluate_baseline(
|
| 464 |
+
baseline_model, tokenizer, split['test'], device, batch_size=batch_size*2
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
info += f"基準 F1 分數: {baseline_perf['f1']:.4f}\n"
|
| 468 |
+
info += f"基準準確率: {baseline_perf['accuracy']:.4f}\n"
|
| 469 |
|
| 470 |
+
# 清理基準模型記憶體
|
| 471 |
del baseline_model
|
| 472 |
if torch.cuda.is_available():
|
| 473 |
torch.cuda.empty_cache()
|
| 474 |
gc.collect()
|
| 475 |
|
| 476 |
+
# ========== 配置微調模型 ==========
|
| 477 |
+
info += f"\n🔧 微調配置\n"
|
| 478 |
+
info += f"{'='*50}\n"
|
| 479 |
+
|
| 480 |
model = BertForSequenceClassification.from_pretrained(
|
| 481 |
+
model_name,
|
| 482 |
+
num_labels=2,
|
| 483 |
hidden_dropout_prob=dropout,
|
| 484 |
attention_probs_dropout_prob=dropout
|
| 485 |
)
|
| 486 |
|
| 487 |
+
# 應用 PEFT 方法
|
| 488 |
peft_applied = False
|
| 489 |
if method == "lora":
|
| 490 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 491 |
+
|
| 492 |
config = LoraConfig(
|
| 493 |
+
task_type=TaskType.SEQ_CLS,
|
| 494 |
+
r=int(lora_r),
|
| 495 |
lora_alpha=int(lora_alpha),
|
| 496 |
+
lora_dropout=lora_dropout,
|
| 497 |
+
target_modules=["query", "value"],
|
| 498 |
bias="none"
|
| 499 |
)
|
| 500 |
model = get_peft_model(model, config)
|
| 501 |
peft_applied = True
|
| 502 |
+
info += f"✅ LoRA 已套用\n"
|
| 503 |
+
info += f" - Rank (r): {int(lora_r)}\n"
|
| 504 |
+
info += f" - Alpha: {int(lora_alpha)}\n"
|
| 505 |
+
info += f" - Dropout: {lora_dropout}\n"
|
| 506 |
+
|
| 507 |
elif method == "adalora":
|
| 508 |
+
from peft import AdaLoraConfig, get_peft_model, TaskType
|
| 509 |
+
|
| 510 |
config = AdaLoraConfig(
|
| 511 |
+
task_type=TaskType.SEQ_CLS,
|
| 512 |
+
r=int(lora_r),
|
| 513 |
lora_alpha=int(lora_alpha),
|
| 514 |
+
lora_dropout=lora_dropout,
|
| 515 |
target_modules=["query", "value"],
|
| 516 |
+
init_r=12,
|
| 517 |
+
target_r=int(lora_r),
|
| 518 |
+
tinit=200,
|
| 519 |
+
tfinal=1000,
|
| 520 |
+
deltaT=10
|
| 521 |
)
|
| 522 |
model = get_peft_model(model, config)
|
| 523 |
peft_applied = True
|
| 524 |
+
info += f"✅ AdaLoRA 已套用\n"
|
| 525 |
+
info += f" - Initial Rank: 12\n"
|
| 526 |
+
info += f" - Target Rank: {int(lora_r)}\n"
|
| 527 |
+
info += f" - Alpha: {int(lora_alpha)}\n"
|
| 528 |
|
| 529 |
+
elif method == "full":
|
| 530 |
+
info += f"✅ Full Fine-tuning 模式\n"
|
| 531 |
+
peft_applied = False
|
| 532 |
|
| 533 |
model = model.to(device)
|
| 534 |
|
| 535 |
+
# 參數統計
|
| 536 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 537 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 538 |
+
|
| 539 |
+
info += f"\n💾 模型參數\n"
|
| 540 |
+
info += f"{'='*50}\n"
|
| 541 |
+
info += f"總參數量: {total_params:,}\n"
|
| 542 |
+
info += f"可訓練參數: {trainable_params:,}\n"
|
| 543 |
+
info += f"可訓練比例: {trainable_params/total_params*100:.2f}%\n"
|
| 544 |
+
info += f"記憶體節省: {(1 - trainable_params/total_params)*100:.1f}%\n"
|
| 545 |
|
| 546 |
+
# ========== 準備訓練 ==========
|
| 547 |
weights = torch.tensor([w0, w1], dtype=torch.float).to(device)
|
| 548 |
+
use_focal = ratio > 10 # 極度不平衡時使用 Focal Loss
|
| 549 |
|
| 550 |
+
if use_focal:
|
| 551 |
+
info += f"\n⚡ 特殊設定\n"
|
| 552 |
+
info += f"{'='*50}\n"
|
| 553 |
+
info += f"使用 Focal Loss (γ=2.0) 處理極度不平衡\n"
|
| 554 |
+
|
| 555 |
+
# 訓練參數
|
| 556 |
+
training_args = TrainingArguments(
|
| 557 |
+
output_dir='./results',
|
| 558 |
num_train_epochs=int(num_epochs),
|
| 559 |
+
per_device_train_batch_size=int(batch_size),
|
| 560 |
+
per_device_eval_batch_size=int(batch_size) * 2,
|
| 561 |
+
learning_rate=float(learning_rate),
|
| 562 |
weight_decay=float(weight_decay),
|
| 563 |
+
evaluation_strategy="epoch",
|
| 564 |
+
save_strategy="no", # 使用自定義保存策略
|
| 565 |
+
load_best_model_at_end=False,
|
| 566 |
+
report_to="none",
|
| 567 |
+
logging_steps=max(1, len(split['train']) // (int(batch_size) * 10)),
|
| 568 |
+
warmup_steps=min(500, len(split['train']) // int(batch_size)),
|
| 569 |
+
logging_first_step=True,
|
| 570 |
+
remove_unused_columns=False,
|
| 571 |
+
label_smoothing_factor=0.1 if ratio > 20 else 0.0, # 極度不平衡時使用標籤平滑
|
| 572 |
)
|
| 573 |
|
| 574 |
+
# 創建監控器
|
| 575 |
+
monitor = TrainingMonitor()
|
| 576 |
+
|
| 577 |
+
# 創建自定義 Trainer
|
| 578 |
+
trainer = CustomTrainer(
|
| 579 |
+
model=model,
|
| 580 |
+
args=training_args,
|
| 581 |
train_dataset=split['train'],
|
| 582 |
+
eval_dataset=split['test'],
|
| 583 |
compute_metrics=compute_metrics,
|
| 584 |
class_weights=weights,
|
| 585 |
+
use_focal_loss=use_focal,
|
| 586 |
+
focal_gamma=2.0,
|
| 587 |
+
monitor=monitor,
|
| 588 |
+
early_stopping_patience=patience if use_early_stopping else 999,
|
| 589 |
+
early_stopping_metric=f'eval_{best_metric}'
|
| 590 |
)
|
| 591 |
|
| 592 |
+
info += f"\n🚀 訓練設定\n"
|
| 593 |
+
info += f"{'='*50}\n"
|
| 594 |
+
info += f"訓練樣本: {len(split['train']):,}\n"
|
| 595 |
+
info += f"測試樣本: {len(split['test']):,}\n"
|
| 596 |
+
info += f"批次大小: {int(batch_size)}\n"
|
| 597 |
+
info += f"訓練輪數: {int(num_epochs)}\n"
|
| 598 |
+
info += f"批次數/輪: {len(split['train']) // int(batch_size)}\n"
|
| 599 |
+
info += f"Early Stopping: {'開啟 (patience=' + str(patience) + ')' if use_early_stopping else '關閉'}\n"
|
| 600 |
+
info += f"最佳指標: {best_metric}\n"
|
| 601 |
|
| 602 |
+
info += f"\n⏳ 開始訓練...\n"
|
| 603 |
+
info += f"{'='*50}\n"
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
+
# ========== 執行訓練 ==========
|
| 606 |
train_result = trainer.train()
|
| 607 |
|
| 608 |
+
# 載入最佳模型
|
| 609 |
+
if use_early_stopping:
|
| 610 |
+
trainer.load_best_model()
|
| 611 |
|
| 612 |
+
# 最終評估
|
| 613 |
+
final_results = trainer.evaluate()
|
| 614 |
|
| 615 |
+
# ========== 保存模型與結果 ==========
|
| 616 |
model_counter += 1
|
| 617 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 618 |
+
model_id = f"{base_model}_{method}_{model_counter}_{timestamp}"
|
| 619 |
+
|
| 620 |
trained_models[model_id] = {
|
| 621 |
+
'model': model,
|
| 622 |
+
'tokenizer': tokenizer,
|
| 623 |
+
'results': final_results,
|
| 624 |
+
'baseline': baseline_perf,
|
| 625 |
'config': {
|
| 626 |
+
'type': base_model,
|
| 627 |
+
'model_name': model_name,
|
| 628 |
+
'method': method,
|
| 629 |
+
'metric': best_metric,
|
| 630 |
+
'epochs': int(num_epochs),
|
| 631 |
+
'batch_size': int(batch_size),
|
| 632 |
+
'learning_rate': float(learning_rate),
|
| 633 |
+
'weight_method': weight_method,
|
| 634 |
+
'weight_mult': weight_mult
|
| 635 |
},
|
| 636 |
+
'timestamp': timestamp,
|
| 637 |
+
'monitor': monitor # 保存訓練歷史
|
| 638 |
}
|
| 639 |
|
| 640 |
+
training_histories[model_id] = monitor.history
|
| 641 |
+
|
| 642 |
+
info += f"\n✅ 訓練完成!\n"
|
| 643 |
+
info += f"最終 Training Loss: {train_result.training_loss:.4f}\n"
|
| 644 |
+
if monitor.history['best_epoch']:
|
| 645 |
+
info += f"最佳 Epoch: {monitor.history['best_epoch']}\n"
|
| 646 |
+
|
| 647 |
+
# ========== 準備輸出結果 ==========
|
| 648 |
+
# 基準模型結果
|
| 649 |
+
baseline_output = format_baseline_results(baseline_perf)
|
| 650 |
+
|
| 651 |
+
# 微調模型結果
|
| 652 |
+
finetuned_output = format_finetuned_results(model_id, final_results)
|
| 653 |
+
|
| 654 |
+
# 比較結果
|
| 655 |
+
comparison_output = format_comparison_results(baseline_perf, final_results)
|
| 656 |
+
|
| 657 |
+
# 訓練歷程
|
| 658 |
+
history_output = monitor.get_summary()
|
| 659 |
+
|
| 660 |
+
return info, baseline_output, finetuned_output, comparison_output, history_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
|
| 662 |
except Exception as e:
|
| 663 |
import traceback
|
| 664 |
+
error_msg = f"❌ 錯誤發生\n\n錯誤類型: {type(e).__name__}\n錯誤訊息: {str(e)}\n\n"
|
| 665 |
+
error_msg += f"詳細追蹤:\n{traceback.format_exc()}"
|
| 666 |
+
return error_msg, "", "", "", ""
|
| 667 |
+
|
| 668 |
+
# ==================== 格式化輸出函數 ====================
|
| 669 |
+
def format_baseline_results(baseline_perf: Dict) -> str:
|
| 670 |
+
"""格式化基準模型結果"""
|
| 671 |
+
output = "🔬 純 BERT(未微調)\n\n"
|
| 672 |
+
output += "📊 模型表現\n"
|
| 673 |
+
output += f"{'='*30}\n"
|
| 674 |
+
output += f"F1 Score: {baseline_perf['f1']:.4f}\n"
|
| 675 |
+
output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
|
| 676 |
+
output += f"Precision: {baseline_perf['precision']:.4f}\n"
|
| 677 |
+
output += f"Recall: {baseline_perf['recall']:.4f}\n"
|
| 678 |
+
output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n"
|
| 679 |
+
output += f"Specificity: {baseline_perf['specificity']:.4f}\n"
|
| 680 |
+
output += f"PPV: {baseline_perf['ppv']:.4f}\n"
|
| 681 |
+
output += f"NPV: {baseline_perf['npv']:.4f}\n\n"
|
| 682 |
+
output += "📈 混淆矩陣\n"
|
| 683 |
+
output += f"{'='*30}\n"
|
| 684 |
+
output += f" 預測 0 預測 1\n"
|
| 685 |
+
output += f"實際 0 {baseline_perf['tn']:4d} {baseline_perf['fp']:4d}\n"
|
| 686 |
+
output += f"實際 1 {baseline_perf['fn']:4d} {baseline_perf['tp']:4d}\n"
|
| 687 |
+
return output
|
| 688 |
+
|
| 689 |
+
def format_finetuned_results(model_id: str, results: Dict) -> str:
|
| 690 |
+
"""格式化微調模型結果"""
|
| 691 |
+
output = f"✅ 微調 BERT\n"
|
| 692 |
+
output += f"模型 ID: {model_id}\n\n"
|
| 693 |
+
output += "📊 模型表現\n"
|
| 694 |
+
output += f"{'='*30}\n"
|
| 695 |
+
output += f"F1 Score: {results['eval_f1']:.4f}\n"
|
| 696 |
+
output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
|
| 697 |
+
output += f"Precision: {results['eval_precision']:.4f}\n"
|
| 698 |
+
output += f"Recall: {results['eval_recall']:.4f}\n"
|
| 699 |
+
output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n"
|
| 700 |
+
output += f"Specificity: {results['eval_specificity']:.4f}\n"
|
| 701 |
+
output += f"PPV: {results['eval_ppv']:.4f}\n"
|
| 702 |
+
output += f"NPV: {results['eval_npv']:.4f}\n\n"
|
| 703 |
+
output += "📈 混淆矩陣\n"
|
| 704 |
+
output += f"{'='*30}\n"
|
| 705 |
+
output += f" 預測 0 預測 1\n"
|
| 706 |
+
output += f"實際 0 {results['eval_tn']:4d} {results['eval_fp']:4d}\n"
|
| 707 |
+
output += f"實際 1 {results['eval_fn']:4d} {results['eval_tp']:4d}\n"
|
| 708 |
+
return output
|
| 709 |
+
|
| 710 |
+
def format_comparison_results(baseline_perf: Dict, finetuned_results: Dict) -> str:
|
| 711 |
+
"""格式化比較結果"""
|
| 712 |
+
output = "📊 純 BERT vs 微調 BERT 比較\n\n"
|
| 713 |
+
output += "指標改善分析:\n"
|
| 714 |
+
output += f"{'='*50}\n"
|
| 715 |
+
output += f"{'指標':<12} {'基準':>8} {'微調':>8} {'變化':>10} {'改善率':>10}\n"
|
| 716 |
+
output += f"{'-'*50}\n"
|
| 717 |
+
|
| 718 |
+
metrics = [
|
| 719 |
+
('F1', 'f1', 'eval_f1'),
|
| 720 |
+
('Accuracy', 'accuracy', 'eval_accuracy'),
|
| 721 |
+
('Precision', 'precision', 'eval_precision'),
|
| 722 |
+
('Recall', 'recall', 'eval_recall'),
|
| 723 |
+
('Sensitivity', 'sensitivity', 'eval_sensitivity'),
|
| 724 |
+
('Specificity', 'specificity', 'eval_specificity'),
|
| 725 |
+
('PPV', 'ppv', 'eval_ppv'),
|
| 726 |
+
('NPV', 'npv', 'eval_npv')
|
| 727 |
+
]
|
| 728 |
+
|
| 729 |
+
for name, base_key, fine_key in metrics:
|
| 730 |
+
base_val = baseline_perf[base_key]
|
| 731 |
+
fine_val = finetuned_results[fine_key]
|
| 732 |
+
change = fine_val - base_val
|
| 733 |
+
improve = calculate_improvement(base_val, fine_val)
|
| 734 |
+
|
| 735 |
+
output += f"{name:<12} {base_val:>8.4f} {fine_val:>8.4f} "
|
| 736 |
+
output += f"{change:+10.4f} {format_improvement(improve):>10}\n"
|
| 737 |
+
|
| 738 |
+
output += f"\n混淆矩陣變化:\n"
|
| 739 |
+
output += f"{'='*40}\n"
|
| 740 |
+
output += f"{'項目':<10} {'基準':>8} {'微調':>8} {'變化':>10}\n"
|
| 741 |
+
output += f"{'-'*40}\n"
|
| 742 |
+
|
| 743 |
+
cm_items = [
|
| 744 |
+
('True Pos', 'tp', 'eval_tp'),
|
| 745 |
+
('True Neg', 'tn', 'eval_tn'),
|
| 746 |
+
('False Pos', 'fp', 'eval_fp'),
|
| 747 |
+
('False Neg', 'fn', 'eval_fn')
|
| 748 |
+
]
|
| 749 |
+
|
| 750 |
+
for name, base_key, fine_key in cm_items:
|
| 751 |
+
base_val = baseline_perf[base_key]
|
| 752 |
+
fine_val = finetuned_results[fine_key]
|
| 753 |
+
change = fine_val - base_val
|
| 754 |
+
|
| 755 |
+
output += f"{name:<10} {base_val:>8d} {fine_val:>8d} {change:+10d}\n"
|
| 756 |
+
|
| 757 |
+
# 總結
|
| 758 |
+
output += f"\n📈 整體評估:\n"
|
| 759 |
+
output += f"{'='*40}\n"
|
| 760 |
+
|
| 761 |
+
f1_improve = calculate_improvement(baseline_perf['f1'], finetuned_results['eval_f1'])
|
| 762 |
+
if f1_improve > 10:
|
| 763 |
+
output += "✅ 顯著改善:微調帶來明顯的性能提升!\n"
|
| 764 |
+
elif f1_improve > 0:
|
| 765 |
+
output += "✅ 有所改善:微調產生正向影響。\n"
|
| 766 |
+
elif f1_improve == 0:
|
| 767 |
+
output += "➖ 無變化:微調未產生明顯影響。\n"
|
| 768 |
+
else:
|
| 769 |
+
output += "⚠️ 性能下降:可能需要調整超參數。\n"
|
| 770 |
+
|
| 771 |
+
return output
|
| 772 |
|
| 773 |
+
# ==================== 預測函數(改進版) ====================
|
| 774 |
def predict(model_id, text):
|
| 775 |
+
"""使用選定模型進行預測並與基準模型比較"""
|
| 776 |
|
| 777 |
if not model_id or model_id not in trained_models:
|
| 778 |
+
return "❌ 請選擇一個已訓練的模型"
|
| 779 |
+
|
| 780 |
+
if not text or len(text.strip()) == 0:
|
| 781 |
+
return "❌ 請輸入要預測的文字"
|
| 782 |
|
| 783 |
try:
|
| 784 |
+
# 獲取模型資訊
|
| 785 |
info = trained_models[model_id]
|
| 786 |
+
model = info['model']
|
| 787 |
+
tokenizer = info['tokenizer']
|
| 788 |
config = info['config']
|
| 789 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 790 |
|
| 791 |
+
# 文字預處理
|
| 792 |
+
inputs = tokenizer(
|
| 793 |
+
text,
|
| 794 |
+
return_tensors="pt",
|
| 795 |
+
truncation=True,
|
| 796 |
+
padding=True,
|
| 797 |
+
max_length=128
|
| 798 |
+
)
|
| 799 |
+
inputs_device = {k: v.to(device) for k, v in inputs.items()}
|
| 800 |
|
| 801 |
+
# ========== 微調模型預測 ==========
|
| 802 |
model.eval()
|
| 803 |
with torch.no_grad():
|
| 804 |
+
outputs = model(**inputs_device)
|
| 805 |
+
logits = outputs.logits
|
| 806 |
+
probs_finetuned = torch.nn.functional.softmax(logits, dim=-1)
|
| 807 |
pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
|
| 808 |
+
confidence_finetuned = probs_finetuned[0][pred_finetuned].item()
|
| 809 |
|
| 810 |
+
# ========== 基準模型預測 ==========
|
| 811 |
+
baseline_model = get_cached_baseline_model(config['model_name'])
|
| 812 |
+
baseline_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
with torch.no_grad():
|
| 815 |
+
outputs_baseline = baseline_model(**inputs_device)
|
| 816 |
+
logits_baseline = outputs_baseline.logits
|
| 817 |
+
probs_baseline = torch.nn.functional.softmax(logits_baseline, dim=-1)
|
| 818 |
pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
|
| 819 |
+
confidence_baseline = probs_baseline[0][pred_baseline].item()
|
| 820 |
|
| 821 |
+
# ========== 格式化輸出 ==========
|
| 822 |
+
result_finetuned = "🟢 存活" if pred_finetuned == 0 else "🔴 死亡"
|
| 823 |
+
result_baseline = "🟢 存活" if pred_baseline == 0 else "🔴 死亡"
|
| 824 |
agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
|
| 825 |
|
| 826 |
+
output = f"""🔮 預測結果比較分析
|
| 827 |
+
|
| 828 |
+
📝 輸入文字
|
| 829 |
+
{'='*60}
|
| 830 |
+
{text[:200]}{'...' if len(text) > 200 else ''}
|
| 831 |
|
| 832 |
+
{'='*60}
|
| 833 |
|
| 834 |
+
🎯 微調模型預測 ({model_id})
|
| 835 |
+
{'='*60}
|
| 836 |
+
預測結果: {result_finetuned}
|
| 837 |
+
預測信心: {confidence_finetuned:.1%}
|
| 838 |
|
|
|
|
|
|
|
|
|
|
| 839 |
機率分布:
|
| 840 |
+
• 存活 (0): {probs_finetuned[0][0].item():.2%}
|
| 841 |
+
• 死亡 (1): {probs_finetuned[0][1].item():.2%}
|
| 842 |
|
| 843 |
+
模型配置:
|
| 844 |
+
• 方法: {config['method'].upper()}
|
| 845 |
+
• 基礎模型: {config['type']}
|
| 846 |
+
• 訓練輪數: {config['epochs']}
|
| 847 |
+
|
| 848 |
+
{'='*60}
|
| 849 |
+
|
| 850 |
+
🔬 基準模型預測(未微調 {config['type']})
|
| 851 |
+
{'='*60}
|
| 852 |
+
預測結果: {result_baseline}
|
| 853 |
+
預測信心: {confidence_baseline:.1%}
|
| 854 |
|
|
|
|
|
|
|
|
|
|
| 855 |
機率分布:
|
| 856 |
+
• 存活 (0): {probs_baseline[0][0].item():.2%}
|
| 857 |
+
• 死亡 (1): {probs_baseline[0][1].item():.2%}
|
| 858 |
|
| 859 |
+
{'='*60}
|
| 860 |
|
| 861 |
+
📊 預測分析
|
| 862 |
+
{'='*60}
|
| 863 |
兩模型預測: {agreement}
|
| 864 |
"""
|
| 865 |
|
| 866 |
if pred_finetuned != pred_baseline:
|
| 867 |
+
output += f"""
|
| 868 |
+
💡 差異分析:
|
| 869 |
+
微調模型預測【{result_finetuned}】(信心: {confidence_finetuned:.1%})
|
| 870 |
+
基準模型預測【{result_baseline}】(信心: {confidence_baseline:.1%})
|
| 871 |
+
|
| 872 |
+
這種差異顯示了微調對此特定案例的影響。
|
| 873 |
+
微調模型可能學習到了更適合您資料集的特徵。
|
| 874 |
+
"""
|
| 875 |
+
else:
|
| 876 |
+
output += f"""
|
| 877 |
+
✅ 預測一致性分析:
|
| 878 |
+
兩個模型都預測為【{result_finetuned}】
|
| 879 |
+
信心差異: {abs(confidence_finetuned - confidence_baseline):.1%}
|
| 880 |
+
"""
|
| 881 |
|
| 882 |
+
# 加入模型整體表現對比
|
| 883 |
+
f1_improve = calculate_improvement(
|
| 884 |
+
info['baseline']['f1'],
|
| 885 |
+
info['results']['eval_f1']
|
| 886 |
+
)
|
| 887 |
|
| 888 |
output += f"""
|
| 889 |
|
| 890 |
+
📈 模型整體表現對比
|
| 891 |
+
{'='*60}
|
| 892 |
微調模型 F1: {info['results']['eval_f1']:.4f}
|
| 893 |
基準模型 F1: {info['baseline']['f1']:.4f}
|
| 894 |
+
改善幅度: {format_improvement(f1_improve)}
|
| 895 |
+
|
| 896 |
+
微調模型準確率: {info['results']['eval_accuracy']:.4f}
|
| 897 |
+
基準模型準確率: {info['baseline']['accuracy']:.4f}
|
| 898 |
"""
|
| 899 |
|
| 900 |
return output
|
| 901 |
|
| 902 |
except Exception as e:
|
| 903 |
import traceback
|
| 904 |
+
return f"❌ 預測時發生錯誤\n\n{str(e)}\n\n{traceback.format_exc()}"
|
| 905 |
|
| 906 |
+
# ==================== 模型比較函數 ====================
|
| 907 |
+
def compare_models():
|
| 908 |
+
"""比較所有已訓練的模型"""
|
| 909 |
+
|
| 910 |
if not trained_models:
|
| 911 |
+
return "❌ 尚未訓練任何模型。請先在「訓練」頁面訓練模型。"
|
| 912 |
+
|
| 913 |
+
output = "# 📊 模型比較報告\n\n"
|
| 914 |
+
output += f"共有 {len(trained_models)} 個已訓練模型\n\n"
|
| 915 |
|
| 916 |
+
# 微調模型表現表格
|
| 917 |
+
output += "## 🎯 微調模型表現\n\n"
|
| 918 |
+
output += "| 模型 ID | 基礎模型 | 方法 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n"
|
| 919 |
+
output += "|---------|----------|------|-----|--------|--------|--------|--------|--------|\n"
|
| 920 |
|
| 921 |
+
for model_id, info in trained_models.items():
|
| 922 |
r = info['results']
|
| 923 |
c = info['config']
|
| 924 |
+
|
| 925 |
+
# 縮短模型 ID 顯示
|
| 926 |
+
short_id = f"{c['type']}_{c['method']}_{info['timestamp'][-6:]}"
|
| 927 |
+
|
| 928 |
+
output += f"| {short_id} | {c['type']} | {c['method'].upper()} | "
|
| 929 |
+
output += f"{r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
|
| 930 |
+
output += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
|
| 931 |
+
output += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
|
| 932 |
|
| 933 |
+
# 基準模型表現
|
| 934 |
+
output += "\n## 🔬 基準模型表現(未微調)\n\n"
|
|
|
|
| 935 |
|
| 936 |
+
# 獲取唯一的基準模型
|
| 937 |
+
unique_baselines = {}
|
| 938 |
+
for model_id, info in trained_models.items():
|
| 939 |
+
base_type = info['config']['type']
|
| 940 |
+
if base_type not in unique_baselines:
|
| 941 |
+
unique_baselines[base_type] = info['baseline']
|
| 942 |
|
| 943 |
+
output += "| 基礎模型 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n"
|
| 944 |
+
output += "|----------|-----|--------|--------|--------|--------|--------|\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
|
| 946 |
+
for base_type, baseline in unique_baselines.items():
|
| 947 |
+
output += f"| {base_type} | {baseline['f1']:.4f} | {baseline['accuracy']:.4f} | "
|
| 948 |
+
output += f"{baseline['precision']:.4f} | {baseline['recall']:.4f} | "
|
| 949 |
+
output += f"{baseline['sensitivity']:.4f} | {baseline['specificity']:.4f} |\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
|
| 951 |
+
# 最佳模型分析
|
| 952 |
+
output += "\n## 🏆 最佳模型(各指標)\n\n"
|
| 953 |
+
|
| 954 |
+
metrics_to_check = [
|
| 955 |
+
('F1 Score', 'eval_f1'),
|
| 956 |
+
('準確率', 'eval_accuracy'),
|
| 957 |
+
('精確率', 'eval_precision'),
|
| 958 |
+
('召回率', 'eval_recall'),
|
| 959 |
+
('敏感度', 'eval_sensitivity'),
|
| 960 |
+
('特異度', 'eval_specificity')
|
| 961 |
+
]
|
| 962 |
+
|
| 963 |
+
for metric_name, metric_key in metrics_to_check:
|
| 964 |
+
best_model = max(
|
| 965 |
+
trained_models.items(),
|
| 966 |
+
key=lambda x: x[1]['results'][metric_key]
|
|
|
|
| 967 |
)
|
| 968 |
|
| 969 |
+
model_id = best_model[0]
|
| 970 |
+
value = best_model[1]['results'][metric_key]
|
| 971 |
+
baseline_val = best_model[1]['baseline'][metric_key.replace('eval_', '')]
|
| 972 |
+
improvement = calculate_improvement(baseline_val, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
|
| 974 |
+
output += f"**{metric_name}**: {model_id[:30]}... "
|
| 975 |
+
output += f"({value:.4f}, 改善 {format_improvement(improvement)})\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 976 |
|
| 977 |
+
# 改善統計
|
| 978 |
+
output += "## 📈 改善統計\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
|
| 980 |
+
improvements = []
|
| 981 |
+
for model_id, info in trained_models.items():
|
| 982 |
+
f1_base = info['baseline']['f1']
|
| 983 |
+
f1_fine = info['results']['eval_f1']
|
| 984 |
+
improve = calculate_improvement(f1_base, f1_fine)
|
| 985 |
+
|
| 986 |
+
if improve != float('inf'):
|
| 987 |
+
improvements.append({
|
| 988 |
+
'model': model_id,
|
| 989 |
+
'improvement': improve,
|
| 990 |
+
'method': info['config']['method']
|
| 991 |
+
})
|
| 992 |
|
| 993 |
+
if improvements:
|
| 994 |
+
avg_improvement = np.mean([x['improvement'] for x in improvements])
|
| 995 |
+
max_improvement = max(improvements, key=lambda x: x['improvement'])
|
| 996 |
+
min_improvement = min(improvements, key=lambda x: x['improvement'])
|
| 997 |
+
|
| 998 |
+
output += f"平均 F1 改善: {format_improvement(avg_improvement)}\n"
|
| 999 |
+
output += f"最大改善: {max_improvement['model'][:30]}... ({format_improvement(max_improvement['improvement'])})\n"
|
| 1000 |
+
output += f"最小改善: {min_improvement['model'][:30]}... ({format_improvement(min_improvement['improvement'])})\n\n"
|
| 1001 |
+
|
| 1002 |
+
# 方法比較
|
| 1003 |
+
method_improvements = {}
|
| 1004 |
+
for imp in improvements:
|
| 1005 |
+
method = imp['method']
|
| 1006 |
+
if method not in method_improvements:
|
| 1007 |
+
method_improvements[method] = []
|
| 1008 |
+
method_improvements[method].append(imp['improvement'])
|
| 1009 |
+
|
| 1010 |
+
output += "### 各方法平均改善:\n"
|
| 1011 |
+
for method, imps in method_improvements.items():
|
| 1012 |
+
avg_imp = np.mean(imps)
|
| 1013 |
+
output += f"- **{method.upper()}**: {format_improvement(avg_imp)}\n"
|
| 1014 |
+
|
| 1015 |
+
return output
|
| 1016 |
+
|
| 1017 |
+
# ==================== Gradio UI ====================
|
| 1018 |
+
def create_demo():
|
| 1019 |
+
"""創建 Gradio 介面"""
|
| 1020 |
+
|
| 1021 |
+
with gr.Blocks(
|
| 1022 |
+
title="BERT Fine-tuning 教學平台",
|
| 1023 |
+
theme=gr.themes.Soft(),
|
| 1024 |
+
css="""
|
| 1025 |
+
.gradio-container {font-family: 'Microsoft JhengHei', 'Arial', sans-serif;}
|
| 1026 |
+
"""
|
| 1027 |
+
) as demo:
|
| 1028 |
+
|
| 1029 |
+
gr.Markdown(
|
| 1030 |
+
"""
|
| 1031 |
+
# 🧬 BERT Fine-tuning 教學平台
|
| 1032 |
+
### 比較基準模型 vs 微調模型的表現差異(改進版)
|
| 1033 |
+
"""
|
| 1034 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1035 |
|
| 1036 |
+
with gr.Tab("🎯 訓練"):
|
| 1037 |
+
gr.Markdown("## 步驟 1: 選擇基礎模型")
|
| 1038 |
+
|
| 1039 |
+
base_model = gr.Dropdown(
|
| 1040 |
+
choices=["BERT-base", "BERT-base-chinese", "BioBERT", "SciBERT"],
|
| 1041 |
+
value="BERT-base",
|
| 1042 |
+
label="基礎模型",
|
| 1043 |
+
info="選擇適合您資料的預訓練模型"
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
gr.Markdown("## 步驟 2: 選擇微調方法")
|
| 1047 |
+
|
| 1048 |
+
method = gr.Radio(
|
| 1049 |
+
choices=["lora", "adalora", "full"],
|
| 1050 |
+
value="lora",
|
| 1051 |
+
label="微調方法",
|
| 1052 |
+
info="LoRA 和 AdaLoRA 是參數高效方法,Full 是完全微調"
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
gr.Markdown("## 步驟 3: 上傳資料")
|
| 1056 |
+
|
| 1057 |
+
csv_file = gr.File(
|
| 1058 |
+
label="CSV 檔案(需包含 Text 和 label 欄位)",
|
| 1059 |
+
file_types=[".csv"]
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
gr.Markdown("## 步驟 4: 設定訓練參數")
|
| 1063 |
+
|
| 1064 |
+
with gr.Accordion("🎯 基本訓練參數", open=True):
|
| 1065 |
+
with gr.Row():
|
| 1066 |
+
num_epochs = gr.Number(
|
| 1067 |
+
value=5, label="訓練輪數", minimum=1, maximum=50, precision=0,
|
| 1068 |
+
info="建議 3-10 輪,過多可能過擬合"
|
| 1069 |
+
)
|
| 1070 |
+
batch_size = gr.Number(
|
| 1071 |
+
value=8, label="批次大小", minimum=1, maximum=64, precision=0,
|
| 1072 |
+
info="GPU 記憶體不足時請降低"
|
| 1073 |
+
)
|
| 1074 |
+
learning_rate = gr.Number(
|
| 1075 |
+
value=3e-5, label="學習率", minimum=1e-6, maximum=1e-3,
|
| 1076 |
+
info="建議 1e-5 到 5e-5"
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
with gr.Accordion("⚙️ 進階參數"):
|
| 1080 |
+
with gr.Row():
|
| 1081 |
+
weight_decay = gr.Number(
|
| 1082 |
+
value=0.01, label="權重衰減", minimum=0, maximum=1,
|
| 1083 |
+
info="防止過擬合,建議 0.01-0.1"
|
| 1084 |
+
)
|
| 1085 |
+
dropout = gr.Number(
|
| 1086 |
+
value=0.1, label="Dropout 率", minimum=0, maximum=0.5,
|
| 1087 |
+
info="防止過擬合,建議 0.1-0.3"
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
with gr.Accordion("🔧 PEFT 參數(LoRA/AdaLoRA)"):
|
| 1091 |
+
with gr.Row():
|
| 1092 |
+
lora_r = gr.Number(
|
| 1093 |
+
value=16, label="LoRA Rank (r)", minimum=1, maximum=64, precision=0,
|
| 1094 |
+
info="越大表達能力越強,但參數越多"
|
| 1095 |
+
)
|
| 1096 |
+
lora_alpha = gr.Number(
|
| 1097 |
+
value=32, label="LoRA Alpha", minimum=1, maximum=128, precision=0,
|
| 1098 |
+
info="通常設為 Rank 的 2 倍"
|
| 1099 |
+
)
|
| 1100 |
+
lora_dropout = gr.Number(
|
| 1101 |
+
value=0.05, label="LoRA Dropout", minimum=0, maximum=0.5,
|
| 1102 |
+
info="LoRA 層的 dropout"
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
with gr.Accordion("⚖️ 類別平衡設定"):
|
| 1106 |
+
with gr.Row():
|
| 1107 |
+
weight_mult = gr.Number(
|
| 1108 |
+
value=1.0, label="權重倍數", minimum=0.1, maximum=5.0,
|
| 1109 |
+
info="調整少數類權重的倍數"
|
| 1110 |
+
)
|
| 1111 |
+
weight_method = gr.Dropdown(
|
| 1112 |
+
choices=["sqrt", "log", "balanced", "custom"],
|
| 1113 |
+
value="sqrt",
|
| 1114 |
+
label="權重計算方法",
|
| 1115 |
+
info="sqrt 和 log 適合極度不平衡資料"
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
with gr.Accordion("🎯 訓練策略"):
|
| 1119 |
+
with gr.Row():
|
| 1120 |
+
best_metric = gr.Dropdown(
|
| 1121 |
+
choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
|
| 1122 |
+
value="f1",
|
| 1123 |
+
label="最佳模型指標",
|
| 1124 |
+
info="根據此指標選擇最佳模型"
|
| 1125 |
+
)
|
| 1126 |
+
use_early_stopping = gr.Checkbox(
|
| 1127 |
+
value=True, label="啟用 Early Stopping",
|
| 1128 |
+
info="當模型不再改善時提前停止"
|
| 1129 |
+
)
|
| 1130 |
+
patience = gr.Number(
|
| 1131 |
+
value=3, label="Patience", minimum=1, maximum=10, precision=0,
|
| 1132 |
+
info="幾輪無改善後停止訓練"
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
|
| 1136 |
+
|
| 1137 |
+
gr.Markdown("## 📊 訓練結果")
|
| 1138 |
+
|
| 1139 |
+
with gr.Row():
|
| 1140 |
+
data_info = gr.Textbox(label="📋 訓練資訊", lines=25)
|
| 1141 |
+
history_output = gr.Textbox(label="📈 訓練歷程", lines=25)
|
| 1142 |
+
|
| 1143 |
+
with gr.Row():
|
| 1144 |
+
baseline_result = gr.Textbox(label="🔬 基準模型(未微調)", lines=15)
|
| 1145 |
+
finetuned_result = gr.Textbox(label="✅ 微調模型", lines=15)
|
| 1146 |
+
|
| 1147 |
+
comparison_result = gr.Textbox(label="📊 效能比較分析", lines=20)
|
| 1148 |
+
|
| 1149 |
+
train_btn.click(
|
| 1150 |
+
train_bert_model,
|
| 1151 |
+
inputs=[
|
| 1152 |
+
csv_file, base_model, method, num_epochs, batch_size, learning_rate,
|
| 1153 |
+
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
|
| 1154 |
+
weight_mult, weight_method, best_metric, use_early_stopping, patience
|
| 1155 |
+
],
|
| 1156 |
+
outputs=[data_info, baseline_result, finetuned_result, comparison_result, history_output]
|
| 1157 |
+
)
|
| 1158 |
|
| 1159 |
+
with gr.Tab("🔮 預測"):
|
| 1160 |
+
gr.Markdown("## 使用訓練好的模型進行預測")
|
| 1161 |
+
|
| 1162 |
+
with gr.Row():
|
| 1163 |
+
model_dropdown = gr.Dropdown(
|
| 1164 |
+
label="選擇模型",
|
| 1165 |
+
choices=list(trained_models.keys()),
|
| 1166 |
+
interactive=True
|
| 1167 |
+
)
|
| 1168 |
+
refresh_btn = gr.Button("🔄 刷新模型列表", size="sm")
|
| 1169 |
+
|
| 1170 |
+
text_input = gr.Textbox(
|
| 1171 |
+
label="輸入要預測的文字",
|
| 1172 |
+
lines=5,
|
| 1173 |
+
placeholder="請輸入病例描述或相關文字..."
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
predict_btn = gr.Button("🎯 執行預測", variant="primary", size="lg")
|
| 1177 |
+
|
| 1178 |
+
pred_output = gr.Textbox(label="預測結果與分析", lines=25)
|
| 1179 |
+
|
| 1180 |
+
# 刷新模型列表
|
| 1181 |
+
refresh_btn.click(
|
| 1182 |
+
lambda: gr.Dropdown(choices=list(trained_models.keys())),
|
| 1183 |
+
outputs=[model_dropdown]
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
# 執行預測
|
| 1187 |
+
predict_btn.click(
|
| 1188 |
+
predict,
|
| 1189 |
+
inputs=[model_dropdown, text_input],
|
| 1190 |
+
outputs=[pred_output]
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
# 範例
|
| 1194 |
+
gr.Examples(
|
| 1195 |
+
examples=[
|
| 1196 |
+
["Patient with stage II breast cancer, showing good response to chemotherapy treatment."],
|
| 1197 |
+
["Advanced metastatic cancer with multiple organ failure, poor prognosis."],
|
| 1198 |
+
["Early stage tumor detected, surgery scheduled, excellent recovery expected."],
|
| 1199 |
+
["Terminal stage disease, palliative care initiated, family counseling provided."]
|
| 1200 |
+
],
|
| 1201 |
+
inputs=text_input
|
| 1202 |
+
)
|
| 1203 |
|
| 1204 |
+
with gr.Tab("📊 比較"):
|
| 1205 |
+
gr.Markdown("## 比較所有已訓練的模型")
|
| 1206 |
+
|
| 1207 |
+
compare_btn = gr.Button("📊 生成比較報告", variant="primary", size="lg")
|
| 1208 |
+
compare_output = gr.Markdown()
|
| 1209 |
+
|
| 1210 |
+
compare_btn.click(compare_models, outputs=[compare_output])
|
| 1211 |
+
|
| 1212 |
+
with gr.Tab("📖 說明"):
|
| 1213 |
+
gr.Markdown("""
|
| 1214 |
+
## 📖 使用說明
|
| 1215 |
+
|
| 1216 |
+
### 🎯 平台特色
|
| 1217 |
+
|
| 1218 |
+
本改進版平台提供以下功能:
|
| 1219 |
+
|
| 1220 |
+
1. **自動基準比較**:每次訓練都會自動評估基準模型,清楚顯示微調的改善
|
| 1221 |
+
2. **訓練監控**:記錄每個 epoch 的詳細訓練歷程
|
| 1222 |
+
3. **Early Stopping**:避免過擬合,自動選擇最佳模型
|
| 1223 |
+
4. **多種權重策略**:針對不平衡資料提供多種處理方法
|
| 1224 |
+
5. **完整評估指標**:包含 F1、準確率、精確率、召回率、敏感度、特異度、PPV、NPV
|
| 1225 |
+
|
| 1226 |
+
### 🤖 支援的基礎模型
|
| 1227 |
+
|
| 1228 |
+
- **BERT-base**: 標準英文 BERT,適用於一般英文文本
|
| 1229 |
+
- **BERT-base-chinese**: 中文 BERT,適用於中文文本
|
| 1230 |
+
- **BioBERT**: 生物醫學領域專用 BERT
|
| 1231 |
+
- **SciBERT**: 科學文獻專用 BERT
|
| 1232 |
+
|
| 1233 |
+
### 🔧 微調方法說明
|
| 1234 |
+
|
| 1235 |
+
- **LoRA** (Low-Rank Adaptation)
|
| 1236 |
+
- 參數效率最高,只訓練 <1% 參數
|
| 1237 |
+
- 訓練速度快,記憶體需求低
|
| 1238 |
+
- 適合大多數場景
|
| 1239 |
+
|
| 1240 |
+
- **AdaLoRA** (Adaptive LoRA)
|
| 1241 |
+
- 自動調整秩的分配
|
| 1242 |
+
- 可能獲得更好的效果
|
| 1243 |
+
- 訓練時間稍長
|
| 1244 |
+
|
| 1245 |
+
- **Full** (完全微調)
|
| 1246 |
+
- 訓練所有參數
|
| 1247 |
+
- 可能獲得最佳效果
|
| 1248 |
+
- 需要較大記憶體和時間
|
| 1249 |
+
|
| 1250 |
+
### ⚖️ 處理不平衡資料
|
| 1251 |
+
|
| 1252 |
+
#### 權重計算方法:
|
| 1253 |
+
|
| 1254 |
+
1. **sqrt** (平方根法) - 推薦用於極度不平衡
|
| 1255 |
+
- 使用平方根緩和權重
|
| 1256 |
+
- 避免權重過大導致過擬合
|
| 1257 |
+
|
| 1258 |
+
2. **log** (對數法) - 更保守的方法
|
| 1259 |
+
- 使用對數進一步緩和
|
| 1260 |
+
- 適合極度不平衡且容易過擬合的情況
|
| 1261 |
+
|
| 1262 |
+
3. **balanced** (平衡法)
|
| 1263 |
+
- sklearn 風格的自動平衡
|
| 1264 |
+
- 適合中度不平衡
|
| 1265 |
+
|
| 1266 |
+
4. **custom** (自定義)
|
| 1267 |
+
- 根據不平衡程度自動調整
|
| 1268 |
+
- 綜合考慮多種因素
|
| 1269 |
+
|
| 1270 |
+
#### 建議參數設定:
|
| 1271 |
+
|
| 1272 |
+
**極度不平衡 (>20:1)**
|
| 1273 |
+
- 權重方法: sqrt 或 log
|
| 1274 |
+
- 權重倍數: 0.5-1.0
|
| 1275 |
+
- 使用 Focal Loss (自動啟用)
|
| 1276 |
+
- Early Stopping: 建議開啟
|
| 1277 |
+
|
| 1278 |
+
**高度不平衡 (10-20:1)**
|
| 1279 |
+
- 權重方法: sqrt
|
| 1280 |
+
- 權重倍數: 0.8-1.5
|
| 1281 |
+
- Early Stopping: 建議開啟
|
| 1282 |
+
|
| 1283 |
+
**中度不平衡 (5-10:1)**
|
| 1284 |
+
- 權重方法: balanced
|
| 1285 |
+
- 權重倍數: 1.0-2.0
|
| 1286 |
+
|
| 1287 |
+
**輕度不平衡 (<5:1)**
|
| 1288 |
+
- 權重方法: balanced
|
| 1289 |
+
- 權重倍數: 1.5-3.0
|
| 1290 |
+
|
| 1291 |
+
### 📊 評估指標說明
|
| 1292 |
+
|
| 1293 |
+
- **F1 Score**: 精確率和召回率的調和平均,適合不平衡資料
|
| 1294 |
+
- **Accuracy**: 整體準確率
|
| 1295 |
+
- **Precision**: 預測為正類中實際為正類的比例
|
| 1296 |
+
- **Recall/Sensitivity**: 實際正類中被正確預測的比例
|
| 1297 |
+
- **Specificity**: 實際負類中被正確預測的比例
|
| 1298 |
+
- **PPV**: 陽性預測值
|
| 1299 |
+
- **NPV**: 陰性預測值
|
| 1300 |
+
|
| 1301 |
+
### 🚀 快速開始指南
|
| 1302 |
+
|
| 1303 |
+
1. **準備資料**
|
| 1304 |
+
- CSV 格式,包含 `Text` 和 `label` 欄位
|
| 1305 |
+
- label: 0=負類(如存活), 1=正類(如死亡)
|
| 1306 |
+
|
| 1307 |
+
2. **選擇模型與方法**
|
| 1308 |
+
- 英文資料:BERT-base + LoRA
|
| 1309 |
+
- 中文資料:BERT-base-chinese + LoRA
|
| 1310 |
+
- 醫學資料:BioBERT + LoRA
|
| 1311 |
+
|
| 1312 |
+
3. **設定參數**
|
| 1313 |
+
- 使用預設參數作為起點
|
| 1314 |
+
- 根據資料不平衡程度調整權重設定
|
| 1315 |
+
|
| 1316 |
+
4. **訓練與評估**
|
| 1317 |
+
- 點擊「開始訓練」
|
| 1318 |
+
- 查看基準 vs 微調的比較
|
| 1319 |
+
- 觀察訓練歷程
|
| 1320 |
+
|
| 1321 |
+
5. **測試預測**
|
| 1322 |
+
- 在「預測」頁面選擇模型
|
| 1323 |
+
- 輸入文字進行預測
|
| 1324 |
+
- 比較微調前後的差異
|
| 1325 |
+
|
| 1326 |
+
### ⚠️ 注意事項
|
| 1327 |
+
|
| 1328 |
+
- GPU 可大幅加速訓練
|
| 1329 |
+
- 批次大小過大可能導致記憶體不足
|
| 1330 |
+
- Early Stopping 可避免過擬合
|
| 1331 |
+
- 極度不平衡資料建議使用較保守的權重設定
|
| 1332 |
+
|
| 1333 |
+
### 💡 優化建議
|
| 1334 |
+
|
| 1335 |
+
1. **記憶體不足**:降低批次大小或使用 LoRA
|
| 1336 |
+
2. **過擬合**:增加 dropout、使用 Early Stopping、降低學習率
|
| 1337 |
+
3. **欠擬合**:增加訓練輪數、提高學習率、增加模型容量
|
| 1338 |
+
4. **不平衡資料**:調整類別權重、使用適當的評估指標(F1)
|
| 1339 |
+
""")
|
| 1340 |
+
|
| 1341 |
+
return demo
|
| 1342 |
|
| 1343 |
+
# ==================== 主程式 ====================
|
| 1344 |
if __name__ == "__main__":
|
| 1345 |
+
demo = create_demo()
|
| 1346 |
demo.launch(
|
| 1347 |
server_name="0.0.0.0",
|
| 1348 |
server_port=7860,
|
| 1349 |
share=False,
|
| 1350 |
+
max_threads=4
|
| 1351 |
)
|