import re import torch from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast from transformers.image_processing_utils import BaseImageProcessor from transformers.image_utils import ImageInput from typing import Any, Dict, List, Optional, Union from PIL import Image from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN # Adapted from transformers.models.llava_next.image_processing_llava_next.expand_to_square def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor: """ Expands an image to a square by adding a background color. """ c, height, width = image.shape if width == height: return image elif width > height: result = torch.ones((c, width, width), dtype=image.dtype) * background_color result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image return result else: result = torch.ones((c, height, height), dtype=image.dtype) * background_color result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image return result class FastVLMImageProcessor(CLIPImageProcessorFast): def _preprocess(self, images, **kwargs): image_sizes = [image.shape[-2:][::-1] for image in images] images = [expand_to_square(image) for image in images] images = super()._preprocess(images, **kwargs) pixel_values = torch.stack(images.pixel_values, dim=0) return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes}) class FastVLMProcessor(ProcessorMixin): attributes = ["tokenizer", "image_processor"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( self, tokenizer, image_processor, chat_template=None, **kwargs ): super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs) def __call__( self, images: ImageInput = None, text: Optional[Union[str, List[str]]] = None, return_tensors: Optional[str] = "pt", **kwargs, ) -> BatchFeature: if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise TypeError("Invalid input text. Please provide a string, or a list of strings") image_inputs = {} if images is not None: image_inputs = self.image_processor(images=images) image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64) input_ids = torch.tensor([], dtype=torch.int64) attention_mask = torch.tensor([], dtype=torch.int64) for prompt in text: image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)] if len(image_indexes) > 1: raise ValueError( f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead." ) # DEFAULT_IMAGE_TOKEN is -200, not in the vocab (so we can't tokenize the full string) pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN) pre_ids = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64) sample_mask = torch.ones_like(sample_ids) input_ids = torch.cat([input_ids, sample_ids], dim=0) attention_mask = torch.cat([attention_mask, sample_mask], dim=0) return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors)