| | from functools import cached_property |
| | from typing import Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor |
| | from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
| | from vllm import ModelRegistry |
| | from vllm.config import VllmConfig |
| | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
| | from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal |
| | from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM |
| | from vllm.model_executor.models.qwen2_5_vl import ( |
| | Qwen2_5_VLMultiModalProcessor, |
| | Qwen2_5_VLProcessingInfo, |
| | ) |
| | from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder |
| | from vllm.model_executor.models.utils import ( |
| | AutoWeightsLoader, |
| | WeightsMapper, |
| | init_vllm_registered_model, |
| | maybe_prefix, |
| | merge_multimodal_embeddings, |
| | ) |
| | from vllm.model_executor.sampling_metadata import SamplingMetadata |
| | from vllm.multimodal import MULTIMODAL_REGISTRY |
| | from vllm.multimodal.inputs import MultiModalDataDict |
| | from vllm.multimodal.parse import ImageSize |
| | from vllm.sequence import IntermediateTensors |
| |
|
| | from .configuration_dots import DotsVisionConfig |
| | from .configuration_dots import DotsOCRConfig |
| | from .modeling_dots_vision import DotsVisionTransformer |
| |
|
| |
|
| | class DotsOCRImagePixelInputs(TypedDict): |
| | type: Literal["pixel_values", "image_grid_thw"] |
| |
|
| | pixel_values: torch.Tensor |
| | image_grid_thw: torch.Tensor |
| |
|
| |
|
| | class DotsOCRImageEmbeddingInputs(TypedDict): |
| | type: Literal["image_embeds", "image_grid_thw"] |
| | image_embeds: torch.Tensor |
| | """Supported types: |
| | - List[`torch.Tensor`]: A list of tensors holding all images' features. |
| | Each tensor holds an image's features. |
| | - `torch.Tensor`: A tensor holding all images' features |
| | (concatenation of all images' feature tensors). |
| | |
| | Tensor shape: `(num_image_features, hidden_size)` |
| | - `num_image_features` varies based on |
| | the number and resolution of the images. |
| | - `hidden_size` must match the hidden size of language model backbone. |
| | """ |
| |
|
| | image_grid_thw: torch.Tensor |
| |
|
| |
|
| | DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs] |
| |
|
| |
|
| | class DotsOCRMultiModalProcessor(Qwen2_5_VLMultiModalProcessor): |
| | pass |
| |
|
| |
|
| | class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): |
| | def get_dummy_mm_data( |
| | self, |
| | seq_len: int, |
| | mm_counts: Mapping[str, int], |
| | ) -> MultiModalDataDict: |
| | num_images = mm_counts.get("image", 0) |
| |
|
| | target_width, target_height = self.info.get_image_size_with_most_features() |
| |
|
| | return { |
| | "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images), |
| | } |
| |
|
| |
|
| | class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo): |
| | def get_hf_config(self) -> DotsOCRConfig: |
| | config = self.ctx.get_hf_config() |
| | if not config.__class__.__name__ == 'DotsOCRConfig': |
| | raise TypeError(f"Expected DotsOCRConfig, got {type(config)}") |
| |
|
| | if hasattr(config, "vision_config") and isinstance(config.vision_config, dict): |
| | config.vision_config = DotsVisionConfig(**config.vision_config) |
| | |
| | return config |
| |
|
| | def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: |
| | return {"image": None, "video": 0} |
| |
|
| | def get_mm_max_tokens_per_item( |
| | self, |
| | seq_len: int, |
| | mm_counts: Mapping[str, int], |
| | ) -> Mapping[str, int]: |
| | max_image_tokens = self.get_max_image_tokens() |
| | return {"image": max_image_tokens, "video": 0} |
| |
|
| | def get_hf_processor( |
| | self, |
| | *, |
| | min_pixels: Optional[int] = None, |
| | max_pixels: Optional[int] = None, |
| | size: Optional[dict[str, int]] = None, |
| | **kwargs: object, |
| | ) -> Qwen2VLProcessor: |
| | self.get_tokenizer().image_token = "<|imgpad|>" |
| | processor = self.ctx.get_hf_processor( |
| | Qwen2VLProcessor, |
| | image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size), |
| | **kwargs, |
| | ) |
| | processor.image_token = "<|imgpad|>" |
| | processor.video_token = "<|video_pad|>" |
| | return processor |
| |
|
| | def _get_vision_info( |
| | self, |
| | *, |
| | image_width: int, |
| | image_height: int, |
| | num_frames: int = 1, |
| | do_resize: bool = True, |
| | image_processor: Optional[Qwen2VLImageProcessor], |
| | ) -> tuple[ImageSize, int]: |
| | if image_processor is None: |
| | image_processor = self.get_image_processor() |
| |
|
| | hf_config: DotsOCRConfig = self.get_hf_config() |
| | vision_config = hf_config.vision_config |
| | patch_size = vision_config.patch_size |
| | merge_size = vision_config.spatial_merge_size |
| | temporal_patch_size = vision_config.temporal_patch_size |
| |
|
| | if do_resize: |
| | resized_height, resized_width = smart_resize( |
| | height=image_height, |
| | width=image_width, |
| | factor=patch_size * merge_size, |
| | min_pixels=image_processor.min_pixels, |
| | max_pixels=image_processor.max_pixels, |
| | ) |
| | preprocessed_size = ImageSize(width=resized_width, height=resized_height) |
| | else: |
| | preprocessed_size = ImageSize(width=image_width, height=image_height) |
| |
|
| | |
| | |
| | padded_num_frames = num_frames + num_frames % temporal_patch_size |
| |
|
| | grid_t = max(padded_num_frames // temporal_patch_size, 1) |
| | grid_h = preprocessed_size.height // patch_size |
| | grid_w = preprocessed_size.width // patch_size |
| |
|
| | num_patches = grid_t * grid_h * grid_w |
| | num_vision_tokens = num_patches // (merge_size**2) |
| |
|
| | return preprocessed_size, num_vision_tokens |
| |
|
| |
|
| | @MULTIMODAL_REGISTRY.register_processor( |
| | Qwen2_5_VLMultiModalProcessor, |
| | info=DotsOCRProcessingInfo, |
| | dummy_inputs=DotsOCRDummyInputsBuilder, |
| | ) |
| | class DotsOCRForCausalLM(nn.Module, SupportsMultiModal): |
| | hf_to_vllm_mapper = WeightsMapper( |
| | orig_to_new_prefix={ |
| | "lm_head.": "language_model.lm_head.", |
| | "model.": "language_model.model.", |
| | } |
| | ) |
| | _tp_plan = {} |
| |
|
| | @classmethod |
| | def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: |
| | if modality in ("image",): |
| | return "<|img|><|imgpad|><|endofimg|>" |
| |
|
| | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| | super().__init__() |
| |
|
| | self.config: DotsOCRConfig = vllm_config.model_config.hf_config |
| | self.quant_config = vllm_config.quant_config |
| | self.multimodal_config = vllm_config.model_config.multimodal_config |
| |
|
| | if isinstance(self.config.vision_config, dict): |
| | vision_config = DotsVisionConfig(**self.config.vision_config) |
| | self.config.vision_config = vision_config |
| | else: |
| | vision_config = self.config.vision_config |
| |
|
| | self.vision_tower = DotsVisionTransformer(vision_config) |
| | self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( |
| | vllm_config=vllm_config, |
| | hf_config=self.config, |
| | prefix=maybe_prefix(prefix, "language_model"), |
| | architectures=["Qwen2ForCausalLM"], |
| | ) |
| |
|
| | @cached_property |
| | def sampler(self): |
| | if hasattr(self.language_model, "sampler"): |
| | return self.language_model.sampler |
| |
|
| | return get_sampler() |
| |
|
| | def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: |
| | if not isinstance(mm_input, (torch.Tensor, list)): |
| | raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") |
| | if isinstance(mm_input, torch.Tensor): |
| | if mm_input.ndim == 2: |
| | return mm_input |
| | if mm_input.ndim != 3: |
| | raise ValueError( |
| | f"{name} should be 2D or batched 3D tensor. " |
| | f"Got ndim: {mm_input.ndim} " |
| | f"(shape={mm_input.shape})" |
| | ) |
| | return torch.concat(list(mm_input)) |
| | else: |
| | return torch.concat(mm_input) |
| |
|
| | def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[DotsOCRImageInputs]: |
| | pixel_values = kwargs.pop("pixel_values", None) |
| | image_embeds = kwargs.pop("image_embeds", None) |
| | image_grid_thw = kwargs.pop("image_grid_thw", None) |
| |
|
| | if pixel_values is None and image_embeds is None: |
| | return None |
| |
|
| | if pixel_values is not None: |
| | pixel_values = self._validate_and_reshape_mm_tensor(pixel_values, "image pixel values") |
| | image_grid_thw = self._validate_and_reshape_mm_tensor(image_grid_thw, "image grid_thw") |
| |
|
| | if not isinstance(pixel_values, (torch.Tensor, list)): |
| | raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}") |
| |
|
| | return DotsOCRImagePixelInputs( |
| | type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw |
| | ) |
| |
|
| | if image_embeds is not None: |
| | image_embeds = self._validate_and_reshape_mm_tensor(image_embeds, "image embeds") |
| | image_grid_thw = self._validate_and_reshape_mm_tensor(image_grid_thw, "image grid_thw") |
| |
|
| | if not isinstance(image_embeds, torch.Tensor): |
| | raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") |
| | return DotsOCRImageEmbeddingInputs( |
| | type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw |
| | ) |
| |
|
| | def vision_forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor): |
| | from vllm.distributed import ( |
| | get_tensor_model_parallel_group, |
| | get_tensor_model_parallel_rank, |
| | get_tensor_model_parallel_world_size, |
| | ) |
| |
|
| | assert self.vision_tower is not None |
| |
|
| | tp_rank = get_tensor_model_parallel_rank() |
| | tp = get_tensor_model_parallel_world_size() |
| |
|
| | image_grid_thw_chunk = image_grid_thw.chunk(tp) |
| | image_sizes_consum = torch.tensor([i.prod(-1).sum() for i in image_grid_thw_chunk]).cumsum(dim=0) |
| | merge_size_square = self.vision_tower.config.spatial_merge_size**2 |
| | image_embedding = torch.zeros( |
| | ( |
| | pixel_values.shape[0] // merge_size_square, |
| | self.vision_tower.config.hidden_size, |
| | ), |
| | device=pixel_values.device, |
| | dtype=pixel_values.dtype, |
| | ) |
| |
|
| | if tp_rank < len(image_sizes_consum): |
| | idx_start = 0 if tp_rank == 0 else image_sizes_consum[tp_rank - 1].item() |
| | idx_end = image_sizes_consum[tp_rank].item() |
| | pixel_values_part = pixel_values[idx_start:idx_end] |
| | image_grid_thw_part = image_grid_thw_chunk[tp_rank] |
| | image_embedding_part = self.vision_tower(pixel_values_part, image_grid_thw_part) |
| | image_embedding[idx_start // merge_size_square : idx_end // merge_size_square] = image_embedding_part |
| |
|
| | group = get_tensor_model_parallel_group().device_group |
| | torch.distributed.all_reduce(image_embedding, group=group) |
| | return image_embedding |
| |
|
| | def _process_image_input(self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]: |
| | grid_thw = image_input["image_grid_thw"] |
| | assert grid_thw.ndim == 2 |
| |
|
| | if image_input["type"] == "image_embeds": |
| | image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype) |
| | else: |
| | pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype) |
| | image_embeds = self.vision_forward(pixel_values, grid_thw)[ |
| | :, : self.config.hidden_size |
| | ] |
| |
|
| | |
| | merge_size = self.vision_tower.config.spatial_merge_size |
| | sizes = grid_thw.prod(-1) // merge_size // merge_size |
| |
|
| | return image_embeds.split(sizes.tolist()) |
| |
|
| | def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: |
| | modalities = {} |
| |
|
| | |
| | |
| | for input_key in kwargs: |
| | if input_key in ("pixel_values", "image_embeds") and "images" not in modalities: |
| | modalities["images"] = self._parse_and_validate_image_input(**kwargs) |
| | return modalities |
| |
|
| | def get_language_model(self) -> torch.nn.Module: |
| | return self.language_model |
| |
|
| | def get_multimodal_embeddings(self, **kwargs: object) -> Optional[MultiModalEmbeddings]: |
| | modalities = self._parse_and_validate_multimodal_inputs(**kwargs) |
| | if not modalities: |
| | return None |
| |
|
| | |
| | |
| | multimodal_embeddings: tuple[torch.Tensor, ...] = () |
| |
|
| | |
| | |
| | for modality in modalities: |
| | if modality == "images": |
| | image_input = modalities["images"] |
| | vision_embeddings = self._process_image_input(image_input) |
| | multimodal_embeddings += vision_embeddings |
| |
|
| | return multimodal_embeddings |
| |
|
| | def get_input_embeddings( |
| | self, |
| | input_ids: torch.Tensor, |
| | multimodal_embeddings: Optional[MultiModalEmbeddings] = None, |
| | ) -> torch.Tensor: |
| | inputs_embeds = self.language_model.get_input_embeddings(input_ids) |
| | if multimodal_embeddings is not None: |
| | inputs_embeds = merge_multimodal_embeddings( |
| | input_ids, |
| | inputs_embeds, |
| | multimodal_embeddings, |
| | [self.config.image_token_id, self.config.video_token_id], |
| | ) |
| |
|
| | return inputs_embeds |
| |
|
| | def get_input_embeddings_v0( |
| | self, |
| | input_ids: torch.Tensor, |
| | image_input: Optional[DotsOCRImagePixelInputs] = None, |
| | ) -> torch.Tensor: |
| | inputs_embeds = self.get_input_embeddings(input_ids) |
| | if image_input is not None: |
| | image_embeds = self._process_image_input(image_input) |
| | inputs_embeds = merge_multimodal_embeddings( |
| | input_ids, |
| | inputs_embeds, |
| | image_embeds, |
| | placeholder_token_id=self.config.image_token_id, |
| | ) |
| | return inputs_embeds |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor], |
| | positions: torch.Tensor, |
| | intermediate_tensors: Optional[IntermediateTensors] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> Union[torch.Tensor, IntermediateTensors]: |
| | if intermediate_tensors is not None: |
| | inputs_embeds = None |
| | elif inputs_embeds is None and kwargs.get("pixel_values") is not None: |
| | image_input = self._parse_and_validate_image_input(**kwargs) |
| | if image_input is None: |
| | inputs_embeds = None |
| | else: |
| | assert input_ids is not None |
| | inputs_embeds = self.get_input_embeddings_v0( |
| | input_ids, |
| | image_input=image_input, |
| | ) |
| | input_ids = None |
| |
|
| | hidden_states = self.language_model( |
| | input_ids=input_ids, |
| | positions=positions, |
| | intermediate_tensors=intermediate_tensors, |
| | inputs_embeds=inputs_embeds, |
| | ) |
| |
|
| | return hidden_states |
| |
|
| | def compute_logits( |
| | self, |
| | hidden_states: torch.Tensor, |
| | sampling_metadata: SamplingMetadata, |
| | ) -> Optional[torch.Tensor]: |
| | return self.language_model.compute_logits(hidden_states, sampling_metadata) |
| |
|
| | def sample( |
| | self, |
| | logits: Optional[torch.Tensor], |
| | sampling_metadata: SamplingMetadata, |
| | ) -> Optional[SamplerOutput]: |
| | next_tokens = self.sampler(logits, sampling_metadata) |
| | return next_tokens |
| |
|
| | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: |
| | loader = AutoWeightsLoader(self) |
| | return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) |
| |
|
| |
|
| | def patch_vllm_chat_placeholder(): |
| | import vllm |
| | |
| | if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1): |
| | return |
| | from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker |
| |
|
| | ori = BaseMultiModalItemTracker._placeholder_str |
| |
|
| | def _placeholder_str(self, modality, current_count: int) -> Optional[str]: |
| | hf_config = self._model_config.hf_config |
| | model_type = hf_config.model_type |
| | if modality in ("image",) and model_type in ["dots_ocr"]: |
| | return "<|img|><|imgpad|><|endofimg|>" |
| | return ori(self, modality, current_count) |
| |
|
| | BaseMultiModalItemTracker._placeholder_str = _placeholder_str |
| |
|
| | ModelRegistry.register_model( |
| | "DotsOCRForCausalLM", DotsOCRForCausalLM, |
| | ) |
| |
|
| |
|
| | patch_vllm_chat_placeholder() |