update llama_action model
Browse files- modeling_llama_action.py +4 -14
modeling_llama_action.py
CHANGED
|
@@ -200,29 +200,19 @@ class LlamaActionForCausalLM(LlamaForCausalLM):
|
|
| 200 |
past_key_values=None,
|
| 201 |
attention_mask=None,
|
| 202 |
use_cache=None,
|
| 203 |
-
|
| 204 |
-
prefix="",
|
| 205 |
-
total=0,
|
| 206 |
**kwargs):
|
| 207 |
batch_size = input_ids.size(0)
|
| 208 |
seq_length = input_ids.size(1)
|
| 209 |
n_frames = seq_length // self.num_image_patches
|
| 210 |
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
|
| 211 |
-
if
|
| 212 |
-
|
| 213 |
-
pbar = tqdm(total=total - len(input_ids[0]), desc=prefix, leave=False)
|
| 214 |
-
postfix = f"Frame [{n_frames + 1}/{total // self.num_image_patches}]"
|
| 215 |
-
pbar.set_postfix_str(postfix)
|
| 216 |
-
else:
|
| 217 |
-
pbar.update()
|
| 218 |
|
| 219 |
if seq_length % self.num_image_patches != 0:
|
| 220 |
n_last_frame_tokens = seq_length % self.num_image_patches
|
| 221 |
attention_mask_length += n_last_frame_tokens
|
| 222 |
-
|
| 223 |
-
if show_progress:
|
| 224 |
-
postfix = f"Frame [{n_frames + 1}/{total // self.num_image_patches}]"
|
| 225 |
-
pbar.set_postfix_str(postfix)
|
| 226 |
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
|
| 227 |
# cut decoder_input_ids if past_key_values is used
|
| 228 |
if past_key_values is not None and len(past_key_values) > 0:
|
|
|
|
| 200 |
past_key_values=None,
|
| 201 |
attention_mask=None,
|
| 202 |
use_cache=None,
|
| 203 |
+
progress_bar=None,
|
|
|
|
|
|
|
| 204 |
**kwargs):
|
| 205 |
batch_size = input_ids.size(0)
|
| 206 |
seq_length = input_ids.size(1)
|
| 207 |
n_frames = seq_length // self.num_image_patches
|
| 208 |
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
|
| 209 |
+
if progress_bar is not None:
|
| 210 |
+
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
if seq_length % self.num_image_patches != 0:
|
| 213 |
n_last_frame_tokens = seq_length % self.num_image_patches
|
| 214 |
attention_mask_length += n_last_frame_tokens
|
| 215 |
+
|
|
|
|
|
|
|
|
|
|
| 216 |
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
|
| 217 |
# cut decoder_input_ids if past_key_values is used
|
| 218 |
if past_key_values is not None and len(past_key_values) > 0:
|