FastVLM-0.5B / processing_fastvlm.py
pcuenq's picture
pcuenq HF Staff
Create custom processor for easier inference
66b6789 verified
raw
history blame
3.89 kB
import re
import torch
from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ImageInput
from typing import Any, Dict, List, Optional, Union
from PIL import Image
from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
# Adapted from transformers.models.llava_next.image_processing_llava_next.expand_to_square
def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor:
"""
Expands an image to a square by adding a background color.
"""
c, height, width = image.shape
if width == height:
return image
elif width > height:
result = torch.ones((c, width, width), dtype=image.dtype) * background_color
result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image
return result
else:
result = torch.ones((c, height, height), dtype=image.dtype) * background_color
result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image
return result
class FastVLMImageProcessor(CLIPImageProcessorFast):
def _preprocess(self, images, **kwargs):
image_sizes = [image.shape[-2:][::-1] for image in images]
images = [expand_to_square(image) for image in images]
images = super()._preprocess(images, **kwargs)
pixel_values = torch.stack(images.pixel_values, dim=0)
return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes})
class FastVLMProcessor(ProcessorMixin):
attributes = ["tokenizer", "image_processor"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
tokenizer,
image_processor,
chat_template=None,
**kwargs
):
super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs)
def __call__(
self,
images: ImageInput = None,
text: Optional[Union[str, List[str]]] = None,
return_tensors: Optional[str] = "pt",
**kwargs,
) -> BatchFeature:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
image_inputs = {}
if images is not None:
image_inputs = self.image_processor(images=images)
image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64)
input_ids = torch.tensor([], dtype=torch.int64)
attention_mask = torch.tensor([], dtype=torch.int64)
for prompt in text:
image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)]
if len(image_indexes) > 1:
raise ValueError(
f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead."
)
# DEFAULT_IMAGE_TOKEN is -200, not in the vocab (so we can't tokenize the full string)
pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN)
pre_ids = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids
sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64)
sample_mask = torch.ones_like(sample_ids)
input_ids = torch.cat([input_ids, sample_ids], dim=0)
attention_mask = torch.cat([attention_mask, sample_mask], dim=0)
return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors)