# ============================================================================== # Smol-MoE 8x135M - Master Script # (Final Version, All Fixes Included) # ============================================================================== import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer # <<< 这是最关键的修复!确保AutoTokenizer在这里 from transformers.models.llama.modeling_llama import LlamaMLP from safetensors.torch import load_file # <<< 使用正确的safetensors加载器 import os import shutil import numpy as np import time # --- 0. Configuration & Setup --- MODEL_NAME = "./SmolLM2-135M-Instruct" BASE_EXPERT_PATH = "./models" EXPERT_DIRS = [ "SmolLM2-135M-Instruct-Actor", "SmolLM2-135M-Instruct-Analyst", "SmolLM2-135M-Instruct-Coder", "SmolLM2-135M-Instruct-Encyclopedia", "SmolLM2-135M-Instruct-Guardian", "SmolLM2-135M-Instruct-Summarizer", "SmolLM2-135M-Instruct-Thinker", "SmolLM2-135M-Instruct-Writer" ] NUM_EXPERTS = 8 TOP_K = 2 LEARNING_RATE = 0.001 EPOCHS = 20 # 既然我们知道模拟数据无法让模型学习,20轮足以验证流程 BATCH_SIZE = 4 SEQUENCE_LENGTH = 128 LB_LOSS_COEFFICIENT = 0.01 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # --- 1. Define the MoE Architecture Components --- class MoERouter(nn.Module): def __init__(self, hidden_size: int, num_experts: int): super().__init__() self.layer = nn.Linear(hidden_size, num_experts, bias=False) def forward(self, hidden_states): return self.layer(hidden_states) class MoEModule(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.top_k = TOP_K self.num_experts = NUM_EXPERTS self.router = MoERouter(self.hidden_size, self.num_experts) self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)]) self.most_recent_lb_loss = None def forward(self, hidden_states): original_shape = hidden_states.shape flat_hidden_states = hidden_states.view(-1, self.hidden_size) router_logits = self.router(flat_hidden_states) routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) router_probs_full = F.softmax(router_logits, dim=-1, dtype=torch.float) avg_expert_prob = router_probs_full.mean(dim=0) expert_mask_for_lb = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1) avg_expert_fraction = expert_mask_for_lb.float().mean(dim=0) self.most_recent_lb_loss = self.num_experts * torch.sum(avg_expert_prob * avg_expert_fraction) final_hidden_states = torch.zeros_like(flat_hidden_states) for k in range(self.top_k): expert_indices_k = selected_experts[:, k] routing_weights_k = routing_weights[:, k] for i in range(self.num_experts): mask = expert_indices_k == i if mask.any(): expert_output = self.experts[i](flat_hidden_states[mask]) final_hidden_states.index_add_(0, torch.where(mask)[0], expert_output * routing_weights_k[mask].unsqueeze(1)) return final_hidden_states.view(*original_shape) # --- 2. The Grand Assembly Function --- def create_moe_model(): print("--- Starting Architectural Surgery ---") config = AutoConfig.from_pretrained(MODEL_NAME) print("Step 1: Loading base model skeleton...") base_model = AutoModelForCausalLM.from_pretrained( os.path.join(BASE_EXPERT_PATH, EXPERT_DIRS[0]), torch_dtype=torch.bfloat16, device_map=device ) print("Step 2: Pre-loading all expert weights into CPU memory for efficiency...") all_experts_state_dicts = [ load_file(os.path.join(BASE_EXPERT_PATH, expert_dir, 'model.safetensors'), device='cpu') for expert_dir in EXPERT_DIRS ] print("All expert weights pre-loaded.") print("Step 3: Replacing FFNs with MoE modules and transplanting expert weights...") for layer_idx, layer in enumerate(base_model.model.layers): layer.mlp = MoEModule(config).to(device, dtype=torch.bfloat16) for expert_idx in range(NUM_EXPERTS): expert_state_dict = all_experts_state_dicts[expert_idx] expert_mlp_weights = { k.replace(f"model.layers.{layer_idx}.mlp.", ""): v for k, v in expert_state_dict.items() if f"model.layers.{layer_idx}.mlp." in k } layer.mlp.experts[expert_idx].load_state_dict(expert_mlp_weights) print("Step 4: Freezing all parameters except for the routers...") for name, param in base_model.named_parameters(): if "router" not in name: param.requires_grad = False print("\n--- Surgery Complete! MoE Model is assembled and ready for training. ---") trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in base_model.parameters()) print(f"Total Parameters: {total_params / 1e6:.2f}M") print(f"Trainable Parameters (Routers): {trainable_params}") return base_model # --- 3. The Main Training & Saving Process --- def main(): moe_model = create_moe_model() optimizer = optim.AdamW([p for p in moe_model.parameters() if p.requires_grad], lr=LEARNING_RATE) print("\n--- Preparing Simulated Mixed Dataset for Training ---") mock_input_ids = torch.randint(0, moe_model.config.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH), device=device) mock_labels = mock_input_ids.clone() print("--- Starting Router Training Loop (Optimized & Corrected) ---") moe_model.train() start_time = time.time() for epoch in range(EPOCHS): optimizer.zero_grad() outputs = moe_model(input_ids=mock_input_ids, labels=mock_labels) main_loss = outputs.loss total_lb_loss = 0.0 for layer in moe_model.model.layers: total_lb_loss += layer.mlp.most_recent_lb_loss total_loss = main_loss + LB_LOSS_COEFFICIENT * total_lb_loss total_loss.backward() optimizer.step() if (epoch + 1) % 10 == 0: elapsed_time = time.time() - start_time print(f"Epoch [{epoch+1:03d}/{EPOCHS}] | Total Loss: {total_loss.item():.4f} | " f"Main Loss: {main_loss.item():.4f} | " f"Avg LB Loss: {(total_lb_loss.item() / moe_model.config.num_hidden_layers):.4f} | " f"Time: {elapsed_time:.2f}s") start_time = time.time() print("\n--- Router Training Complete! ---") print("\n--- Phase 5: Saving the fully trained MoE model to disk ---") OUTPUT_MODEL_DIR = "./SmolMoE-8x135M-Instruct-v1-Trained" if os.path.exists(OUTPUT_MODEL_DIR): shutil.rmtree(OUTPUT_MODEL_DIR) os.makedirs(OUTPUT_MODEL_DIR) print("Updating model config with MoE-specific parameters...") moe_model.config.moe_num_experts = NUM_EXPERTS moe_model.config.moe_top_k = TOP_K print(f"Saving model to '{OUTPUT_MODEL_DIR}'...") moe_model.save_pretrained(OUTPUT_MODEL_DIR) print("Saving tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.save_pretrained(OUTPUT_MODEL_DIR) print("\n--- Model successfully saved! ---") print("You can now load this model in other scripts, but you must re-define the custom MoE classes first.") if __name__ == "__main__": main()