Create custom processor for easier inference

#11
by pcuenq HF Staff - opened
README.md CHANGED
@@ -55,56 +55,42 @@ python predict.py --model-path /path/to/checkpoint-dir \
55
  To run inference with transformers we can leverage `trust_remote_code` along with the following snippet:
56
 
57
  ```python
58
- import torch
59
- from PIL import Image
60
- from transformers import AutoTokenizer, AutoModelForCausalLM
61
 
62
- MID = "apple/FastVLM-0.5B"
63
- IMAGE_TOKEN_INDEX = -200 # what the model code looks for
64
 
65
- # Load
66
- tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
67
  model = AutoModelForCausalLM.from_pretrained(
68
- MID,
69
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
70
- device_map="auto",
71
  trust_remote_code=True,
72
  )
73
 
74
- # Build chat -> render to string (not tokens) so we can place <image> exactly
75
  messages = [
76
- {"role": "user", "content": "<image>\nDescribe this image in detail."}
 
 
 
 
 
 
77
  ]
78
- rendered = tok.apply_chat_template(
79
- messages, add_generation_prompt=True, tokenize=False
80
- )
81
-
82
- pre, post = rendered.split("<image>", 1)
83
-
84
- # Tokenize the text *around* the image token (no extra specials!)
85
- pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
86
- post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
87
-
88
- # Splice in the IMAGE token id (-200) at the placeholder position
89
- img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
90
- input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
91
- attention_mask = torch.ones_like(input_ids, device=model.device)
92
 
93
- # Preprocess image via the model's own processor
94
- img = Image.open("test-2.jpg").convert("RGB")
95
- px = model.get_vision_tower().image_processor(images=img, return_tensors="pt")["pixel_values"]
96
- px = px.to(model.device, dtype=model.dtype)
 
 
 
97
 
98
- # Generate
99
- with torch.no_grad():
100
- out = model.generate(
101
- inputs=input_ids,
102
- attention_mask=attention_mask,
103
- images=px,
104
- max_new_tokens=128,
105
- )
106
 
107
- print(tok.decode(out[0], skip_special_tokens=True))
108
  ```
109
 
110
  ## Citation
@@ -117,4 +103,4 @@ If you found this model useful, please cite the following paper:
117
  month = {June},
118
  year = {2025},
119
  }
120
- ```
 
55
  To run inference with transformers we can leverage `trust_remote_code` along with the following snippet:
56
 
57
  ```python
58
+ from transformers import AutoModelForCausalLM, AutoProcessor
 
 
59
 
60
+ model_id = "apple/FastVLM-0.5B"
 
61
 
62
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
63
  model = AutoModelForCausalLM.from_pretrained(
64
+ model_id,
 
 
65
  trust_remote_code=True,
66
  )
67
 
68
+ image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
69
  messages = [
70
+ {
71
+ "role": "user",
72
+ "content": [
73
+ {"type": "image", "image": image_url},
74
+ {"type": "text", "text": "Describe this image in detail."},
75
+ ]
76
+ }
77
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ inputs = processor.apply_chat_template(
80
+ messages,
81
+ add_generation_prompt=True,
82
+ tokenize=True,
83
+ return_tensors="pt",
84
+ return_dict=True,
85
+ )
86
 
87
+ out = model.generate(
88
+ **inputs,
89
+ do_sample=False,
90
+ max_new_tokens=150,
91
+ )
 
 
 
92
 
93
+ print(processor.tokenizer.decode(out[0], skip_special_tokens=False))
94
  ```
95
 
96
  ## Citation
 
103
  month = {June},
104
  year = {2025},
105
  }
106
+ ```
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "_name_or_path": "./llava-v1.5-13b",
3
  "architectures": [
4
  "LlavaQwen2ForCausalLM"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "llava_qwen.LlavaConfig",
8
- "AutoModelForCausalLM": "llava_qwen.LlavaQwen2ForCausalLM"
9
- },
 
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 151643,
12
  "eos_token_id": 151645,
@@ -45,5 +45,24 @@
45
  "use_cache": true,
46
  "use_mm_proj": true,
47
  "use_sliding_window": false,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  "vocab_size": 151936
49
  }
 
1
  {
 
2
  "architectures": [
3
  "LlavaQwen2ForCausalLM"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "llava_qwen.LlavaConfig",
7
+ "AutoModelForCausalLM": "llava_qwen.LlavaQwen2ForCausalLM",
8
+ "AutoProcessor": "processing_fastvlm.FastVLMProcessor"
9
+ },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 151643,
12
  "eos_token_id": 151645,
 
45
  "use_cache": true,
46
  "use_mm_proj": true,
47
  "use_sliding_window": false,
48
+ "vision_config": {
49
+ "cls_ratio": 2.0,
50
+ "down_patch_size": 7,
51
+ "down_stride": 2,
52
+ "downsamples": [true, true, true, true, true],
53
+ "embed_dims": [96, 192, 384, 768, 1536],
54
+ "hidden_size": 1024,
55
+ "image_size": 1024,
56
+ "intermediate_size": 3072,
57
+ "layer_scale_init_value": 1e-5,
58
+ "layers": [2, 12, 24, 4, 2],
59
+ "mlp_ratios": [4, 4, 4, 4, 4],
60
+ "num_classes": 1000,
61
+ "patch_size": 64,
62
+ "pos_embs_shapes": [null, null, null, [7, 7], [7, 7]],
63
+ "projection_dim": 768,
64
+ "repmixer_kernel_size": 3,
65
+ "token_mixers": ["repmixer", "repmixer", "repmixer", "attention", "attention"]
66
+ },
67
  "vocab_size": 151936
68
  }
llava_qwen.py CHANGED
@@ -2140,8 +2140,8 @@ class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
2140
  @torch.no_grad()
2141
  def generate(
2142
  self,
2143
- inputs: Optional[torch.Tensor] = None,
2144
- images: Optional[torch.Tensor] = None,
2145
  image_sizes: Optional[torch.Tensor] = None,
2146
  **kwargs,
2147
  ) -> Union[GenerateOutput, torch.LongTensor]:
@@ -2150,21 +2150,21 @@ class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
2150
  if "inputs_embeds" in kwargs:
2151
  raise NotImplementedError("`inputs_embeds` is not supported")
2152
 
2153
- if images is not None:
2154
  (
2155
- inputs,
2156
  position_ids,
2157
  attention_mask,
2158
  _,
2159
  inputs_embeds,
2160
  _
2161
  ) = self.prepare_inputs_labels_for_multimodal(
2162
- inputs,
2163
  position_ids,
2164
  attention_mask,
2165
  None,
2166
  None,
2167
- images,
2168
  image_sizes=image_sizes
2169
  )
2170
  else:
@@ -2179,17 +2179,17 @@ class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
2179
 
2180
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
2181
  inputs_embeds=None, **kwargs):
2182
- images = kwargs.pop("images", None)
2183
  image_sizes = kwargs.pop("image_sizes", None)
2184
  inputs = super().prepare_inputs_for_generation(
2185
  input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
2186
  )
2187
  if images is not None:
2188
- inputs['images'] = images
2189
  if image_sizes is not None:
2190
  inputs['image_sizes'] = image_sizes
2191
  return inputs
2192
 
2193
 
2194
  AutoConfig.register("llava_qwen2", LlavaConfig)
2195
- AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)
 
2140
  @torch.no_grad()
2141
  def generate(
2142
  self,
2143
+ input_ids: Optional[torch.Tensor] = None,
2144
+ pixel_values: Optional[torch.Tensor] = None,
2145
  image_sizes: Optional[torch.Tensor] = None,
2146
  **kwargs,
2147
  ) -> Union[GenerateOutput, torch.LongTensor]:
 
2150
  if "inputs_embeds" in kwargs:
2151
  raise NotImplementedError("`inputs_embeds` is not supported")
2152
 
2153
+ if pixel_values is not None:
2154
  (
2155
+ input_ids,
2156
  position_ids,
2157
  attention_mask,
2158
  _,
2159
  inputs_embeds,
2160
  _
2161
  ) = self.prepare_inputs_labels_for_multimodal(
2162
+ input_ids,
2163
  position_ids,
2164
  attention_mask,
2165
  None,
2166
  None,
2167
+ pixel_values,
2168
  image_sizes=image_sizes
2169
  )
2170
  else:
 
2179
 
2180
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
2181
  inputs_embeds=None, **kwargs):
2182
+ images = kwargs.pop("pixel_values", None)
2183
  image_sizes = kwargs.pop("image_sizes", None)
2184
  inputs = super().prepare_inputs_for_generation(
2185
  input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
2186
  )
2187
  if images is not None:
2188
+ inputs['pixel_values'] = images
2189
  if image_sizes is not None:
2190
  inputs['image_sizes'] = image_sizes
2191
  return inputs
2192
 
2193
 
2194
  AutoConfig.register("llava_qwen2", LlavaConfig)
2195
+ AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)
preprocessor_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_fastvlm.FastVLMImageProcessor"
4
+ },
5
+ "image_processor_type": "FastVLMImageProcessor",
6
+ "crop_size": {
7
+ "height": 1024,
8
+ "width": 1024
9
+ },
10
+ "do_center_crop": true,
11
+ "do_convert_rgb": true,
12
+ "do_normalize": true,
13
+ "do_rescale": true,
14
+ "do_resize": true,
15
+ "image_mean": [
16
+ 0.0,
17
+ 0.0,
18
+ 0.0
19
+ ],
20
+ "image_std": [
21
+ 1.0,
22
+ 1.0,
23
+ 1.0
24
+ ],
25
+ "resample": 3,
26
+ "rescale_factor": 0.00392156862745098,
27
+ "size": {
28
+ "shortest_edge": 1024
29
+ }
30
+ }
31
+
processing_fastvlm.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast
4
+ from transformers.image_processing_utils import BaseImageProcessor
5
+ from transformers.image_utils import ImageInput
6
+ from typing import Any, Dict, List, Optional, Union
7
+ from PIL import Image
8
+
9
+ from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
10
+
11
+ # Adapted from transformers.models.llava_next.image_processing_llava_next.expand_to_square
12
+ def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor:
13
+ """
14
+ Expands an image to a square by adding a background color.
15
+ """
16
+ c, height, width = image.shape
17
+ if width == height:
18
+ return image
19
+ elif width > height:
20
+ result = torch.ones((c, width, width), dtype=image.dtype) * background_color
21
+ result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image
22
+ return result
23
+ else:
24
+ result = torch.ones((c, height, height), dtype=image.dtype) * background_color
25
+ result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image
26
+ return result
27
+
28
+
29
+ class FastVLMImageProcessor(CLIPImageProcessorFast):
30
+ def _preprocess(self, images, **kwargs):
31
+ image_sizes = [image.shape[-2:][::-1] for image in images]
32
+ images = [expand_to_square(image) for image in images]
33
+ images = super()._preprocess(images, **kwargs)
34
+ pixel_values = torch.stack(images.pixel_values, dim=0)
35
+ return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes})
36
+
37
+ class FastVLMProcessor(ProcessorMixin):
38
+ attributes = ["tokenizer", "image_processor"]
39
+ image_processor_class = "AutoImageProcessor"
40
+ tokenizer_class = "AutoTokenizer"
41
+
42
+ def __init__(
43
+ self,
44
+ tokenizer,
45
+ image_processor,
46
+ chat_template=None,
47
+ **kwargs
48
+ ):
49
+ super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs)
50
+
51
+ def __call__(
52
+ self,
53
+ images: ImageInput = None,
54
+ text: Optional[Union[str, List[str]]] = None,
55
+ return_tensors: Optional[str] = "pt",
56
+ **kwargs,
57
+ ) -> BatchFeature:
58
+ if isinstance(text, str):
59
+ text = [text]
60
+ elif not isinstance(text, list) and not isinstance(text[0], str):
61
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
62
+
63
+ image_inputs = {}
64
+ if images is not None:
65
+ image_inputs = self.image_processor(images=images)
66
+
67
+ image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64)
68
+ input_ids = torch.tensor([], dtype=torch.int64)
69
+ attention_mask = torch.tensor([], dtype=torch.int64)
70
+ for prompt in text:
71
+ image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)]
72
+ if len(image_indexes) > 1:
73
+ raise ValueError(
74
+ f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead."
75
+ )
76
+
77
+ # DEFAULT_IMAGE_TOKEN is -200, not in the vocab (so we can't tokenize the full string)
78
+ pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN)
79
+ pre_ids = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids
80
+ post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids
81
+
82
+ sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64)
83
+ sample_mask = torch.ones_like(sample_ids)
84
+
85
+ input_ids = torch.cat([input_ids, sample_ids], dim=0)
86
+ attention_mask = torch.cat([attention_mask, sample_mask], dim=0)
87
+
88
+ return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors)
processor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{%- if messages is string -%}\n {{- messages -}}\n{%- else -%}\n {%- for message in messages -%}\n {%- if loop.first and messages[0]['role'] != 'system' -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' -}}\n {%- endif -%}\n {{- '<|im_start|>' + message['role'] + '\\n' -}}\n {%- if message['content'] is string -%}\n {{- message['content'] -}}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{- '<image>\\n' -}}\n {%- elif item['type'] == 'text' -%}\n {{- item['text'] -}}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{- raise_exception(\"Invalid content type\") -}}\n {%- endif -%}\n {{- '<|im_end|>' + '\\n' -}}\n {%- endfor -%}\n {%- if add_generation_prompt -%}\n {{- '<|im_start|>assistant\\n' -}}\n {%- endif -%}\n{%- endif -%}\n"
3
+ }