Chengyue Wu commited on
Commit
7eec723
·
1 Parent(s): 19930b4

update training

Browse files
Files changed (1) hide show
  1. modeling.py +129 -23
modeling.py CHANGED
@@ -36,6 +36,55 @@ class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
36
  block_past_key_values: Optional[Cache] = None
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
40
  # Compute block indices
41
  block_q = q_idx // block_size
@@ -180,20 +229,24 @@ class Fast_dLLM_QwenAttention(nn.Module):
180
  key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
181
  value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
182
 
183
- attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
184
-
185
- attn_output, attn_weights = attention_interface(
186
- self,
187
- query_states,
188
- key_states,
189
- value_states,
190
- attention_mask,
191
- is_causal=False,
192
- dropout=0.0 if not self.training else self.attention_dropout,
193
- scaling=self.scaling,
194
- sliding_window=self.sliding_window, # main diff with Llama
195
- **kwargs,
196
- )
 
 
 
 
197
 
198
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
199
  attn_output = self.o_proj(attn_output)
@@ -376,6 +429,13 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
376
  )
377
  return mask
378
 
 
 
 
 
 
 
 
379
  def forward(
380
  self,
381
  input_ids: Optional[torch.LongTensor] = None,
@@ -407,23 +467,31 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
407
 
408
  if cache_position is None:
409
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
410
- if use_block_cache:
411
- block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
412
  cache_position = torch.arange(
413
- block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
414
  )
415
  else:
416
- cache_position = torch.arange(
417
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
418
- )
 
 
 
 
 
 
419
 
420
  if position_ids is None:
421
  position_ids = cache_position.unsqueeze(0)
422
 
423
- if use_block_cache and block_past_key_values.get_seq_length() != 0:
424
- attention_mask = None
425
  else:
426
- attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
 
 
 
427
 
428
  hidden_states = inputs_embeds
429
 
@@ -503,9 +571,45 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
503
  use_block_cache: Optional[bool] = False,
504
  block_past_key_values: Optional[Cache] = None,
505
  replace_position: Optional[int] = None,
 
506
  **kwargs
507
  ) -> CausalLMOutputWithPastAndBlockCache:
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  outputs: BaseModelOutputWithPastAndBlockCache = self.model(
510
  input_ids=input_ids,
511
  labels=labels,
@@ -524,6 +628,8 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
524
  )
525
 
526
  hidden_states = outputs.last_hidden_state
 
 
527
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
528
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
529
  logits = self.lm_head(hidden_states[:, slice_indices, :])
 
36
  block_past_key_values: Optional[Cache] = None
37
 
38
 
39
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
40
+ def fused_flex_attention(q, k, v, mask=None):
41
+ return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
42
+
43
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
44
+ """
45
+ Constructs the specialized block diffusion attention mask for training
46
+ composed of three masks:
47
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
48
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
49
+ - **Block Causal Mask (M_BC)**: Attention to update x0
50
+
51
+ Args:
52
+ b, h: Batch and head indices (ignored for mask logic).
53
+ q_idx, kv_idx: Query and Key indices.
54
+ seq_len: Total sequence length.
55
+ block_size: Defines the block structure.
56
+
57
+ Returns:
58
+ A boolean attention mask.
59
+ """
60
+ # Indicate whether token belongs to xt or x0
61
+ x0_flag_q = (q_idx >= n)
62
+ x0_flag_kv = (kv_idx >= n)
63
+
64
+ # Compute block indices
65
+ block_q = torch.where(x0_flag_q == 1,
66
+ (q_idx - n) // block_size,
67
+ q_idx // block_size)
68
+ block_kv = torch.where(x0_flag_kv == 1,
69
+ (kv_idx - n) // block_size,
70
+ kv_idx // block_size)
71
+
72
+ # **1. Block Diagonal Mask (M_BD) **
73
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
74
+
75
+ # **2. Offset Block-Causal Mask (M_OBC) **
76
+ offset_block_causal = (
77
+ (block_q > block_kv)
78
+ & (x0_flag_kv == 1)
79
+ & (x0_flag_q == 0)
80
+ )
81
+
82
+ # **3. Block-Causal Mask (M_BC) **
83
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
84
+
85
+ # **4. Combine Masks **
86
+ return block_diagonal | offset_block_causal | block_causal
87
+
88
  def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
89
  # Compute block indices
90
  block_q = q_idx // block_size
 
229
  key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
230
  value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
231
 
232
+ if self.training:
233
+ attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask)
234
+ attn_output = attn_output.transpose(1, 2).contiguous()
235
+ else:
236
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
237
+
238
+ attn_output, attn_weights = attention_interface(
239
+ self,
240
+ query_states,
241
+ key_states,
242
+ value_states,
243
+ attention_mask,
244
+ is_causal=False,
245
+ dropout=0.0 if not self.training else self.attention_dropout,
246
+ scaling=self.scaling,
247
+ sliding_window=self.sliding_window, # main diff with Llama
248
+ **kwargs,
249
+ )
250
 
251
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
252
  attn_output = self.o_proj(attn_output)
 
429
  )
430
  return mask
431
 
432
+ def gen_mask(self, seqlen, block_size, B, H):
433
+ mask = create_block_mask(
434
+ partial(block_diff_mask, block_size=block_size, n=seqlen),
435
+ B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
436
+
437
+ return mask
438
+
439
  def forward(
440
  self,
441
  input_ids: Optional[torch.LongTensor] = None,
 
467
 
468
  if cache_position is None:
469
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
470
+ if self.training:
 
471
  cache_position = torch.arange(
472
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device
473
  )
474
  else:
475
+ if use_block_cache:
476
+ block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
477
+ cache_position = torch.arange(
478
+ block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
479
+ )
480
+ else:
481
+ cache_position = torch.arange(
482
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
483
+ )
484
 
485
  if position_ids is None:
486
  position_ids = cache_position.unsqueeze(0)
487
 
488
+ if self.training:
489
+ attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
490
  else:
491
+ if use_block_cache and block_past_key_values.get_seq_length() != 0:
492
+ attention_mask = None
493
+ else:
494
+ attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
495
 
496
  hidden_states = inputs_embeds
497
 
 
571
  use_block_cache: Optional[bool] = False,
572
  block_past_key_values: Optional[Cache] = None,
573
  replace_position: Optional[int] = None,
574
+ mask_id: Optional[int] = 151665,
575
  **kwargs
576
  ) -> CausalLMOutputWithPastAndBlockCache:
577
 
578
+ if self.training:
579
+ original_labels = labels.clone()
580
+ original_input_ids = input_ids.clone()
581
+
582
+ noisy_input_ids = input_ids.clone()
583
+
584
+ input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
585
+ b, l = input_ids.shape
586
+ t = torch.rand((b,), device=input_ids.device)
587
+ eps=1e-3
588
+ p_mask = (1 - eps) * t + eps
589
+ p_mask = p_mask[:, None].repeat(1, l)
590
+
591
+ mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
592
+ x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape)
593
+ noisy_input_ids[labels != -100] = x_t[labels != -100]
594
+ mask = (noisy_input_ids != mask_id)
595
+ labels[mask] = -100
596
+ input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1)
597
+
598
+ complementary_noisy_input_ids = original_input_ids.clone()
599
+ complementary_labels = original_labels.clone()
600
+
601
+ complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
602
+
603
+ complementary_mask_indices = ~mask_indices
604
+ complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape)
605
+ complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100]
606
+ complementary_mask = (complementary_noisy_input_ids != mask_id)
607
+ complementary_labels[complementary_mask] = -100
608
+ complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1)
609
+
610
+ input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
611
+ labels = torch.cat([labels, complementary_labels], dim=0)
612
+
613
  outputs: BaseModelOutputWithPastAndBlockCache = self.model(
614
  input_ids=input_ids,
615
  labels=labels,
 
628
  )
629
 
630
  hidden_states = outputs.last_hidden_state
631
+ if self.training:
632
+ hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
633
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
634
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
635
  logits = self.lm_head(hidden_states[:, slice_indices, :])