Chengyue Wu
commited on
Commit
·
7eec723
1
Parent(s):
19930b4
update training
Browse files- modeling.py +129 -23
modeling.py
CHANGED
|
@@ -36,6 +36,55 @@ class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
|
|
| 36 |
block_past_key_values: Optional[Cache] = None
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
| 40 |
# Compute block indices
|
| 41 |
block_q = q_idx // block_size
|
|
@@ -180,20 +229,24 @@ class Fast_dLLM_QwenAttention(nn.Module):
|
|
| 180 |
key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
|
| 181 |
value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 199 |
attn_output = self.o_proj(attn_output)
|
|
@@ -376,6 +429,13 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
|
|
| 376 |
)
|
| 377 |
return mask
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
def forward(
|
| 380 |
self,
|
| 381 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -407,23 +467,31 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
|
|
| 407 |
|
| 408 |
if cache_position is None:
|
| 409 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 410 |
-
if
|
| 411 |
-
block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
|
| 412 |
cache_position = torch.arange(
|
| 413 |
-
|
| 414 |
)
|
| 415 |
else:
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
if position_ids is None:
|
| 421 |
position_ids = cache_position.unsqueeze(0)
|
| 422 |
|
| 423 |
-
if
|
| 424 |
-
attention_mask =
|
| 425 |
else:
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
hidden_states = inputs_embeds
|
| 429 |
|
|
@@ -503,9 +571,45 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 503 |
use_block_cache: Optional[bool] = False,
|
| 504 |
block_past_key_values: Optional[Cache] = None,
|
| 505 |
replace_position: Optional[int] = None,
|
|
|
|
| 506 |
**kwargs
|
| 507 |
) -> CausalLMOutputWithPastAndBlockCache:
|
| 508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
outputs: BaseModelOutputWithPastAndBlockCache = self.model(
|
| 510 |
input_ids=input_ids,
|
| 511 |
labels=labels,
|
|
@@ -524,6 +628,8 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 524 |
)
|
| 525 |
|
| 526 |
hidden_states = outputs.last_hidden_state
|
|
|
|
|
|
|
| 527 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 528 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 529 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
| 36 |
block_past_key_values: Optional[Cache] = None
|
| 37 |
|
| 38 |
|
| 39 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 40 |
+
def fused_flex_attention(q, k, v, mask=None):
|
| 41 |
+
return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
|
| 42 |
+
|
| 43 |
+
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
| 44 |
+
"""
|
| 45 |
+
Constructs the specialized block diffusion attention mask for training
|
| 46 |
+
composed of three masks:
|
| 47 |
+
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 48 |
+
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 49 |
+
- **Block Causal Mask (M_BC)**: Attention to update x0
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
b, h: Batch and head indices (ignored for mask logic).
|
| 53 |
+
q_idx, kv_idx: Query and Key indices.
|
| 54 |
+
seq_len: Total sequence length.
|
| 55 |
+
block_size: Defines the block structure.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
A boolean attention mask.
|
| 59 |
+
"""
|
| 60 |
+
# Indicate whether token belongs to xt or x0
|
| 61 |
+
x0_flag_q = (q_idx >= n)
|
| 62 |
+
x0_flag_kv = (kv_idx >= n)
|
| 63 |
+
|
| 64 |
+
# Compute block indices
|
| 65 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 66 |
+
(q_idx - n) // block_size,
|
| 67 |
+
q_idx // block_size)
|
| 68 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 69 |
+
(kv_idx - n) // block_size,
|
| 70 |
+
kv_idx // block_size)
|
| 71 |
+
|
| 72 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 73 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 74 |
+
|
| 75 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 76 |
+
offset_block_causal = (
|
| 77 |
+
(block_q > block_kv)
|
| 78 |
+
& (x0_flag_kv == 1)
|
| 79 |
+
& (x0_flag_q == 0)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# **3. Block-Causal Mask (M_BC) **
|
| 83 |
+
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 84 |
+
|
| 85 |
+
# **4. Combine Masks **
|
| 86 |
+
return block_diagonal | offset_block_causal | block_causal
|
| 87 |
+
|
| 88 |
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
| 89 |
# Compute block indices
|
| 90 |
block_q = q_idx // block_size
|
|
|
|
| 229 |
key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
|
| 230 |
value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
|
| 231 |
|
| 232 |
+
if self.training:
|
| 233 |
+
attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask)
|
| 234 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 235 |
+
else:
|
| 236 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
| 237 |
+
|
| 238 |
+
attn_output, attn_weights = attention_interface(
|
| 239 |
+
self,
|
| 240 |
+
query_states,
|
| 241 |
+
key_states,
|
| 242 |
+
value_states,
|
| 243 |
+
attention_mask,
|
| 244 |
+
is_causal=False,
|
| 245 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 246 |
+
scaling=self.scaling,
|
| 247 |
+
sliding_window=self.sliding_window, # main diff with Llama
|
| 248 |
+
**kwargs,
|
| 249 |
+
)
|
| 250 |
|
| 251 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 252 |
attn_output = self.o_proj(attn_output)
|
|
|
|
| 429 |
)
|
| 430 |
return mask
|
| 431 |
|
| 432 |
+
def gen_mask(self, seqlen, block_size, B, H):
|
| 433 |
+
mask = create_block_mask(
|
| 434 |
+
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
| 435 |
+
B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
| 436 |
+
|
| 437 |
+
return mask
|
| 438 |
+
|
| 439 |
def forward(
|
| 440 |
self,
|
| 441 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 467 |
|
| 468 |
if cache_position is None:
|
| 469 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 470 |
+
if self.training:
|
|
|
|
| 471 |
cache_position = torch.arange(
|
| 472 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device
|
| 473 |
)
|
| 474 |
else:
|
| 475 |
+
if use_block_cache:
|
| 476 |
+
block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
|
| 477 |
+
cache_position = torch.arange(
|
| 478 |
+
block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
cache_position = torch.arange(
|
| 482 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
|
| 483 |
+
)
|
| 484 |
|
| 485 |
if position_ids is None:
|
| 486 |
position_ids = cache_position.unsqueeze(0)
|
| 487 |
|
| 488 |
+
if self.training:
|
| 489 |
+
attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
|
| 490 |
else:
|
| 491 |
+
if use_block_cache and block_past_key_values.get_seq_length() != 0:
|
| 492 |
+
attention_mask = None
|
| 493 |
+
else:
|
| 494 |
+
attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
|
| 495 |
|
| 496 |
hidden_states = inputs_embeds
|
| 497 |
|
|
|
|
| 571 |
use_block_cache: Optional[bool] = False,
|
| 572 |
block_past_key_values: Optional[Cache] = None,
|
| 573 |
replace_position: Optional[int] = None,
|
| 574 |
+
mask_id: Optional[int] = 151665,
|
| 575 |
**kwargs
|
| 576 |
) -> CausalLMOutputWithPastAndBlockCache:
|
| 577 |
|
| 578 |
+
if self.training:
|
| 579 |
+
original_labels = labels.clone()
|
| 580 |
+
original_input_ids = input_ids.clone()
|
| 581 |
+
|
| 582 |
+
noisy_input_ids = input_ids.clone()
|
| 583 |
+
|
| 584 |
+
input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
|
| 585 |
+
b, l = input_ids.shape
|
| 586 |
+
t = torch.rand((b,), device=input_ids.device)
|
| 587 |
+
eps=1e-3
|
| 588 |
+
p_mask = (1 - eps) * t + eps
|
| 589 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 590 |
+
|
| 591 |
+
mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
| 592 |
+
x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape)
|
| 593 |
+
noisy_input_ids[labels != -100] = x_t[labels != -100]
|
| 594 |
+
mask = (noisy_input_ids != mask_id)
|
| 595 |
+
labels[mask] = -100
|
| 596 |
+
input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1)
|
| 597 |
+
|
| 598 |
+
complementary_noisy_input_ids = original_input_ids.clone()
|
| 599 |
+
complementary_labels = original_labels.clone()
|
| 600 |
+
|
| 601 |
+
complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
|
| 602 |
+
|
| 603 |
+
complementary_mask_indices = ~mask_indices
|
| 604 |
+
complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape)
|
| 605 |
+
complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100]
|
| 606 |
+
complementary_mask = (complementary_noisy_input_ids != mask_id)
|
| 607 |
+
complementary_labels[complementary_mask] = -100
|
| 608 |
+
complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1)
|
| 609 |
+
|
| 610 |
+
input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
|
| 611 |
+
labels = torch.cat([labels, complementary_labels], dim=0)
|
| 612 |
+
|
| 613 |
outputs: BaseModelOutputWithPastAndBlockCache = self.model(
|
| 614 |
input_ids=input_ids,
|
| 615 |
labels=labels,
|
|
|
|
| 628 |
)
|
| 629 |
|
| 630 |
hidden_states = outputs.last_hidden_state
|
| 631 |
+
if self.training:
|
| 632 |
+
hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
|
| 633 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 634 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 635 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|