AndyZijianZhang commited on
Commit
7d97786
·
verified ·
1 Parent(s): 6bde1de

Upload files with `vila-upload`.

Browse files

Upload processing_vila.py
Upload modeling_vila.py

Files changed (2) hide show
  1. modeling_vila.py +39 -35
  2. processing_vila.py +267 -207
modeling_vila.py CHANGED
@@ -1,9 +1,10 @@
1
- from typing import List, Optional, Type
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from torch import Tensor
 
7
  from transformers.configuration_utils import PretrainedConfig
8
  from transformers.generation.utils import GenerationMixin
9
  from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
@@ -55,23 +56,22 @@ class MultimodalProjector(nn.Module):
55
  ):
56
  super().__init__(*args, **kwargs)
57
 
58
- match config.mm_projector_type:
59
- case "mlp_downsample_3x3_fix":
60
- self.layers = nn.Sequential(
61
- DownSample3x3BlockFix(),
62
- nn.LayerNorm(config.mm_hidden_size * 9),
63
- nn.Linear(
64
- config.mm_hidden_size * 9,
65
- config.mm_hidden_size * 3,
66
- ),
67
- nn.GELU(),
68
- nn.LayerNorm(config.vision_config.hidden_size * 3),
69
- nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
70
- nn.GELU(),
71
- nn.Linear(config.hidden_size, config.hidden_size),
72
- )
73
- case _:
74
- raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}")
75
 
76
  self.layers.type(config.torch_dtype)
77
 
@@ -131,22 +131,29 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
131
  attention_mask: Optional[Tensor] = None,
132
  input_ids: Optional[Tensor] = None,
133
  inputs_embeds: Optional[Tensor] = None,
 
134
  pixel_values: Optional[Tensor] = None,
 
 
135
  **kwargs,
136
  ) -> CausalLMOutputWithPast:
137
- # Vision info is only used for prefilling.
138
- if kwargs.get("past_key_values", None) is not None:
139
- pixel_values = None
140
 
141
- if inputs_embeds is None:
142
- if input_ids is None:
143
- raise ValueError("input_ids is required when inputs_embeds is None")
144
-
145
- inputs_embeds = self._embed(input_ids, pixel_values)
146
 
147
  outputs = self.llm.__call__(
148
- inputs_embeds=inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype),
149
  attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None),
 
 
 
 
 
 
 
150
  **kwargs,
151
  )
152
 
@@ -208,10 +215,7 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
208
 
209
  selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer]
210
 
211
- match self.config.mm_vision_select_feature:
212
- case "cls_patch":
213
- return selected_layer_hidden_states
214
- case _:
215
- raise NotImplementedError(
216
- f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
217
- )
 
1
+ from typing import List, Optional, Type, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from torch import LongTensor, Tensor
7
+ from transformers.cache_utils import Cache
8
  from transformers.configuration_utils import PretrainedConfig
9
  from transformers.generation.utils import GenerationMixin
10
  from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
 
56
  ):
57
  super().__init__(*args, **kwargs)
58
 
59
+ if config.mm_projector_type == "mlp_downsample_3x3_fix":
60
+ self.layers = nn.Sequential(
61
+ DownSample3x3BlockFix(),
62
+ nn.LayerNorm(config.mm_hidden_size * 9),
63
+ nn.Linear(
64
+ config.mm_hidden_size * 9,
65
+ config.mm_hidden_size * 3,
66
+ ),
67
+ nn.GELU(),
68
+ nn.LayerNorm(config.vision_config.hidden_size * 3),
69
+ nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
70
+ nn.GELU(),
71
+ nn.Linear(config.hidden_size, config.hidden_size),
72
+ )
73
+ else:
74
+ raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}")
 
75
 
76
  self.layers.type(config.torch_dtype)
77
 
 
131
  attention_mask: Optional[Tensor] = None,
132
  input_ids: Optional[Tensor] = None,
133
  inputs_embeds: Optional[Tensor] = None,
134
+ past_key_values: Optional[Cache] = None,
135
  pixel_values: Optional[Tensor] = None,
136
+ position_ids: Optional[LongTensor] = None,
137
+ logits_to_keep: Union[int, Tensor] = 0,
138
  **kwargs,
139
  ) -> CausalLMOutputWithPast:
140
+ if (input_ids is None) ^ (inputs_embeds is not None):
141
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
 
142
 
143
+ if past_key_values is None: # Prefill
144
+ if input_ids is not None:
145
+ inputs_embeds = self._embed(input_ids, pixel_values)
146
+ input_ids = None
 
147
 
148
  outputs = self.llm.__call__(
 
149
  attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None),
150
+ input_ids=(input_ids.to(device=self.llm.device) if input_ids is not None else None),
151
+ inputs_embeds=(
152
+ inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype) if inputs_embeds is not None else None
153
+ ),
154
+ past_key_values=past_key_values,
155
+ position_ids=(position_ids.to(device=self.llm.device) if position_ids is not None else None),
156
+ logits_to_keep=logits_to_keep,
157
  **kwargs,
158
  )
159
 
 
215
 
216
  selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer]
217
 
218
+ if self.config.mm_vision_select_feature == "cls_patch":
219
+ return selected_layer_hidden_states
220
+ else:
221
+ raise NotImplementedError(f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}")
 
 
 
processing_vila.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List, Optional, Tuple, cast
2
 
3
  import transformers.image_transforms as image_transforms
@@ -14,7 +15,7 @@ from transformers.models.siglip.image_processing_siglip import SiglipImageProces
14
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
15
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
16
  from transformers.tokenization_utils import PreTrainedTokenizer
17
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
18
  from transformers.video_utils import VideoInput
19
 
20
  logger = transformers.utils.logging.get_logger(__name__)
@@ -83,7 +84,6 @@ class VILAProcessor(ProcessorMixin):
83
  text: TextInput | List[TextInput],
84
  images: Optional[ImageInput] = None,
85
  videos: Optional[VideoInput] = None,
86
- audio: None = None,
87
  **kwargs: Unpack[ProcessingKwargs],
88
  ) -> VILAProcessorOutput:
89
  """Preprocesses inputs for VILA.
@@ -92,7 +92,6 @@ class VILAProcessor(ProcessorMixin):
92
  text: The text to be processed.
93
  images: The images to be processed.
94
  videos: The videos to be processed.
95
- audio: Not available.
96
  **kwargs: Additional arguments for processing.
97
 
98
  Returns:
@@ -105,58 +104,32 @@ class VILAProcessor(ProcessorMixin):
105
  **kwargs,
106
  )
107
 
108
- prepared_text, prepared_images, prepared_videos = self._prepare_inputs(
109
  text=text,
110
  images=images,
111
  videos=videos,
112
  )
113
 
114
- # Process videos.
115
- prepared_text, prepared_images, video_flags = self._treat_videos_as_image_seqs(
116
- text=prepared_text,
117
- images=prepared_images,
118
- videos=prepared_videos,
119
  )
120
 
121
- # Process images.
122
- image_inputs, num_cropped_images = self._process_images(
123
- images=prepared_images,
124
- video_flags=video_flags,
125
- **merged_kwargs["images_kwargs"],
126
- )
127
-
128
- # Process text.
129
- prepared_text = self._pad_image_tokens_by_num_crops(
130
- prepared_text,
131
- num_cropped_images=num_cropped_images,
132
- video_flags=video_flags,
133
- )
134
-
135
- prepared_text = self._pad_image_tokens_by_num_embeddings(prepared_text)
136
-
137
  text_inputs = self.tokenizer.__call__(
138
- prepared_text,
139
  **merged_kwargs["text_kwargs"],
140
  )
141
 
142
- # Find the last image token of each image tile and replace to "\n".
143
- lf_token_id = self.tokenizer.encode("\n")[0]
144
- image_token_id = self.tokenizer.image_token_id
145
-
146
- for i in range(len(text_inputs.input_ids)):
147
- input_ids = text_inputs.input_ids[i]
148
-
149
- idx = 0
150
- while idx < len(input_ids):
151
- if input_ids[idx] != image_token_id:
152
- idx += 1
153
- continue
154
 
155
- if idx + self.image_pad_len < len(input_ids):
156
- input_ids[idx + self.image_pad_len] = lf_token_id
157
- idx += self.image_pad_len + 1
158
- else:
159
- break
160
 
161
  return VILAProcessorOutput(
162
  data={
@@ -165,119 +138,118 @@ class VILAProcessor(ProcessorMixin):
165
  }
166
  )
167
 
168
- def _crop_image(
169
- self,
170
- image: Image,
171
- *,
172
- is_video_frame: bool,
173
- ) -> List[Image]:
174
- """Crops the image into multiple tiles.
175
 
176
  Args:
177
- image: The image to be cropped.
178
-
179
- Returns:
180
- The cropped images.
181
- """
182
-
183
- # TODO: Support more image processors.
184
- if not isinstance(self.image_processor, (SiglipImageProcessor, SiglipImageProcessorFast)):
185
- raise NotImplementedError
186
-
187
- assert self.image_processor.size["height"] == self.image_processor.size["width"]
188
- cropped_size = self.image_processor.size["height"]
189
-
190
- cropped_images: List[Image] = dynamic_preprocess(
191
- image,
192
- min_num=self.min_tiles,
193
- max_num=self.max_tiles if not is_video_frame else self.video_max_tiles,
194
- image_size=cropped_size,
195
- )
196
-
197
- return cropped_images
198
-
199
- def _pad_image_tokens_by_num_crops(
200
- self,
201
- text: List[str],
202
- *,
203
- num_cropped_images: List[int],
204
- video_flags: List[bool],
205
- ) -> List[str]:
206
- """Pads each \\<image> to num_cropped_images of "\\<image>\\n" for images and "\\<video>" for videos.
207
-
208
- Args:
209
- text: The text to be padded.
210
- num_cropped_images: The number of cropped images for each image token.
211
- video_flags: A list of flags indicating whether the num_cropped_images item is a video.
212
 
213
  Returns:
214
- The padded text.
 
215
  """
216
 
217
- assert len(num_cropped_images) == len(
218
- video_flags
219
- ), "num_cropped_images and video_flags must have the same length."
220
-
221
- image_token: str = cast(str, self.tokenizer.image_token)
222
 
223
- return_text: List[str] = []
224
 
225
  for text_item in text:
226
- return_text_item: str = ""
227
-
228
- # Repeatedly find image_token in the text.
229
- while image_token in text_item:
230
  image_pos = text_item.find(image_token)
 
231
 
232
- if image_pos != -1 and len(num_cropped_images) > 0:
233
- num_crops = num_cropped_images.pop(0)
234
- video_flag = video_flags.pop(0)
235
-
236
- return_text_item += (
237
- text_item[:image_pos] + (image_token if video_flag else (image_token + "\n")) * num_crops
238
- )
239
- text_item = text_item[image_pos + len(image_token) :]
240
-
241
- else:
242
  break
243
 
244
- # Must place outside the while loop.
245
- if image_token in text_item:
246
- raise ValueError("Too many image tokens in the text.")
247
-
248
- return_text_item += text_item
249
- text_item = ""
250
-
251
- return_text.append(return_text_item)
252
-
253
- if len(num_cropped_images) != 0:
254
- raise ValueError("Too many images provided.")
255
 
256
- return return_text
 
 
 
257
 
258
- def _pad_image_tokens_by_num_embeddings(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  self,
260
- text: List[str],
261
- ) -> List[str]:
262
- """Pads each \\<image> to image_pad_len times of "\\<image>".
 
 
263
 
264
  Args:
265
- text: The text to be padded.
 
 
266
 
267
  Returns:
268
- The padded text.
269
  """
270
 
271
  image_token = cast(str, self.tokenizer.image_token)
 
272
 
273
- return [text_item.replace(image_token, image_token * (self.image_pad_len + 1)) for text_item in text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- @staticmethod
276
- def _prepare_inputs(
 
 
277
  text: TextInput | List[TextInput],
278
  images: Optional[ImageInput],
279
  videos: Optional[VideoInput],
280
  ) -> Tuple[List[str], List[Image], List[List[Image]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  prepared_text = text if isinstance(text, list) else [text]
282
 
283
  if images is not None:
@@ -296,117 +268,205 @@ class VILAProcessor(ProcessorMixin):
296
 
297
  return prepared_text, prepared_images, prepared_videos
298
 
299
- def _process_images(
300
  self,
301
- images: List[Image],
302
- *,
303
- video_flags: List[bool],
304
- **kwargs,
305
- ) -> Tuple[BatchFeature, List[int]]:
306
- cropped_images: List[Image] = []
307
- num_cropped_images: List[int] = []
308
 
309
- for image, video_flag in zip(images, video_flags):
310
- single_cropped_images = self._crop_image(image, is_video_frame=video_flag)
311
 
312
- cropped_images.extend(single_cropped_images)
313
- num_cropped_images.append(len(single_cropped_images))
314
 
315
- if len(cropped_images) == 0:
316
- # The image processor may not properly handle empty image lists.
317
- # This is a workaround to avoid errors.
318
- return BatchFeature(), num_cropped_images
319
 
320
- image_inputs = self.image_processor.__call__(
321
- cropped_images,
322
- **kwargs,
323
- )
324
 
325
- return image_inputs, num_cropped_images
326
 
327
- def _treat_videos_as_image_seqs(
328
- self, text: List[str], images: List[Image], videos: List[List[Image]]
329
- ) -> Tuple[List[str], List[Image], List[bool]]:
330
- """Treats videos as image sequences.
 
 
 
331
 
332
- This method will replace all video tokens in the text with #frame image tokens,
333
- and insert the corresponding images into the images list.
 
334
 
335
  Args:
336
- text: The text to be processed.
337
- images: The images to be processed.
338
- videos: The videos to be processed.
339
 
340
  Returns:
341
- The processed text and images, and a list of flags indicating whether the images are from videos.
 
342
  """
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  image_token = cast(str, self.tokenizer.image_token)
345
- video_token = cast(str, self.tokenizer.video_token)
346
 
347
- return_text: List[str] = []
348
- return_images: List[Image] = []
349
- return_video_flags: List[bool] = []
350
 
351
- for text_item in text:
352
- return_text_item: str = ""
 
 
 
353
 
354
- # Repeatedly find image_token or video_token in the text.
355
- while image_token in text_item or video_token in text_item:
356
- image_pos = text_item.find(image_token)
357
- video_pos = text_item.find(video_token)
358
 
359
- # If not found, set position to the end of the text.
360
- if image_pos == -1:
361
- image_pos = len(text_item)
362
- if video_pos == -1:
363
- video_pos = len(text_item)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- if image_pos != len(text_item) and len(images) > 0 and image_pos < video_pos:
366
- # Take an image and keep the image token if:
367
- # - an image token is found, and
368
- # - there are images left, and
369
- # - the image token is before the first video token.
370
 
371
- image = images.pop(0)
372
- return_images.append(image)
373
- return_video_flags.append(False)
374
 
375
- return_text_item += text_item[: image_pos + len(image_token)]
376
- text_item = text_item[image_pos + len(image_token) :]
 
 
 
 
 
 
 
 
 
 
377
 
378
- elif video_pos != len(text_item) and len(videos) > 0 and video_pos < image_pos:
379
- # Take a video and replace the video token with #frame image tokens if:
380
- # - a video token is found, and
381
- # - there are videos left, and
382
- # - the video token is before the first image token.
 
 
383
 
384
- video = videos.pop(0)
385
- return_images.extend(video)
386
- return_video_flags.extend([True] * len(video))
387
 
388
- return_text_item += text_item[:video_pos] + image_token * len(video)
389
- text_item = text_item[video_pos + len(video_token) :]
390
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  break
392
 
393
- # Must place outside the while loop.
394
- if image_token in text_item:
395
- raise ValueError("Too many image tokens in the text.")
396
- if video_token in text_item:
397
- raise ValueError("Too many video tokens in the text.")
 
 
 
 
 
398
 
399
- return_text_item += text_item
400
- text_item = ""
 
 
 
 
 
 
 
 
401
 
402
- return_text.append(return_text_item)
 
 
 
 
403
 
404
- if len(images) != 0:
405
- raise ValueError("Too many images provided.")
406
- if len(videos) != 0:
407
- raise ValueError("Too many videos provided.")
 
408
 
409
- return return_text, return_images, return_video_flags
410
 
411
 
412
  def dynamic_preprocess(image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail=True) -> List[Image]:
 
1
+ import uuid
2
  from typing import List, Optional, Tuple, cast
3
 
4
  import transformers.image_transforms as image_transforms
 
15
  from transformers.models.siglip.image_processing_siglip_fast import SiglipImageProcessorFast
16
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
17
  from transformers.tokenization_utils import PreTrainedTokenizer
18
+ from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TextInput
19
  from transformers.video_utils import VideoInput
20
 
21
  logger = transformers.utils.logging.get_logger(__name__)
 
84
  text: TextInput | List[TextInput],
85
  images: Optional[ImageInput] = None,
86
  videos: Optional[VideoInput] = None,
 
87
  **kwargs: Unpack[ProcessingKwargs],
88
  ) -> VILAProcessorOutput:
89
  """Preprocesses inputs for VILA.
 
92
  text: The text to be processed.
93
  images: The images to be processed.
94
  videos: The videos to be processed.
 
95
  **kwargs: Additional arguments for processing.
96
 
97
  Returns:
 
104
  **kwargs,
105
  )
106
 
107
+ normalized_text, normalized_images, normalized_videos = self._normalize_inputs(
108
  text=text,
109
  images=images,
110
  videos=videos,
111
  )
112
 
113
+ preprocessed_text, preprocessed_media_tiles = self._preprocess_inputs(
114
+ text=normalized_text,
115
+ images=normalized_images,
116
+ videos=normalized_videos,
 
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  text_inputs = self.tokenizer.__call__(
120
+ preprocessed_text,
121
  **merged_kwargs["text_kwargs"],
122
  )
123
 
124
+ if len(preprocessed_media_tiles) > 0:
125
+ image_inputs = self.image_processor.__call__(
126
+ preprocessed_media_tiles,
127
+ **merged_kwargs["images_kwargs"],
128
+ )
129
+ else:
130
+ image_inputs = BatchFeature()
 
 
 
 
 
131
 
132
+ text_inputs = self._replace_image_tile_suffix(text_inputs)
 
 
 
 
133
 
134
  return VILAProcessorOutput(
135
  data={
 
138
  }
139
  )
140
 
141
+ def _find_media_token_order(self, text: List[str]) -> List[str]:
142
+ """Finds the order of media tokens in the text.
 
 
 
 
 
143
 
144
  Args:
145
+ text: The text to be processed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  Returns:
148
+ The order of media tokens in the text. Each item is either an image token or a video
149
+ token.
150
  """
151
 
152
+ image_token = cast(str, self.tokenizer.image_token)
153
+ video_token = cast(str, self.tokenizer.video_token)
 
 
 
154
 
155
+ return_order: List[str] = []
156
 
157
  for text_item in text:
158
+ while image_token in text_item or video_token in text_item:
 
 
 
159
  image_pos = text_item.find(image_token)
160
+ video_pos = text_item.find(video_token)
161
 
162
+ if image_pos == -1 and video_pos == -1:
163
+ # If no media token found, move to the next text item.
 
 
 
 
 
 
 
 
164
  break
165
 
166
+ elif image_pos == -1:
167
+ # If only video token found, add it to the return order.
168
+ return_order.append(video_token)
169
+ text_item = text_item[video_pos + len(video_token) :]
 
 
 
 
 
 
 
170
 
171
+ elif video_pos == -1:
172
+ # If only image token found, add it to the return order.
173
+ return_order.append(image_token)
174
+ text_item = text_item[image_pos + len(image_token) :]
175
 
176
+ else:
177
+ # If both tokens found, choose the one that appears first.
178
+ if image_pos < video_pos:
179
+ return_order.append(image_token)
180
+ text_item = text_item[image_pos + len(image_token) :]
181
+ else:
182
+ return_order.append(video_token)
183
+ text_item = text_item[video_pos + len(video_token) :]
184
+
185
+ return return_order
186
+
187
+ def _generate_image_token_placeholder(self, text: List[str]) -> str:
188
+ while True:
189
+ placeholder = f"<|image_placeholder_{str(uuid.uuid4())}|>"
190
+ if all(placeholder not in text_item for text_item in text):
191
+ return placeholder
192
+
193
+ def _merge_media_tiles(
194
  self,
195
+ image_tiles: List[List[Image]],
196
+ video_tiles: List[List[List[Image]]],
197
+ media_token_order: List[str],
198
+ ) -> List[Image]:
199
+ """Merges the media tiles by the media token order.
200
 
201
  Args:
202
+ image_tiles: The image tiles.
203
+ video_tiles: The video tiles.
204
+ media_token_order: The order of media tokens in the text.
205
 
206
  Returns:
207
+ The merged media tiles.
208
  """
209
 
210
  image_token = cast(str, self.tokenizer.image_token)
211
+ video_token = cast(str, self.tokenizer.video_token)
212
 
213
+ image_tiles_idx = 0
214
+ video_tiles_idx = 0
215
+
216
+ return_tiles: List[Image] = []
217
+
218
+ for media_token in media_token_order:
219
+ if media_token == image_token:
220
+ return_tiles.extend(image_tiles[image_tiles_idx])
221
+ image_tiles_idx += 1
222
+ elif media_token == video_token:
223
+ for video_tile in video_tiles[video_tiles_idx]:
224
+ return_tiles.extend(video_tile)
225
+ video_tiles_idx += 1
226
+ else:
227
+ raise ValueError(f"Invalid media token: {media_token}")
228
 
229
+ return return_tiles
230
+
231
+ def _normalize_inputs(
232
+ self,
233
  text: TextInput | List[TextInput],
234
  images: Optional[ImageInput],
235
  videos: Optional[VideoInput],
236
  ) -> Tuple[List[str], List[Image], List[List[Image]]]:
237
+ """Normalizes text, image, and video inputs for processing.
238
+
239
+ This method converts various input formats into standardized lists of PIL images
240
+ and text strings that can be processed by the model.
241
+
242
+ Args:
243
+ text: The original input text.
244
+ images: The original input images.
245
+ videos: The original input videos.
246
+
247
+ Returns:
248
+ The text as a list of strings.
249
+ The images as a list of PIL images.
250
+ The videos as a list of lists of PIL images.
251
+ """
252
+
253
  prepared_text = text if isinstance(text, list) else [text]
254
 
255
  if images is not None:
 
268
 
269
  return prepared_text, prepared_images, prepared_videos
270
 
271
+ def _pad_image_tiles(
272
  self,
273
+ text: List[str],
274
+ ) -> List[str]:
275
+ """Pads each media tile.
 
 
 
 
276
 
277
+ This will pad each <image> to (self.image_pad_len + 1) times. The additional one padding is
278
+ for the \\n token suffix.
279
 
280
+ Args:
281
+ text: The text to be padded.
282
 
283
+ Returns:
284
+ The padded text.
285
+ """
 
286
 
287
+ image_token = cast(str, self.tokenizer.image_token)
 
 
 
288
 
289
+ return [text_item.replace(image_token, image_token * (self.image_pad_len + 1)) for text_item in text]
290
 
291
+ def _preprocess_inputs(
292
+ self,
293
+ text: List[str],
294
+ images: List[Image],
295
+ videos: List[List[Image]],
296
+ ) -> Tuple[List[str], List[Image]]:
297
+ """Preprocesses the input data for the VILA model.
298
 
299
+ This method takes a list of texts, images, and videos, and prepares them for the model.
300
+ It handles the interleaving of text and media, and returns the processed text and a
301
+ list of media tiles (images or video frames).
302
 
303
  Args:
304
+ text: The input text.
305
+ images: The input images.
306
+ videos: The input videos.
307
 
308
  Returns:
309
+ The text ready to be tokenized.
310
+ The media tiles ready to be processed.
311
  """
312
 
313
+ media_token_order = self._find_media_token_order(text)
314
+
315
+ image_token_placeholder = self._generate_image_token_placeholder(text)
316
+
317
+ preprocessed_text = text
318
+ preprocessed_text, preprocessed_image_tiles = self._preprocess_images(
319
+ preprocessed_text,
320
+ images,
321
+ image_token_placeholder=image_token_placeholder,
322
+ )
323
+ preprocessed_text, preprocessed_video_tiles = self._preprocess_videos(
324
+ preprocessed_text,
325
+ videos,
326
+ image_token_placeholder=image_token_placeholder,
327
+ )
328
+
329
+ # Convert back to the original image token.
330
  image_token = cast(str, self.tokenizer.image_token)
331
+ preprocessed_text = [text_item.replace(image_token_placeholder, image_token) for text_item in preprocessed_text]
332
 
333
+ preprocessed_text = self._pad_image_tiles(preprocessed_text)
 
 
334
 
335
+ preprocessed_media_tiles = self._merge_media_tiles(
336
+ preprocessed_image_tiles,
337
+ preprocessed_video_tiles,
338
+ media_token_order,
339
+ )
340
 
341
+ return preprocessed_text, preprocessed_media_tiles
 
 
 
342
 
343
+ def _preprocess_images(
344
+ self,
345
+ text: List[str],
346
+ images: List[Image],
347
+ *,
348
+ image_token_placeholder: str,
349
+ ) -> Tuple[List[str], List[List[Image]]]:
350
+ single_image_token_placeholder = self._generate_image_token_placeholder(text)
351
+
352
+ preprocessed_text = text
353
+ preprocessed_image_tiles: List[List[Image]] = []
354
+
355
+ for image in images:
356
+ preprocessed_text, preprocessed_single_image_tiles = self._preprocess_single_image(
357
+ text,
358
+ image,
359
+ image_token_placeholder=single_image_token_placeholder,
360
+ is_video_frame=False,
361
+ use_dynamic_preprocess=(len(images) == 1),
362
+ )
363
+
364
+ preprocessed_text = [
365
+ text_item.replace(
366
+ single_image_token_placeholder,
367
+ (image_token_placeholder + "\n") if len(images) == 1 else image_token_placeholder,
368
+ )
369
+ for text_item in preprocessed_text
370
+ ]
371
 
372
+ preprocessed_image_tiles.append(preprocessed_single_image_tiles)
 
 
 
 
373
 
374
+ return preprocessed_text, preprocessed_image_tiles
 
 
375
 
376
+ def _preprocess_single_image(
377
+ self,
378
+ text: List[str],
379
+ image: Image,
380
+ *,
381
+ image_token_placeholder: str,
382
+ is_video_frame: bool,
383
+ use_dynamic_preprocess: bool,
384
+ ) -> Tuple[List[str], List[Image]]:
385
+ assert isinstance(self.image_processor, (SiglipImageProcessor, SiglipImageProcessorFast))
386
+ assert self.image_processor.size["height"] == self.image_processor.size["width"]
387
+ cropped_size = self.image_processor.size["height"]
388
 
389
+ if use_dynamic_preprocess:
390
+ if is_video_frame:
391
+ max_num = self.video_max_tiles
392
+ else:
393
+ max_num = self.max_tiles
394
+ else:
395
+ max_num = 1
396
 
397
+ image = image.convert("RGB")
 
 
398
 
399
+ cropped_images: List[Image] = dynamic_preprocess(
400
+ image,
401
+ min_num=self.min_tiles,
402
+ max_num=max_num,
403
+ image_size=cropped_size,
404
+ )
405
+
406
+ image_token = cast(str, self.tokenizer.image_token)
407
+
408
+ for i in range(len(text)):
409
+ if image_token in text[i]:
410
+ text[i] = text[i].replace(image_token, image_token_placeholder * len(cropped_images))
411
+ break
412
+
413
+ return text, cropped_images
414
+
415
+ def _preprocess_videos(
416
+ self,
417
+ text: List[str],
418
+ videos: List[List[Image]],
419
+ *,
420
+ image_token_placeholder: str,
421
+ ) -> Tuple[List[str], List[List[List[Image]]]]:
422
+ image_token = cast(str, self.tokenizer.image_token)
423
+ video_token = cast(str, self.tokenizer.video_token)
424
+
425
+ processed_text = text
426
+ processed_video_tiles: List[List[List[Image]]] = []
427
+
428
+ for video in videos:
429
+ # Replace the first video token with #frame image tokens.
430
+ for i in range(len(processed_text)):
431
+ if video_token in processed_text[i]:
432
+ processed_text[i] = processed_text[i].replace(video_token, image_token * len(video))
433
  break
434
 
435
+ processed_frame_tiles: List[List[Image]] = []
436
+ for frame in video:
437
+ processed_text, processed_single_frame_tiles = self._preprocess_single_image(
438
+ processed_text,
439
+ frame,
440
+ image_token_placeholder=image_token_placeholder,
441
+ is_video_frame=True,
442
+ use_dynamic_preprocess=(self.video_max_tiles > 1),
443
+ )
444
+ processed_frame_tiles.append(processed_single_frame_tiles)
445
 
446
+ processed_video_tiles.append(processed_frame_tiles)
447
+
448
+ return processed_text, processed_video_tiles
449
+
450
+ def _replace_image_tile_suffix(self, text_inputs: BatchEncoding) -> BatchEncoding:
451
+ lf_token_id = cast(int, self.tokenizer.encode("\n")[0])
452
+ image_token_id = cast(int, self.tokenizer.image_token_id)
453
+
454
+ for i in range(len(text_inputs.input_ids)):
455
+ input_ids = text_inputs.input_ids[i]
456
 
457
+ idx = 0
458
+ while idx < len(input_ids):
459
+ if input_ids[idx] != image_token_id:
460
+ idx += 1
461
+ continue
462
 
463
+ if idx + self.image_pad_len < len(input_ids):
464
+ input_ids[idx + self.image_pad_len] = lf_token_id
465
+ idx += self.image_pad_len + 1
466
+ else:
467
+ break
468
 
469
+ return text_inputs
470
 
471
 
472
  def dynamic_preprocess(image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail=True) -> List[Image]: