Update modeling_c_cubed.py
Browse files- modeling_c_cubed.py +27 -1
modeling_c_cubed.py
CHANGED
|
@@ -707,4 +707,30 @@ class CcubedForConditionalGeneration(CcubedPreTrainedModel):
|
|
| 707 |
hidden_states=outputs.hidden_states,
|
| 708 |
attentions=outputs.attentions,
|
| 709 |
context_hidden_states=context_features,
|
| 710 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
hidden_states=outputs.hidden_states,
|
| 708 |
attentions=outputs.attentions,
|
| 709 |
context_hidden_states=context_features,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def prepare_inputs_for_generation(
|
| 713 |
+
self,
|
| 714 |
+
input_ids,
|
| 715 |
+
inputs_embeds=None,
|
| 716 |
+
past_key_values=None,
|
| 717 |
+
attention_mask=None,
|
| 718 |
+
context_attention_mask=None,
|
| 719 |
+
**kwargs
|
| 720 |
+
):
|
| 721 |
+
if past_key_values:
|
| 722 |
+
input_ids = input_ids[:, -1:]
|
| 723 |
+
|
| 724 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 725 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 726 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 727 |
+
else:
|
| 728 |
+
model_inputs = {"input_ids": input_ids}
|
| 729 |
+
|
| 730 |
+
model_inputs.update({
|
| 731 |
+
"past_key_values": past_key_values,
|
| 732 |
+
"use_cache": kwargs.get("use_cache"),
|
| 733 |
+
"attention_mask": attention_mask,
|
| 734 |
+
"context_attention_mask": context_attention_mask
|
| 735 |
+
})
|
| 736 |
+
return model_inputs
|