Spaces:
Running
Running
| # /// script | |
| # dependencies = [ | |
| # "torch", | |
| # "numpy", | |
| # ] | |
| # /// | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from utils import to_dtype, tensor_stats, set_seed, bench_context | |
| from config import ( | |
| NUM_EXPERTS, HIDDEN_SIZE, TOP_K, | |
| BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, | |
| WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED | |
| ) | |
| from pathlib import Path | |
| import os | |
| # Discover the upstream artifact directory from env | |
| data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') | |
| # list all the files in the directory | |
| print(f"Loading weights from: {data_dir}") | |
| print(f"Files in directory: {list(Path(data_dir).glob('*'))}") | |
| router_weight = torch.load(Path(data_dir) / 'router_weight.pt') | |
| router_bias = torch.load(Path(data_dir) / 'router_bias.pt') | |
| gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') | |
| gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') | |
| down_proj = torch.load(Path(data_dir) / 'down_proj.pt') | |
| down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') | |
| print("Loaded shared weights from artifacts") | |
| print(f"Router weight sum: {router_weight.sum().item():.6f}") | |
| print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") | |
| print(f"Down sum: {down_proj.sum().item():.6f}") | |
| class GptOssRouter(nn.Module): | |
| def __init__(self, router_weight, router_bias): | |
| super().__init__() | |
| self.top_k = TOP_K | |
| self.num_experts = NUM_EXPERTS | |
| self.hidden_dim = HIDDEN_SIZE | |
| self.weight = nn.Parameter(router_weight.clone()) | |
| self.bias = nn.Parameter(router_bias.clone()) | |
| def forward(self, hidden_states): | |
| hidden_states = hidden_states.reshape(-1, self.hidden_dim) | |
| router_logits = F.linear(hidden_states, self.weight, self.bias) | |
| router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) | |
| router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) | |
| router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | |
| return router_scores, router_indices | |
| class GptOssExperts(nn.Module): | |
| def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): | |
| super().__init__() | |
| self.num_experts = NUM_EXPERTS | |
| self.hidden_size = HIDDEN_SIZE | |
| self.expert_dim = self.hidden_size | |
| self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) | |
| self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) | |
| self.down_proj = nn.Parameter(down_proj.clone()) | |
| self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) | |
| self.alpha = 1.702 | |
| self.limit = 7.0 | |
| def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: | |
| batch_size = hidden_states.shape[0] | |
| hidden_states = hidden_states.reshape(-1, self.hidden_size) | |
| num_experts = routing_weights.shape[1] | |
| if hidden_states.device.type == "cpu" or self.training: | |
| next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) | |
| with torch.no_grad(): | |
| expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) | |
| expert_mask = expert_mask.permute(2, 1, 0) | |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | |
| for expert_idx in expert_hit[:]: | |
| expert_idx = expert_idx[0] | |
| with torch.no_grad(): | |
| _, token_idx = torch.where(expert_mask[expert_idx]) | |
| current_state = hidden_states[token_idx] | |
| gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] | |
| gate, up = gate_up[..., ::2], gate_up[..., 1::2] | |
| gate = gate.clamp(min=None, max=self.limit) | |
| up = up.clamp(min=-self.limit, max=self.limit) | |
| glu = gate * torch.sigmoid(gate * self.alpha) | |
| gated_output = (up + 1) * glu | |
| out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] | |
| weighted_output = out * routing_weights[token_idx, expert_idx, None] | |
| next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) | |
| next_states = next_states.view(batch_size, -1, self.hidden_size) | |
| else: | |
| hidden_states = hidden_states.repeat(num_experts, 1) | |
| hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) | |
| gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] | |
| gate, up = gate_up[..., ::2], gate_up[..., 1::2] | |
| gate = gate.clamp(min=None, max=self.limit) | |
| up = up.clamp(min=-self.limit, max=self.limit) | |
| glu = gate * torch.sigmoid(gate * self.alpha) | |
| next_states = torch.bmm(((up + 1) * glu), self.down_proj) | |
| next_states = next_states + self.down_proj_bias[..., None, :] | |
| next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) | |
| next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] | |
| next_states = next_states.sum(dim=0) | |
| return next_states | |
| class GptOssMoEMLP(nn.Module): | |
| def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): | |
| super().__init__() | |
| self.router = GptOssRouter(router_weight, router_bias) | |
| self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) | |
| def forward(self, hidden_states): | |
| router_scores, router_indices = self.router(hidden_states) | |
| routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) | |
| return routed_out, router_scores | |
| # Run the model | |
| set_seed(GENERAL_SEED) | |
| device = torch.device(DEVICE) | |
| dtype = to_dtype(DTYPE) | |
| print("\n=== GPT-OSS Implementation ===") | |
| # Initialize model with loaded weights | |
| model = GptOssMoEMLP( | |
| router_weight.to(device, dtype=dtype), | |
| router_bias.to(device, dtype=dtype), | |
| gate_up_proj.to(device, dtype=dtype), | |
| gate_up_proj_bias.to(device, dtype=dtype), | |
| down_proj.to(device, dtype=dtype), | |
| down_proj_bias.to(device, dtype=dtype) | |
| ).to(device=device, dtype=dtype) | |
| print(f"Router weight sum: {model.router.weight.sum().item():.6f}") | |
| print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") | |
| print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") | |
| # Benchmark the model using different input tensors on each iteration | |
| tokens = BATCH_SIZE * SEQ_LEN | |
| input_shape = (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE) | |
| with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, | |
| save_json="gptoss_results.json", input_shape=input_shape, input_seed_base=INPUT_SEED) as bench: | |
| output, stats = bench(model) | |
| print(f"\nOutput sum: {output[0].sum().item():.6f}") |