Update modeling_klear.py
Browse files- modeling_klear.py +2 -77
modeling_klear.py
CHANGED
|
@@ -552,73 +552,6 @@ class KlearModel(KlearPreTrainedModel):
|
|
| 552 |
)
|
| 553 |
|
| 554 |
|
| 555 |
-
def load_balancing_loss_func(
|
| 556 |
-
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
|
| 557 |
-
num_experts: Optional[int] = None,
|
| 558 |
-
top_k: int = 2,
|
| 559 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 560 |
-
moe_aux_loss_coeff: float = 1,
|
| 561 |
-
) -> torch.Tensor:
|
| 562 |
-
"""
|
| 563 |
-
Computes sequence-level auxiliary load balancing loss for MoE gating.
|
| 564 |
-
|
| 565 |
-
Args:
|
| 566 |
-
gate_logits: Tensor of shape [batch_size, seq_len, num_experts]
|
| 567 |
-
or a tuple of such tensors (for multiple towers).
|
| 568 |
-
num_experts: Number of experts (inferred from gate_logits if None).
|
| 569 |
-
top_k: Number of top experts chosen per token.
|
| 570 |
-
attention_mask: Optional mask [batch_size, seq_len], 1 for valid tokens, 0 for padding.
|
| 571 |
-
moe_aux_loss_coeff: Scaling coefficient for the balancing loss.
|
| 572 |
-
|
| 573 |
-
Returns:
|
| 574 |
-
A scalar tensor representing the load balancing loss.
|
| 575 |
-
"""
|
| 576 |
-
# Merge towers if provided
|
| 577 |
-
if isinstance(gate_logits, tuple):
|
| 578 |
-
gate_logits = torch.cat(gate_logits, dim=0)
|
| 579 |
-
|
| 580 |
-
assert gate_logits is not None, "gate_logits must be provided"
|
| 581 |
-
batch_size, seq_len, n_experts = gate_logits.shape
|
| 582 |
-
num_experts = n_experts if num_experts is None else num_experts
|
| 583 |
-
assert num_experts == n_experts, f"num_experts ({num_experts}) != gate dimension ({n_experts})"
|
| 584 |
-
|
| 585 |
-
# Compute gating probabilities
|
| 586 |
-
gate_probs = F.softmax(gate_logits, dim=-1)
|
| 587 |
-
|
| 588 |
-
# Optionally mask padding tokens
|
| 589 |
-
if attention_mask is not None:
|
| 590 |
-
mask = attention_mask.float().unsqueeze(-1) # [batch, seq, 1]
|
| 591 |
-
else:
|
| 592 |
-
mask = torch.ones(batch_size, seq_len, 1, device=gate_logits.device)
|
| 593 |
-
|
| 594 |
-
# Select top_k experts per token
|
| 595 |
-
topk_vals, topk_idx = torch.topk(gate_probs, top_k, dim=-1) # both [batch, seq, top_k]
|
| 596 |
-
# Build one-hot mask of assignments
|
| 597 |
-
one_hot = F.one_hot(topk_idx, num_experts).float() # [batch, seq, top_k, num_experts]
|
| 598 |
-
# Sum along top_k to combine multiple choices
|
| 599 |
-
expert_mask = one_hot.sum(dim=2) # [batch, seq, num_experts]
|
| 600 |
-
|
| 601 |
-
# Apply token mask
|
| 602 |
-
expert_mask = expert_mask * mask # zeros out padding
|
| 603 |
-
gate_probs_masked = gate_probs * mask
|
| 604 |
-
|
| 605 |
-
# Normalizer: number of valid tokens per sample
|
| 606 |
-
tokens_per_sample = mask.sum(dim=1).clamp(min=1.0) # [batch, 1]
|
| 607 |
-
|
| 608 |
-
# Sequence-level tokens per expert: fraction of tokens routed to each expert per sample
|
| 609 |
-
tokens_per_expert = expert_mask.sum(dim=1).div_(tokens_per_sample * top_k / num_experts) # [batch, num_experts]
|
| 610 |
-
|
| 611 |
-
# Sequence-level average probability per expert per sample
|
| 612 |
-
router_prob_per_expert = gate_probs_masked.sum(dim=1).div(tokens_per_sample) # [batch, num_experts]
|
| 613 |
-
|
| 614 |
-
# Compute loss per sample: encourage uniform load
|
| 615 |
-
# Loss = sum_e (tokens_e * probs_e)
|
| 616 |
-
loss_per_sample = (tokens_per_expert * router_prob_per_expert).sum(dim=1) # [batch]
|
| 617 |
-
# Average across batch and scale
|
| 618 |
-
loss = moe_aux_loss_coeff * loss_per_sample.mean()
|
| 619 |
-
return loss
|
| 620 |
-
|
| 621 |
-
|
| 622 |
@auto_docstring
|
| 623 |
class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin):
|
| 624 |
_tied_weights_keys = ["lm_head.weight"]
|
|
@@ -720,16 +653,8 @@ class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin):
|
|
| 720 |
|
| 721 |
aux_loss = None
|
| 722 |
if output_router_logits:
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
self.num_experts,
|
| 726 |
-
self.num_experts_per_tok,
|
| 727 |
-
attention_mask,
|
| 728 |
-
self.moe_aux_loss_coeff,
|
| 729 |
-
)
|
| 730 |
-
if labels is not None:
|
| 731 |
-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 732 |
-
|
| 733 |
return MoeCausalLMOutputWithPast(
|
| 734 |
loss=loss,
|
| 735 |
aux_loss=aux_loss,
|
|
|
|
| 552 |
)
|
| 553 |
|
| 554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
@auto_docstring
|
| 556 |
class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin):
|
| 557 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
| 653 |
|
| 654 |
aux_loss = None
|
| 655 |
if output_router_logits:
|
| 656 |
+
pass
|
| 657 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
return MoeCausalLMOutputWithPast(
|
| 659 |
loss=loss,
|
| 660 |
aux_loss=aux_loss,
|