| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from .args import MoeArgs | |
| class MoeLayer(nn.Module): | |
| def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): | |
| super().__init__() | |
| assert len(experts) > 0 | |
| self.experts = nn.ModuleList(experts) | |
| self.gate = gate | |
| self.args = moe_args | |
| def forward(self, inputs: torch.Tensor): | |
| gate_logits = self.gate(inputs) | |
| weights, selected_experts = torch.topk( | |
| gate_logits, self.args.num_experts_per_tok | |
| ) | |
| weights = torch.nn.functional.softmax(weights, dim=1, dtype=torch.float).to( | |
| inputs.dtype | |
| ) | |
| results = torch.zeros_like(inputs) | |
| for i, expert in enumerate(self.experts): | |
| batch_idx, nth_expert = torch.where(selected_experts == i) | |
| results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( | |
| inputs[batch_idx] | |
| ) | |
| return results | |