BubbleQ commited on
Commit
e8cd9b8
·
verified ·
1 Parent(s): 6091cc6

Update modeling_klear.py

Browse files
Files changed (1) hide show
  1. modeling_klear.py +2 -77
modeling_klear.py CHANGED
@@ -552,73 +552,6 @@ class KlearModel(KlearPreTrainedModel):
552
  )
553
 
554
 
555
- def load_balancing_loss_func(
556
- gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
557
- num_experts: Optional[int] = None,
558
- top_k: int = 2,
559
- attention_mask: Optional[torch.Tensor] = None,
560
- moe_aux_loss_coeff: float = 1,
561
- ) -> torch.Tensor:
562
- """
563
- Computes sequence-level auxiliary load balancing loss for MoE gating.
564
-
565
- Args:
566
- gate_logits: Tensor of shape [batch_size, seq_len, num_experts]
567
- or a tuple of such tensors (for multiple towers).
568
- num_experts: Number of experts (inferred from gate_logits if None).
569
- top_k: Number of top experts chosen per token.
570
- attention_mask: Optional mask [batch_size, seq_len], 1 for valid tokens, 0 for padding.
571
- moe_aux_loss_coeff: Scaling coefficient for the balancing loss.
572
-
573
- Returns:
574
- A scalar tensor representing the load balancing loss.
575
- """
576
- # Merge towers if provided
577
- if isinstance(gate_logits, tuple):
578
- gate_logits = torch.cat(gate_logits, dim=0)
579
-
580
- assert gate_logits is not None, "gate_logits must be provided"
581
- batch_size, seq_len, n_experts = gate_logits.shape
582
- num_experts = n_experts if num_experts is None else num_experts
583
- assert num_experts == n_experts, f"num_experts ({num_experts}) != gate dimension ({n_experts})"
584
-
585
- # Compute gating probabilities
586
- gate_probs = F.softmax(gate_logits, dim=-1)
587
-
588
- # Optionally mask padding tokens
589
- if attention_mask is not None:
590
- mask = attention_mask.float().unsqueeze(-1) # [batch, seq, 1]
591
- else:
592
- mask = torch.ones(batch_size, seq_len, 1, device=gate_logits.device)
593
-
594
- # Select top_k experts per token
595
- topk_vals, topk_idx = torch.topk(gate_probs, top_k, dim=-1) # both [batch, seq, top_k]
596
- # Build one-hot mask of assignments
597
- one_hot = F.one_hot(topk_idx, num_experts).float() # [batch, seq, top_k, num_experts]
598
- # Sum along top_k to combine multiple choices
599
- expert_mask = one_hot.sum(dim=2) # [batch, seq, num_experts]
600
-
601
- # Apply token mask
602
- expert_mask = expert_mask * mask # zeros out padding
603
- gate_probs_masked = gate_probs * mask
604
-
605
- # Normalizer: number of valid tokens per sample
606
- tokens_per_sample = mask.sum(dim=1).clamp(min=1.0) # [batch, 1]
607
-
608
- # Sequence-level tokens per expert: fraction of tokens routed to each expert per sample
609
- tokens_per_expert = expert_mask.sum(dim=1).div_(tokens_per_sample * top_k / num_experts) # [batch, num_experts]
610
-
611
- # Sequence-level average probability per expert per sample
612
- router_prob_per_expert = gate_probs_masked.sum(dim=1).div(tokens_per_sample) # [batch, num_experts]
613
-
614
- # Compute loss per sample: encourage uniform load
615
- # Loss = sum_e (tokens_e * probs_e)
616
- loss_per_sample = (tokens_per_expert * router_prob_per_expert).sum(dim=1) # [batch]
617
- # Average across batch and scale
618
- loss = moe_aux_loss_coeff * loss_per_sample.mean()
619
- return loss
620
-
621
-
622
  @auto_docstring
623
  class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin):
624
  _tied_weights_keys = ["lm_head.weight"]
@@ -720,16 +653,8 @@ class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin):
720
 
721
  aux_loss = None
722
  if output_router_logits:
723
- aux_loss = load_balancing_loss_func(
724
- outputs.router_logits,
725
- self.num_experts,
726
- self.num_experts_per_tok,
727
- attention_mask,
728
- self.moe_aux_loss_coeff,
729
- )
730
- if labels is not None:
731
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
732
-
733
  return MoeCausalLMOutputWithPast(
734
  loss=loss,
735
  aux_loss=aux_loss,
 
552
  )
553
 
554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  @auto_docstring
556
  class KlearMoeForCausalLM(KlearPreTrainedModel, GenerationMixin):
557
  _tied_weights_keys = ["lm_head.weight"]
 
653
 
654
  aux_loss = None
655
  if output_router_logits:
656
+ pass
657
+
 
 
 
 
 
 
 
 
658
  return MoeCausalLMOutputWithPast(
659
  loss=loss,
660
  aux_loss=aux_loss,