aifeifei798's picture
Upload 7 files
840fc15 verified
# ==============================================================================
# Smol-MoE 8x135M - Master Script
# (Final Version, All Fixes Included)
# ==============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer # <<< 这是最关键的修复!确保AutoTokenizer在这里
from transformers.models.llama.modeling_llama import LlamaMLP
from safetensors.torch import load_file # <<< 使用正确的safetensors加载器
import os
import shutil
import numpy as np
import time
# --- 0. Configuration & Setup ---
MODEL_NAME = "./SmolLM2-135M-Instruct"
BASE_EXPERT_PATH = "./models"
EXPERT_DIRS = [
"SmolLM2-135M-Instruct-Actor", "SmolLM2-135M-Instruct-Analyst",
"SmolLM2-135M-Instruct-Coder", "SmolLM2-135M-Instruct-Encyclopedia",
"SmolLM2-135M-Instruct-Guardian", "SmolLM2-135M-Instruct-Summarizer",
"SmolLM2-135M-Instruct-Thinker", "SmolLM2-135M-Instruct-Writer"
]
NUM_EXPERTS = 8
TOP_K = 2
LEARNING_RATE = 0.001
EPOCHS = 20 # 既然我们知道模拟数据无法让模型学习,20轮足以验证流程
BATCH_SIZE = 4
SEQUENCE_LENGTH = 128
LB_LOSS_COEFFICIENT = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- 1. Define the MoE Architecture Components ---
class MoERouter(nn.Module):
def __init__(self, hidden_size: int, num_experts: int):
super().__init__()
self.layer = nn.Linear(hidden_size, num_experts, bias=False)
def forward(self, hidden_states):
return self.layer(hidden_states)
class MoEModule(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.top_k = TOP_K
self.num_experts = NUM_EXPERTS
self.router = MoERouter(self.hidden_size, self.num_experts)
self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)])
self.most_recent_lb_loss = None
def forward(self, hidden_states):
original_shape = hidden_states.shape
flat_hidden_states = hidden_states.view(-1, self.hidden_size)
router_logits = self.router(flat_hidden_states)
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
router_probs_full = F.softmax(router_logits, dim=-1, dtype=torch.float)
avg_expert_prob = router_probs_full.mean(dim=0)
expert_mask_for_lb = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1)
avg_expert_fraction = expert_mask_for_lb.float().mean(dim=0)
self.most_recent_lb_loss = self.num_experts * torch.sum(avg_expert_prob * avg_expert_fraction)
final_hidden_states = torch.zeros_like(flat_hidden_states)
for k in range(self.top_k):
expert_indices_k = selected_experts[:, k]
routing_weights_k = routing_weights[:, k]
for i in range(self.num_experts):
mask = expert_indices_k == i
if mask.any():
expert_output = self.experts[i](flat_hidden_states[mask])
final_hidden_states.index_add_(0, torch.where(mask)[0], expert_output * routing_weights_k[mask].unsqueeze(1))
return final_hidden_states.view(*original_shape)
# --- 2. The Grand Assembly Function ---
def create_moe_model():
print("--- Starting Architectural Surgery ---")
config = AutoConfig.from_pretrained(MODEL_NAME)
print("Step 1: Loading base model skeleton...")
base_model = AutoModelForCausalLM.from_pretrained(
os.path.join(BASE_EXPERT_PATH, EXPERT_DIRS[0]),
torch_dtype=torch.bfloat16,
device_map=device
)
print("Step 2: Pre-loading all expert weights into CPU memory for efficiency...")
all_experts_state_dicts = [
load_file(os.path.join(BASE_EXPERT_PATH, expert_dir, 'model.safetensors'), device='cpu')
for expert_dir in EXPERT_DIRS
]
print("All expert weights pre-loaded.")
print("Step 3: Replacing FFNs with MoE modules and transplanting expert weights...")
for layer_idx, layer in enumerate(base_model.model.layers):
layer.mlp = MoEModule(config).to(device, dtype=torch.bfloat16)
for expert_idx in range(NUM_EXPERTS):
expert_state_dict = all_experts_state_dicts[expert_idx]
expert_mlp_weights = {
k.replace(f"model.layers.{layer_idx}.mlp.", ""): v
for k, v in expert_state_dict.items()
if f"model.layers.{layer_idx}.mlp." in k
}
layer.mlp.experts[expert_idx].load_state_dict(expert_mlp_weights)
print("Step 4: Freezing all parameters except for the routers...")
for name, param in base_model.named_parameters():
if "router" not in name:
param.requires_grad = False
print("\n--- Surgery Complete! MoE Model is assembled and ready for training. ---")
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in base_model.parameters())
print(f"Total Parameters: {total_params / 1e6:.2f}M")
print(f"Trainable Parameters (Routers): {trainable_params}")
return base_model
# --- 3. The Main Training & Saving Process ---
def main():
moe_model = create_moe_model()
optimizer = optim.AdamW([p for p in moe_model.parameters() if p.requires_grad], lr=LEARNING_RATE)
print("\n--- Preparing Simulated Mixed Dataset for Training ---")
mock_input_ids = torch.randint(0, moe_model.config.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH), device=device)
mock_labels = mock_input_ids.clone()
print("--- Starting Router Training Loop (Optimized & Corrected) ---")
moe_model.train()
start_time = time.time()
for epoch in range(EPOCHS):
optimizer.zero_grad()
outputs = moe_model(input_ids=mock_input_ids, labels=mock_labels)
main_loss = outputs.loss
total_lb_loss = 0.0
for layer in moe_model.model.layers:
total_lb_loss += layer.mlp.most_recent_lb_loss
total_loss = main_loss + LB_LOSS_COEFFICIENT * total_lb_loss
total_loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
elapsed_time = time.time() - start_time
print(f"Epoch [{epoch+1:03d}/{EPOCHS}] | Total Loss: {total_loss.item():.4f} | "
f"Main Loss: {main_loss.item():.4f} | "
f"Avg LB Loss: {(total_lb_loss.item() / moe_model.config.num_hidden_layers):.4f} | "
f"Time: {elapsed_time:.2f}s")
start_time = time.time()
print("\n--- Router Training Complete! ---")
print("\n--- Phase 5: Saving the fully trained MoE model to disk ---")
OUTPUT_MODEL_DIR = "./SmolMoE-8x135M-Instruct-v1-Trained"
if os.path.exists(OUTPUT_MODEL_DIR):
shutil.rmtree(OUTPUT_MODEL_DIR)
os.makedirs(OUTPUT_MODEL_DIR)
print("Updating model config with MoE-specific parameters...")
moe_model.config.moe_num_experts = NUM_EXPERTS
moe_model.config.moe_top_k = TOP_K
print(f"Saving model to '{OUTPUT_MODEL_DIR}'...")
moe_model.save_pretrained(OUTPUT_MODEL_DIR)
print("Saving tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(OUTPUT_MODEL_DIR)
print("\n--- Model successfully saved! ---")
print("You can now load this model in other scripts, but you must re-define the custom MoE classes first.")
if __name__ == "__main__":
main()