ToxcityDetector / train_model.py
khushi-18's picture
Upload 13 files
3a4a5df verified
raw
history blame
11.7 kB
"""
Train toxic comment classification model on Jigsaw dataset
Uses 70% of data and Mac M2 GPU (MPS)
"""
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import numpy as np
import os
from tqdm import tqdm
import time
import re
# Configuration
MODEL_NAME = "distilbert-base-uncased"
DATA_PATH = "train.csv"
OUTPUT_DIR = "models"
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
MAX_LENGTH = 128
TRAIN_SPLIT = 0.7 # Use 70% of data
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15
# Labels for classification
LABELS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
NUM_LABELS = len(LABELS)
# Check for GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"πŸš€ Using device: {device}")
class ToxicCommentDataset(Dataset):
"""Custom dataset for toxic comment classification"""
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.FloatTensor(label)
}
def preprocess_text(text):
"""Clean and preprocess text"""
if pd.isna(text):
return ""
text = str(text).lower()
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
# Remove extra whitespace
text = ' '.join(text.split())
return text
def load_data():
"""Load and split the data"""
print("πŸ“‚ Loading data...")
df = pd.read_csv(DATA_PATH)
# Preprocess texts
print("🧹 Preprocessing texts...")
df['comment_text'] = df['comment_text'].apply(preprocess_text)
# Extract labels
labels = df[LABELS].values
# Use 70% of data
total_samples = len(df)
train_samples = int(total_samples * TRAIN_SPLIT)
print(f"πŸ“Š Dataset Statistics:")
print(f" Total samples: {total_samples:,}")
print(f" Using {TRAIN_SPLIT*100}%: {train_samples:,} samples")
# Split data: 70% train, 15% val, 15% test
X_temp = df['comment_text'].iloc[:train_samples].values
y_temp = labels[:train_samples]
X_train, X_temp, y_train, y_temp = train_test_split(
X_temp, y_temp, test_size=VAL_SPLIT+TEST_SPLIT, random_state=42
)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp, test_size=0.5, random_state=42
)
print(f" Train: {len(X_train):,} samples")
print(f" Val: {len(X_val):,} samples")
print(f" Test: {len(X_test):,} samples")
return X_train, X_val, X_test, y_train, y_val, y_test
def train_epoch(model, dataloader, optimizer, device):
"""Train for one epoch"""
model.train()
total_loss = 0
predictions = []
actuals = []
progress_bar = tqdm(dataloader, desc="Training", leave=False)
for batch in progress_bar:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += loss.item()
# Store predictions for metrics
predictions.append(torch.sigmoid(outputs.logits).cpu().detach().numpy())
actuals.append(labels.cpu().numpy())
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(dataloader)
predictions = np.vstack(predictions)
actuals = np.vstack(actuals)
return avg_loss, predictions, actuals
def evaluate(model, dataloader, device):
"""Evaluate the model"""
model.eval()
total_loss = 0
predictions = []
actuals = []
with torch.no_grad():
progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
for batch in progress_bar:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
total_loss += loss.item()
predictions.append(torch.sigmoid(outputs.logits).cpu().numpy())
actuals.append(labels.cpu().numpy())
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(dataloader)
predictions = np.vstack(predictions)
actuals = np.vstack(actuals)
return avg_loss, predictions, actuals
def compute_metrics(predictions, actuals, threshold=0.5):
"""Compute classification metrics"""
binary_predictions = (predictions >= threshold).astype(int)
# Per-label metrics
f1_scores = []
auc_scores = []
for i, label in enumerate(LABELS):
f1 = f1_score(actuals[:, i], binary_predictions[:, i], average='binary')
try:
auc = roc_auc_score(actuals[:, i], predictions[:, i])
except ValueError:
auc = 0.0
f1_scores.append(f1)
auc_scores.append(auc)
# Overall metrics
macro_f1 = np.mean(f1_scores)
macro_auc = np.mean(auc_scores)
return {
'macro_f1': macro_f1,
'macro_auc': macro_auc,
'per_label_f1': f1_scores,
'per_label_auc': auc_scores
}
def main():
"""Main training function"""
print("=" * 80)
print("πŸš€ Starting Toxic Comment Classification Training")
print("=" * 80)
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Load data
X_train, X_val, X_test, y_train, y_val, y_test = load_data()
# Initialize tokenizer and model
print(f"\nπŸ“₯ Loading model: {MODEL_NAME}")
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS,
problem_type="multi_label_classification"
)
model.to(device)
print(f"βœ“ Model loaded successfully")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Create datasets and dataloaders
print("\nπŸ”„ Creating datasets...")
train_dataset = ToxicCommentDataset(X_train, y_train, tokenizer, MAX_LENGTH)
val_dataset = ToxicCommentDataset(X_val, y_val, tokenizer, MAX_LENGTH)
test_dataset = ToxicCommentDataset(X_test, y_test, tokenizer, MAX_LENGTH)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"βœ“ Datasets created")
# Setup optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# Training loop
print("\n" + "=" * 80)
print("πŸ‹οΈ Starting Training")
print("=" * 80)
best_val_f1 = 0
training_history = []
for epoch in range(NUM_EPOCHS):
print(f"\n{'='*80}")
print(f"πŸ“… EPOCH {epoch + 1}/{NUM_EPOCHS}")
print(f"{'='*80}")
epoch_start = time.time()
# Train
print(f"\n▢️ Training Phase")
train_loss, train_preds, train_actuals = train_epoch(model, train_loader, optimizer, device)
train_metrics = compute_metrics(train_preds, train_actuals)
# Validate
print(f"\nβœ… Validation Phase")
val_loss, val_preds, val_actuals = evaluate(model, val_loader, device)
val_metrics = compute_metrics(val_preds, val_actuals)
epoch_time = time.time() - epoch_start
# Print results
print(f"\nπŸ“Š Epoch {epoch + 1} Results:")
print(f" Time: {epoch_time:.2f}s ({epoch_time/60:.2f} min)")
print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
print(f" Train F1: {train_metrics['macro_f1']:.4f} | Val F1: {val_metrics['macro_f1']:.4f}")
print(f" Train AUC: {train_metrics['macro_auc']:.4f} | Val AUC: {val_metrics['macro_auc']:.4f}")
print(f"\nπŸ“‹ Per-Label Validation F1 Scores:")
for label, f1, auc in zip(LABELS, val_metrics['per_label_f1'], val_metrics['per_label_auc']):
print(f" {label:20s}: F1={f1:.4f} | AUC={auc:.4f}")
# Save history
training_history.append({
'epoch': epoch + 1,
'train_loss': train_loss,
'val_loss': val_loss,
'train_f1': train_metrics['macro_f1'],
'val_f1': val_metrics['macro_f1'],
'train_auc': train_metrics['macro_auc'],
'val_auc': val_metrics['macro_auc']
})
# Save best model
if val_metrics['macro_f1'] > best_val_f1:
best_val_f1 = val_metrics['macro_f1']
model_path = os.path.join(OUTPUT_DIR, "best_model")
print(f"\nπŸ’Ύ Saving best model (F1: {best_val_f1:.4f}) to {model_path}")
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
print("\n" + "=" * 80)
print("πŸ§ͺ Final Testing")
print("=" * 80)
# Load best model and test
print("\nπŸ“₯ Loading best model for testing...")
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.to(device)
test_loss, test_preds, test_actuals = evaluate(model, test_loader, device)
test_metrics = compute_metrics(test_preds, test_actuals)
print(f"\n🎯 Final Test Results:")
print(f" Test Loss: {test_loss:.4f}")
print(f" Test F1: {test_metrics['macro_f1']:.4f}")
print(f" Test AUC: {test_metrics['macro_auc']:.4f}")
print(f"\nπŸ“‹ Per-Label Test F1 Scores:")
for label, f1, auc in zip(LABELS, test_metrics['per_label_f1'], test_metrics['per_label_auc']):
print(f" {label:20s}: F1={f1:.4f} | AUC={auc:.4f}")
# Save training history
history_df = pd.DataFrame(training_history)
history_path = os.path.join(OUTPUT_DIR, "training_history.csv")
history_df.to_csv(history_path, index=False)
print(f"\nπŸ’Ύ Training history saved to {history_path}")
# Print summary
print("\n" + "=" * 80)
print("βœ… Training Complete!")
print("=" * 80)
print(f"\nπŸ“ Model saved to: {model_path}")
print(f"πŸ“Š Training history: {history_path}")
print(f"🎯 Best validation F1: {best_val_f1:.4f}")
print(f"πŸ§ͺ Final test F1: {test_metrics['macro_f1']:.4f}")
print("\nπŸš€ You can now use this model in your Streamlit app!")
if __name__ == "__main__":
main()