|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import torch.nn.functional as F
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
|
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
|
|
from safetensors.torch import load_file
|
|
|
import os
|
|
|
import shutil
|
|
|
import numpy as np
|
|
|
import time
|
|
|
|
|
|
|
|
|
MODEL_NAME = "./SmolLM2-135M-Instruct"
|
|
|
BASE_EXPERT_PATH = "./models"
|
|
|
EXPERT_DIRS = [
|
|
|
"SmolLM2-135M-Instruct-Actor", "SmolLM2-135M-Instruct-Analyst",
|
|
|
"SmolLM2-135M-Instruct-Coder", "SmolLM2-135M-Instruct-Encyclopedia",
|
|
|
"SmolLM2-135M-Instruct-Guardian", "SmolLM2-135M-Instruct-Summarizer",
|
|
|
"SmolLM2-135M-Instruct-Thinker", "SmolLM2-135M-Instruct-Writer"
|
|
|
]
|
|
|
NUM_EXPERTS = 8
|
|
|
TOP_K = 2
|
|
|
LEARNING_RATE = 0.001
|
|
|
EPOCHS = 20
|
|
|
BATCH_SIZE = 4
|
|
|
SEQUENCE_LENGTH = 128
|
|
|
LB_LOSS_COEFFICIENT = 0.01
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
class MoERouter(nn.Module):
|
|
|
def __init__(self, hidden_size: int, num_experts: int):
|
|
|
super().__init__()
|
|
|
self.layer = nn.Linear(hidden_size, num_experts, bias=False)
|
|
|
def forward(self, hidden_states):
|
|
|
return self.layer(hidden_states)
|
|
|
|
|
|
class MoEModule(nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.hidden_size = config.hidden_size
|
|
|
self.top_k = TOP_K
|
|
|
self.num_experts = NUM_EXPERTS
|
|
|
self.router = MoERouter(self.hidden_size, self.num_experts)
|
|
|
self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)])
|
|
|
self.most_recent_lb_loss = None
|
|
|
|
|
|
def forward(self, hidden_states):
|
|
|
original_shape = hidden_states.shape
|
|
|
flat_hidden_states = hidden_states.view(-1, self.hidden_size)
|
|
|
router_logits = self.router(flat_hidden_states)
|
|
|
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
|
|
|
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
|
router_probs_full = F.softmax(router_logits, dim=-1, dtype=torch.float)
|
|
|
avg_expert_prob = router_probs_full.mean(dim=0)
|
|
|
expert_mask_for_lb = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1)
|
|
|
avg_expert_fraction = expert_mask_for_lb.float().mean(dim=0)
|
|
|
self.most_recent_lb_loss = self.num_experts * torch.sum(avg_expert_prob * avg_expert_fraction)
|
|
|
final_hidden_states = torch.zeros_like(flat_hidden_states)
|
|
|
for k in range(self.top_k):
|
|
|
expert_indices_k = selected_experts[:, k]
|
|
|
routing_weights_k = routing_weights[:, k]
|
|
|
for i in range(self.num_experts):
|
|
|
mask = expert_indices_k == i
|
|
|
if mask.any():
|
|
|
expert_output = self.experts[i](flat_hidden_states[mask])
|
|
|
final_hidden_states.index_add_(0, torch.where(mask)[0], expert_output * routing_weights_k[mask].unsqueeze(1))
|
|
|
return final_hidden_states.view(*original_shape)
|
|
|
|
|
|
|
|
|
def create_moe_model():
|
|
|
print("--- Starting Architectural Surgery ---")
|
|
|
config = AutoConfig.from_pretrained(MODEL_NAME)
|
|
|
print("Step 1: Loading base model skeleton...")
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
os.path.join(BASE_EXPERT_PATH, EXPERT_DIRS[0]),
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
device_map=device
|
|
|
)
|
|
|
print("Step 2: Pre-loading all expert weights into CPU memory for efficiency...")
|
|
|
all_experts_state_dicts = [
|
|
|
load_file(os.path.join(BASE_EXPERT_PATH, expert_dir, 'model.safetensors'), device='cpu')
|
|
|
for expert_dir in EXPERT_DIRS
|
|
|
]
|
|
|
print("All expert weights pre-loaded.")
|
|
|
print("Step 3: Replacing FFNs with MoE modules and transplanting expert weights...")
|
|
|
for layer_idx, layer in enumerate(base_model.model.layers):
|
|
|
layer.mlp = MoEModule(config).to(device, dtype=torch.bfloat16)
|
|
|
for expert_idx in range(NUM_EXPERTS):
|
|
|
expert_state_dict = all_experts_state_dicts[expert_idx]
|
|
|
expert_mlp_weights = {
|
|
|
k.replace(f"model.layers.{layer_idx}.mlp.", ""): v
|
|
|
for k, v in expert_state_dict.items()
|
|
|
if f"model.layers.{layer_idx}.mlp." in k
|
|
|
}
|
|
|
layer.mlp.experts[expert_idx].load_state_dict(expert_mlp_weights)
|
|
|
print("Step 4: Freezing all parameters except for the routers...")
|
|
|
for name, param in base_model.named_parameters():
|
|
|
if "router" not in name:
|
|
|
param.requires_grad = False
|
|
|
print("\n--- Surgery Complete! MoE Model is assembled and ready for training. ---")
|
|
|
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
|
|
|
total_params = sum(p.numel() for p in base_model.parameters())
|
|
|
print(f"Total Parameters: {total_params / 1e6:.2f}M")
|
|
|
print(f"Trainable Parameters (Routers): {trainable_params}")
|
|
|
return base_model
|
|
|
|
|
|
|
|
|
def main():
|
|
|
moe_model = create_moe_model()
|
|
|
optimizer = optim.AdamW([p for p in moe_model.parameters() if p.requires_grad], lr=LEARNING_RATE)
|
|
|
print("\n--- Preparing Simulated Mixed Dataset for Training ---")
|
|
|
mock_input_ids = torch.randint(0, moe_model.config.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH), device=device)
|
|
|
mock_labels = mock_input_ids.clone()
|
|
|
print("--- Starting Router Training Loop (Optimized & Corrected) ---")
|
|
|
moe_model.train()
|
|
|
start_time = time.time()
|
|
|
for epoch in range(EPOCHS):
|
|
|
optimizer.zero_grad()
|
|
|
outputs = moe_model(input_ids=mock_input_ids, labels=mock_labels)
|
|
|
main_loss = outputs.loss
|
|
|
total_lb_loss = 0.0
|
|
|
for layer in moe_model.model.layers:
|
|
|
total_lb_loss += layer.mlp.most_recent_lb_loss
|
|
|
total_loss = main_loss + LB_LOSS_COEFFICIENT * total_lb_loss
|
|
|
total_loss.backward()
|
|
|
optimizer.step()
|
|
|
if (epoch + 1) % 10 == 0:
|
|
|
elapsed_time = time.time() - start_time
|
|
|
print(f"Epoch [{epoch+1:03d}/{EPOCHS}] | Total Loss: {total_loss.item():.4f} | "
|
|
|
f"Main Loss: {main_loss.item():.4f} | "
|
|
|
f"Avg LB Loss: {(total_lb_loss.item() / moe_model.config.num_hidden_layers):.4f} | "
|
|
|
f"Time: {elapsed_time:.2f}s")
|
|
|
start_time = time.time()
|
|
|
print("\n--- Router Training Complete! ---")
|
|
|
print("\n--- Phase 5: Saving the fully trained MoE model to disk ---")
|
|
|
OUTPUT_MODEL_DIR = "./SmolMoE-8x135M-Instruct-v1-Trained"
|
|
|
if os.path.exists(OUTPUT_MODEL_DIR):
|
|
|
shutil.rmtree(OUTPUT_MODEL_DIR)
|
|
|
os.makedirs(OUTPUT_MODEL_DIR)
|
|
|
print("Updating model config with MoE-specific parameters...")
|
|
|
moe_model.config.moe_num_experts = NUM_EXPERTS
|
|
|
moe_model.config.moe_top_k = TOP_K
|
|
|
print(f"Saving model to '{OUTPUT_MODEL_DIR}'...")
|
|
|
moe_model.save_pretrained(OUTPUT_MODEL_DIR)
|
|
|
print("Saving tokenizer...")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
tokenizer.save_pretrained(OUTPUT_MODEL_DIR)
|
|
|
print("\n--- Model successfully saved! ---")
|
|
|
print("You can now load this model in other scripts, but you must re-define the custom MoE classes first.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |