shilinxu commited on
Commit
050b696
·
verified ·
1 Parent(s): 3940890

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
2
+ You are a helpful assistant.<|im_end|>
3
+ {% endif %}<|im_start|>{{ message['role'] }}
4
+ {% if message['role'] == 'assistant' %}{% generation %}{{ message['content'][0]['text'] }}<|im_end|>
5
+ {% endgeneration %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}<|vision_start|><|image_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
6
+ {% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
7
+ {% endif %}
config.json ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SmallVLMForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_smallvlm.SmallVLMConfig",
7
+ "AutoModelForCausalLM": "modeling_smallvlm.SmallVLMForCausalLM"
8
+ },
9
+ "image_token_id": 151655,
10
+ "language_model_config": {
11
+ "_name_or_path": "pretrained/Qwen/Qwen3-1.7B",
12
+ "add_cross_attention": false,
13
+ "architectures": [
14
+ "Qwen3ForCausalLM"
15
+ ],
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "bad_words_ids": null,
19
+ "begin_suppress_tokens": null,
20
+ "bos_token_id": 151643,
21
+ "chunk_size_feed_forward": 0,
22
+ "cross_attention_hidden_size": null,
23
+ "decoder_start_token_id": null,
24
+ "diversity_penalty": 0.0,
25
+ "do_sample": false,
26
+ "early_stopping": false,
27
+ "encoder_no_repeat_ngram_size": 0,
28
+ "eos_token_id": 151645,
29
+ "exponential_decay_length_penalty": null,
30
+ "finetuning_task": null,
31
+ "forced_bos_token_id": null,
32
+ "forced_eos_token_id": null,
33
+ "head_dim": 128,
34
+ "hidden_act": "silu",
35
+ "hidden_size": 2048,
36
+ "id2label": {
37
+ "0": "LABEL_0",
38
+ "1": "LABEL_1"
39
+ },
40
+ "initializer_range": 0.02,
41
+ "intermediate_size": 6144,
42
+ "is_decoder": false,
43
+ "is_encoder_decoder": false,
44
+ "label2id": {
45
+ "LABEL_0": 0,
46
+ "LABEL_1": 1
47
+ },
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 40960,
51
+ "max_window_layers": 28,
52
+ "min_length": 0,
53
+ "model_type": "qwen3",
54
+ "no_repeat_ngram_size": 0,
55
+ "num_attention_heads": 16,
56
+ "num_beam_groups": 1,
57
+ "num_beams": 1,
58
+ "num_hidden_layers": 28,
59
+ "num_key_value_heads": 8,
60
+ "num_return_sequences": 1,
61
+ "output_attentions": false,
62
+ "output_hidden_states": false,
63
+ "output_scores": false,
64
+ "pad_token_id": null,
65
+ "prefix": null,
66
+ "problem_type": null,
67
+ "pruned_heads": {},
68
+ "remove_invalid_values": false,
69
+ "repetition_penalty": 1.0,
70
+ "return_dict": true,
71
+ "return_dict_in_generate": false,
72
+ "rms_norm_eps": 1e-06,
73
+ "rope_scaling": null,
74
+ "rope_theta": 1000000,
75
+ "sep_token_id": null,
76
+ "sliding_window": null,
77
+ "suppress_tokens": null,
78
+ "task_specific_params": null,
79
+ "temperature": 1.0,
80
+ "tf_legacy_loss": false,
81
+ "tie_encoder_decoder": false,
82
+ "tie_word_embeddings": true,
83
+ "tokenizer_class": null,
84
+ "top_k": 50,
85
+ "top_p": 1.0,
86
+ "torch_dtype": "bfloat16",
87
+ "torchscript": false,
88
+ "typical_p": 1.0,
89
+ "use_bfloat16": false,
90
+ "use_cache": true,
91
+ "use_sliding_window": false,
92
+ "vocab_size": 151936
93
+ },
94
+ "model_type": "smallvlm",
95
+ "torch_dtype": "bfloat16",
96
+ "transformers_version": "4.52.1",
97
+ "video_token_id": 151656,
98
+ "vision_abstractor_config": null,
99
+ "vision_model_config": {
100
+ "_name_or_path": "",
101
+ "add_cross_attention": false,
102
+ "architectures": null,
103
+ "auto_map": {
104
+ "AutoConfig": "configuration_moonvit.MoonViTConfig",
105
+ "AutoModel": "modeling_moonvit.MoonVitPretrainedModel"
106
+ },
107
+ "bad_words_ids": null,
108
+ "begin_suppress_tokens": null,
109
+ "bos_token_id": null,
110
+ "chunk_size_feed_forward": 0,
111
+ "cross_attention_hidden_size": null,
112
+ "decoder_start_token_id": null,
113
+ "diversity_penalty": 0.0,
114
+ "do_sample": false,
115
+ "early_stopping": false,
116
+ "encoder_no_repeat_ngram_size": 0,
117
+ "eos_token_id": null,
118
+ "exponential_decay_length_penalty": null,
119
+ "finetuning_task": null,
120
+ "forced_bos_token_id": null,
121
+ "forced_eos_token_id": null,
122
+ "hidden_size": 1152,
123
+ "id2label": {
124
+ "0": "LABEL_0",
125
+ "1": "LABEL_1"
126
+ },
127
+ "init_pos_emb_height": 64,
128
+ "init_pos_emb_width": 64,
129
+ "intermediate_size": 4304,
130
+ "is_decoder": false,
131
+ "is_encoder_decoder": false,
132
+ "label2id": {
133
+ "LABEL_0": 0,
134
+ "LABEL_1": 1
135
+ },
136
+ "length_penalty": 1.0,
137
+ "max_length": 20,
138
+ "max_position_embeddings": 128000,
139
+ "merge_kernel_size": [
140
+ 2,
141
+ 2
142
+ ],
143
+ "min_length": 0,
144
+ "model_type": "moonvit",
145
+ "no_repeat_ngram_size": 0,
146
+ "num_attention_heads": 16,
147
+ "num_beam_groups": 1,
148
+ "num_beams": 1,
149
+ "num_hidden_layers": 27,
150
+ "num_return_sequences": 1,
151
+ "output_attentions": false,
152
+ "output_hidden_states": false,
153
+ "output_scores": false,
154
+ "pad_token_id": null,
155
+ "patch_size": 14,
156
+ "prefix": null,
157
+ "problem_type": null,
158
+ "pruned_heads": {},
159
+ "remove_invalid_values": false,
160
+ "repetition_penalty": 1.0,
161
+ "return_dict": true,
162
+ "return_dict_in_generate": false,
163
+ "rope_scaling": {
164
+ "mrope_section": [
165
+ 12,
166
+ 12,
167
+ 12
168
+ ],
169
+ "rope_type": "default",
170
+ "type": "default"
171
+ },
172
+ "rope_theta": 1000000.0,
173
+ "sep_token_id": null,
174
+ "suppress_tokens": null,
175
+ "task_specific_params": null,
176
+ "temperature": 1.0,
177
+ "text_hidden_size": 2048,
178
+ "tf_legacy_loss": false,
179
+ "tie_encoder_decoder": false,
180
+ "tie_word_embeddings": true,
181
+ "tokenizer_class": null,
182
+ "top_k": 50,
183
+ "top_p": 1.0,
184
+ "torch_dtype": null,
185
+ "torchscript": false,
186
+ "typical_p": 1.0,
187
+ "use_bfloat16": false
188
+ },
189
+ "vision_start_token_id": 151652
190
+ }
configuration_moonvit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class MoonViTConfig(PretrainedConfig):
5
+ model_type = "moonvit"
6
+
7
+ def __init__(
8
+ self,
9
+ patch_size: int = 14,
10
+ init_pos_emb_height: int = 64,
11
+ init_pos_emb_width: int = 64,
12
+ num_attention_heads: int = 16,
13
+ num_hidden_layers: int = 27,
14
+ hidden_size: int = 1152,
15
+ text_hidden_size: int = 2048,
16
+ intermediate_size: int = 4304,
17
+ merge_kernel_size: tuple[int, int] = (2, 2),
18
+ rope_theta: float = 1000000.0,
19
+ max_position_embeddings: int = 128000,
20
+ rope_scaling: dict = {'type': 'default', 'mrope_section': [12, 12, 12], 'rope_type': 'default'},
21
+ **kwargs,
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.patch_size = patch_size
25
+ # Positional embedding config
26
+ self.init_pos_emb_height = init_pos_emb_height
27
+ self.init_pos_emb_width = init_pos_emb_width
28
+ # Transformer config
29
+ self.num_hidden_layers = num_hidden_layers
30
+ self.num_attention_heads = num_attention_heads
31
+ self.hidden_size = hidden_size
32
+ self.text_hidden_size = text_hidden_size
33
+ self.intermediate_size = intermediate_size
34
+ # Patch merger config
35
+ self.merge_kernel_size = merge_kernel_size
36
+
37
+ self.rope_theta = rope_theta
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.rope_scaling = rope_scaling
configuration_smallvlm.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, AutoConfig, CONFIG_MAPPING
2
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
3
+
4
+
5
+ class SmallVLMConfig(PretrainedConfig):
6
+ model_type = "smallvlm"
7
+ is_composition = True
8
+
9
+ def __init__(
10
+ self,
11
+ language_model_config=None,
12
+ vision_model_config=None,
13
+ image_token_id=None,
14
+ **kwargs):
15
+ super().__init__(**kwargs)
16
+ if isinstance(language_model_config, dict):
17
+ if '_name_or_path' not in language_model_config:
18
+ language_model_config['_name_or_path'] = self._name_or_path
19
+ language_model_type = language_model_config.get('model_type', '')
20
+ is_remote_code = '.' in language_model_config.get('auto_map', {}).get('AutoConfig', '')
21
+ if language_model_type in CONFIG_MAPPING and not is_remote_code:
22
+ language_model_config = AutoConfig.for_model(**language_model_config)
23
+ elif language_model_type:
24
+ Config = get_class_from_dynamic_module(language_model_config["auto_map"]["AutoConfig"], language_model_config['_name_or_path'])
25
+ language_model_config = Config(**language_model_config)
26
+ self.language_model_config = language_model_config
27
+
28
+ if isinstance(vision_model_config, dict):
29
+ # if '_name_or_path' not in vision_model_config:
30
+ vision_model_config['_name_or_path'] = self._name_or_path
31
+ vision_model_type = vision_model_config.get('model_type', '')
32
+ is_remote_code = '.' in vision_model_config.get('auto_map', {}).get('AutoConfig', '')
33
+ if vision_model_type in CONFIG_MAPPING and not is_remote_code:
34
+ vision_model_config = AutoConfig.for_model(**vision_model_config)
35
+ elif vision_model_type:
36
+ Config = get_class_from_dynamic_module(vision_model_config["auto_map"]["AutoConfig"], vision_model_config['_name_or_path'])
37
+ vision_model_config = Config(**vision_model_config)
38
+ self.vision_model_config = vision_model_config
39
+
40
+ self.image_token_id = image_token_id
41
+ self.video_token_id = 151656
42
+ self.vision_start_token_id = 151652
43
+
44
+ @property
45
+ def hidden_size(self):
46
+ return self.language_model_config.hidden_size
47
+
48
+ @classmethod
49
+ def from_dict(cls, config_dict, **kwargs):
50
+ if 'name_or_path' in kwargs:
51
+ config_dict['_name_or_path'] = kwargs.pop('name_or_path')
52
+ return super().from_dict(config_dict, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.52.1"
13
+ }
image_processing_moonvit.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for KimiVL."""
2
+
3
+ import math
4
+ import numpy as np
5
+ from PIL import Image
6
+ from typing import Optional, Union
7
+
8
+ import torch
9
+ from torchvision.transforms import functional as TF
10
+ from transformers.image_utils import ImageInput, make_list_of_images, valid_images
11
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
12
+ from transformers.utils import TensorType
13
+
14
+
15
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
16
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
17
+
18
+
19
+ class MoonViTImageProcessor(BaseImageProcessor):
20
+ model_type = "moonvit"
21
+
22
+ def __init__(
23
+ self,
24
+ patch_size: int = 14,
25
+ pad_input: bool = False,
26
+ image_mean: tuple[float, float, float] = OPENAI_DATASET_MEAN,
27
+ image_std: tuple[float, float, float] = OPENAI_DATASET_STD,
28
+ in_token_limit: int = 4096,
29
+ merge_kernel_size: list[int, int] = [2, 2],
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.in_token_limit = in_token_limit
34
+ self.patch_size = patch_size
35
+ self.pad_input = pad_input
36
+ self.image_mean = image_mean
37
+ self.image_std = image_std
38
+ self.merge_kernel_size = merge_kernel_size
39
+
40
+ def rescale(
41
+ self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2]
42
+ ) -> Image.Image:
43
+ w, h = image.size
44
+ patch_size = self.patch_size
45
+
46
+ if (w // patch_size) * (h // patch_size) > self.in_token_limit:
47
+ scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size)))
48
+ new_w, new_h = int(w * scale), int(h * scale)
49
+ image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
50
+ if self.pad_input:
51
+ new_w, new_h = image.size
52
+ pad_size_h = merge_kernel_size[0] * patch_size
53
+ pad_size_w = merge_kernel_size[1] * patch_size
54
+
55
+ pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h
56
+ pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w
57
+
58
+ image = TF.pad(image, (0, 0, pad_w, pad_h))
59
+ else:
60
+ new_w, new_h = image.size
61
+ new_w = new_w - new_w % patch_size
62
+ new_h = new_h - new_h % patch_size
63
+ image = TF.center_crop(image, (new_h, new_w))
64
+
65
+ w, h = image.size
66
+ if w // patch_size >= 512 or h // patch_size >= 512:
67
+ raise ValueError("Exceed pos emb")
68
+
69
+ return image
70
+
71
+ def to_tensor(self, image: Image.Image) -> torch.Tensor:
72
+ return TF.to_tensor(image.convert("RGB"))
73
+
74
+ def normalize(self, image: torch.Tensor) -> torch.Tensor:
75
+ return TF.normalize(image, self.image_mean, self.image_std)
76
+
77
+ def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]:
78
+ patch_size = self.patch_size
79
+ C, H, W = image.shape
80
+ patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
81
+ patches = patches.permute(1, 3, 0, 2, 4)
82
+ patches = patches.contiguous().view(-1, C, patch_size, patch_size)
83
+ grid_hw = (H // patch_size, W // patch_size)
84
+ return patches, grid_hw
85
+
86
+ def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]:
87
+ """
88
+ Preprocess image and patchify it.
89
+
90
+ Args:
91
+ image (`ImageInput`):
92
+ Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
93
+
94
+ Returns:
95
+ patches: torch.Tensor
96
+ grid_hw: list[int, int]
97
+ """
98
+ image = self.rescale(image, self.merge_kernel_size)
99
+ image = self.to_tensor(image)
100
+ image = self.normalize(image)
101
+ patches, grid_hw = self.patchify(image)
102
+ return patches, grid_hw
103
+
104
+ def preprocess(
105
+ self,
106
+ images: ImageInput,
107
+ return_tensors: Optional[Union[str, TensorType]] = None,
108
+ ) -> BatchFeature:
109
+ images = make_list_of_images(images)
110
+
111
+ if not valid_images(images):
112
+ raise ValueError(
113
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
114
+ "torch.Tensor, tf.Tensor or jax.ndarray."
115
+ )
116
+
117
+ pixel_values, image_grid_hws = [], []
118
+ for image in images:
119
+ patches, image_grid_hw = self._preprocess(image)
120
+ pixel_values.append(patches)
121
+ image_grid_hws.append(image_grid_hw)
122
+ pixel_values = torch.concat(pixel_values, dim=0)
123
+ image_grid_hws = np.array(image_grid_hws)
124
+ data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws}
125
+
126
+ return BatchFeature(data=data, tensor_type=return_tensors)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:574587a763e7adc19186a15a9da4969aaa0563310c1d695f732c06e9831cd95c
3
+ size 4938445104
modeling_moonvit.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import math
5
+ from copy import deepcopy
6
+ from typing import Union, Tuple, Sequence, Optional, List
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import GELUActivation, ACT2FN, PytorchGELUTanh
13
+
14
+ from transformers.activations import PytorchGELUTanh
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import is_flash_attn_2_available
17
+
18
+ from .configuration_moonvit import MoonViTConfig
19
+
20
+ if is_flash_attn_2_available():
21
+ from flash_attn import flash_attn_varlen_func
22
+ else:
23
+ flash_attn_varlen_func = None
24
+
25
+
26
+ def rotate_half(x):
27
+ x1 = x[..., : x.shape[-1] // 2]
28
+ x2 = x[..., x.shape[-1] // 2 :]
29
+ return torch.cat((-x2, x1), dim=-1)
30
+
31
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section=[12, 12, 12], unsqueeze_dim=1):
32
+ mrope_section = mrope_section * 2
33
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
34
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
35
+ q_embed = (q * cos) + (rotate_half(q) * sin)
36
+ k_embed = (k * cos) + (rotate_half(k) * sin)
37
+ return q_embed, k_embed
38
+
39
+ def get_rope_index(
40
+ image_token_id,
41
+ video_token_id,
42
+ vision_start_token_id,
43
+ spatial_merge_size: int = 2,
44
+ input_ids: Optional[torch.LongTensor] = None,
45
+ image_grid_thw: Optional[torch.LongTensor] = None,
46
+ video_grid_thw: Optional[torch.LongTensor] = None,
47
+ second_per_grid_ts: Optional[torch.Tensor] = None,
48
+ attention_mask: Optional[torch.Tensor] = None,
49
+ ) -> tuple[torch.Tensor, torch.Tensor]:
50
+
51
+ mrope_position_deltas = []
52
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
53
+ total_input_ids = input_ids
54
+ if attention_mask is None:
55
+ attention_mask = torch.ones_like(total_input_ids)
56
+ position_ids = torch.ones(3,input_ids.shape[0],input_ids.shape[1],dtype=input_ids.dtype,device=input_ids.device)
57
+ image_index, video_index = 0, 0
58
+ attention_mask = attention_mask.to(total_input_ids.device)
59
+ for i, input_ids in enumerate(total_input_ids):
60
+ input_ids = input_ids[attention_mask[i] == 1]
61
+ image_nums, video_nums = 0, 0
62
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
63
+ vision_tokens = input_ids[vision_start_indices + 1]
64
+ image_nums = (vision_tokens == image_token_id).sum()
65
+ video_nums = (vision_tokens == video_token_id).sum()
66
+ input_tokens = input_ids.tolist()
67
+ llm_pos_ids_list: list = []
68
+ st = 0
69
+ remain_images, remain_videos = image_nums, video_nums
70
+ for _ in range(image_nums + video_nums):
71
+ if image_token_id in input_tokens and remain_images > 0:
72
+ ed_image = input_tokens.index(image_token_id, st)
73
+ else:
74
+ ed_image = len(input_tokens) + 1
75
+ if video_token_id in input_tokens and remain_videos > 0:
76
+ ed_video = input_tokens.index(video_token_id, st)
77
+ else:
78
+ ed_video = len(input_tokens) + 1
79
+ if ed_image < ed_video:
80
+ t, h, w = (image_grid_thw[image_index][0],image_grid_thw[image_index][1],image_grid_thw[image_index][2])
81
+ second_per_grid_t = 0
82
+ image_index += 1
83
+ remain_images -= 1
84
+ ed = ed_image
85
+
86
+ else:
87
+ t, h, w = (video_grid_thw[video_index][0],video_grid_thw[video_index][1],video_grid_thw[video_index][2])
88
+ if second_per_grid_ts is not None:
89
+ second_per_grid_t = second_per_grid_ts[video_index]
90
+ else:
91
+ second_per_grid_t = 1.0
92
+ video_index += 1
93
+ remain_videos -= 1
94
+ ed = ed_video
95
+ llm_grid_t, llm_grid_h, llm_grid_w = (t.item(),h.item() // spatial_merge_size,w.item() // spatial_merge_size)
96
+ text_len = ed - st
97
+
98
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
99
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
100
+
101
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
102
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
103
+
104
+ ## normalize type, send to device.
105
+ second_per_grid_t = torch.as_tensor(second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device)
106
+
107
+ time_tensor = expanded_range * second_per_grid_t * 2
108
+
109
+ time_tensor_long = time_tensor.long()
110
+ t_index = time_tensor_long.flatten()
111
+
112
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
113
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
114
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
115
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
116
+
117
+ if st < len(input_tokens):
118
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
119
+ text_len = len(input_tokens) - st
120
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
121
+
122
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
123
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
124
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
125
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
126
+ return position_ids, mrope_position_deltas
127
+ else:
128
+ if attention_mask is not None:
129
+ position_ids = attention_mask.long().cumsum(-1) - 1
130
+ position_ids.masked_fill_(attention_mask == 0, 1)
131
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
132
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
133
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
134
+ else:
135
+ position_ids = (torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand(3, input_ids.shape[0], -1))
136
+ mrope_position_deltas = torch.zeros([input_ids.shape[0], 1],device=input_ids.device,dtype=input_ids.dtype,)
137
+
138
+ return position_ids, mrope_position_deltas
139
+
140
+ def multihead_attention(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ q_cu_seqlens: Optional[torch.Tensor] = None,
145
+ k_cu_seqlens: Optional[torch.Tensor] = None,
146
+ ):
147
+ """Multi-head attention using flash attention 2.
148
+ Args:
149
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
150
+ or (tot_seqlens, num_heads, head_dim) if packing.
151
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
152
+ The first element should be 0 and the last element should be q.shape[0].
153
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
154
+ The first element should be 0 and the last element should be k.shape[0].
155
+ Returns:
156
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
157
+ where dim = num_heads * head_dim
158
+ """
159
+ # Unified format legal check
160
+ assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
161
+ assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
162
+ assert (
163
+ k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
164
+ ), "k_cu_seqlens must sum to k.shape[0]"
165
+ assert q.dtype in [
166
+ torch.bfloat16,
167
+ torch.float16,
168
+ ], f"unsupported dtype {q.dtype} for multihead attn"
169
+
170
+ max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
171
+ max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
172
+ attn_out = flash_attn_varlen_func(
173
+ q,
174
+ k,
175
+ v,
176
+ q_cu_seqlens,
177
+ k_cu_seqlens,
178
+ max_seqlen_q,
179
+ max_seqlen_k,
180
+ causal=False,
181
+ )
182
+ attn_out = attn_out.flatten(start_dim=-2)
183
+
184
+ return attn_out
185
+
186
+
187
+ def sdpa_attention(
188
+ q: torch.Tensor,
189
+ k: torch.Tensor,
190
+ v: torch.Tensor,
191
+ attention_mask: torch.Tensor,
192
+ ) -> torch.Tensor:
193
+ """SDPA attention.
194
+ Args:
195
+ q, k, v: tensor of shape (batch_size, num_heads, seqlen, head_dim),
196
+ or (batch_size, seqlen, num_heads, head_dim) if packing.
197
+ """
198
+ # bs, num_heads, seq_length, head_dim = q.shape
199
+ # attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
200
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
201
+ attn_output = attn_output.transpose(1, 2)
202
+ return attn_output
203
+
204
+
205
+ def eager_attention(
206
+ q: torch.Tensor,
207
+ k: torch.Tensor,
208
+ v: torch.Tensor,
209
+ q_cu_seqlens: Optional[torch.Tensor] = None,
210
+ k_cu_seqlens: Optional[torch.Tensor] = None,
211
+ ) -> torch.Tensor:
212
+ seq_length = q.shape[0]
213
+ attention_mask = torch.zeros(
214
+ [1, seq_length, seq_length], device=q.device, dtype=torch.bool
215
+ )
216
+ for i in range(1, len(q_cu_seqlens)):
217
+ attention_mask[
218
+ ...,
219
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
220
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
221
+ ] = True
222
+ q = q.transpose(0, 1)
223
+ k = k.transpose(0, 1)
224
+ v = v.transpose(0, 1)
225
+
226
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
227
+ attn_weight += attention_mask
228
+ attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
229
+
230
+ attn_output = attn_weight @ v
231
+ attn_output = attn_output.transpose(0, 1)
232
+ attn_output = attn_output.reshape(seq_length, -1)
233
+ return attn_output
234
+
235
+
236
+ VL_VISION_ATTENTION_FUNCTIONS = {
237
+ "flash_attention_2": multihead_attention,
238
+ "sdpa": sdpa_attention,
239
+ "eager": eager_attention,
240
+ }
241
+
242
+
243
+ def _apply_rope_input_validation(x, freqs_cis):
244
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
245
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
246
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
247
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
248
+
249
+
250
+ def apply_rope(
251
+ xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
252
+ ) -> tuple[torch.Tensor, torch.Tensor]:
253
+ """
254
+ Args: (The leading dimensions of all inputs should be the same)
255
+ xq: query, tensor of shape (..., num_heads, head_dim)
256
+ xk: key, tensor of shape (..., num_heads, head_dim)
257
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
258
+ Returns:
259
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
260
+ """
261
+ _apply_rope_input_validation(xq, freqs_cis)
262
+ _apply_rope_input_validation(xk, freqs_cis)
263
+
264
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
265
+ # ..., num_heads, head_dim/2
266
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
267
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
268
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
269
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
270
+ return xq_out.type_as(xq), xk_out.type_as(xk)
271
+
272
+
273
+ class Learnable2DInterpPosEmb(nn.Module):
274
+ def __init__(
275
+ self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
276
+ ) -> None:
277
+ super().__init__()
278
+ self.height = height
279
+ self.width = width
280
+ self.interpolation_mode = interpolation_mode
281
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
282
+ self.reset_parameters()
283
+
284
+ def reset_parameters(self):
285
+ nn.init.normal_(self.weight)
286
+
287
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
288
+ pos_embs = []
289
+ for shape in grid_hws.tolist():
290
+ if shape == self.weight.shape[:-1]:
291
+ pos_embs.append(self.weight.flatten(end_dim=1))
292
+ else:
293
+ pos_embs.append(
294
+ F.interpolate(
295
+ self.weight.permute((2, 0, 1)).unsqueeze(0),
296
+ size=shape,
297
+ mode=self.interpolation_mode,
298
+ )
299
+ .squeeze(0)
300
+ .permute((1, 2, 0))
301
+ .flatten(end_dim=1)
302
+ )
303
+ out = x + torch.cat(pos_embs)
304
+ return out
305
+
306
+
307
+ class MoonVisionPatchEmbed(nn.Module):
308
+
309
+ def __init__(
310
+ self,
311
+ out_dim: int,
312
+ in_dim: int = 3,
313
+ patch_size: Union[int, Tuple[int, int]] = (14, 14),
314
+ pos_emb_height: int = 14,
315
+ pos_emb_width: int = 14,
316
+ ):
317
+ super().__init__()
318
+ assert isinstance(
319
+ patch_size, (int, Sequence)
320
+ ), f"Invalid patch_size type: {type(patch_size)}"
321
+ if isinstance(patch_size, int):
322
+ patch_size = (patch_size, patch_size)
323
+ assert (
324
+ len(patch_size) == 2
325
+ ), f"Expected patch_size to be a tuple of 2, got {patch_size}"
326
+ self.patch_size = patch_size
327
+
328
+ self.proj = nn.Conv2d(
329
+ in_dim, out_dim, kernel_size=patch_size, stride=patch_size
330
+ )
331
+
332
+ self.pos_emb = Learnable2DInterpPosEmb(
333
+ height=pos_emb_height, width=pos_emb_width, dim=out_dim
334
+ )
335
+
336
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
337
+ """
338
+ Args:
339
+ x (L, Channels): input tensor
340
+ grid_hws (N, 2): grid height and width
341
+ Returns:
342
+ (L, Cout) tensor
343
+ """
344
+ x = self.proj(x).view(x.size(0), -1)
345
+ # apply positional embedding
346
+ x = self.pos_emb(x, grid_hws)
347
+ return x
348
+
349
+
350
+ class Rope2DPosEmb(nn.Module):
351
+ """2D rotary position embedding with multi-resolution support.
352
+ This class is intended to be used in the following way:
353
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
354
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
355
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
356
+ The rope is shared across all attention layers and all heads.
357
+ Refs:
358
+ - RoFormer: https://arxiv.org/abs/2104.09864
359
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
360
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
361
+ Args:
362
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
363
+ max_height (int): the maximum height of the 2D grid
364
+ max_width (int): the maximum width of the 2D grid
365
+ theta_base (float): the base of the theta
366
+ device (str): the device to store the precomputed cis
367
+ """
368
+
369
+ def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
370
+ super().__init__()
371
+ self.dim = dim
372
+ assert self.dim % 4 == 0, "dim must be divisible by 4"
373
+ self.max_height = max_height
374
+ self.max_width = max_width
375
+ self.theta_base = theta_base
376
+
377
+ self.freqs_cis = None
378
+
379
+ def extra_repr(self):
380
+ return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
381
+
382
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
383
+ """Calculate the cis(freqs) for each position in the 2D grid.
384
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
385
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
386
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
387
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
388
+ """
389
+ N = self.max_height * self.max_width
390
+ flat_pos = torch.arange(0, N).float().to(device)
391
+ x_pos = flat_pos % self.max_width
392
+ y_pos = flat_pos // self.max_width
393
+ dim_range = (
394
+ torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
395
+ ) # C/4
396
+ freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
397
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
398
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
399
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
400
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
401
+ # N, C/4, 2
402
+ freqs_cis = torch.cat(
403
+ [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
404
+ )
405
+ # max_height, max_width, C/2
406
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
407
+ return freqs_cis
408
+
409
+ def get_freqs_cis(self, grid_hws: torch.Tensor) -> torch.Tensor:
410
+ """
411
+ Args:
412
+ grid_hws (torch.Tensor): grid height and width
413
+ Returns:
414
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
415
+ """
416
+ if self.freqs_cis is None:
417
+ self.freqs_cis = self._precompute_freqs_cis(grid_hws.device)
418
+
419
+ shapes = grid_hws.tolist()
420
+ assert all(
421
+ 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
422
+ ), (
423
+ shapes,
424
+ self.max_height,
425
+ self.max_width,
426
+ )
427
+ # freqs_cis = torch.cat(
428
+ # [self.freqs_cis[:h, :w].reshape(-1, self.dim // 2) for h, w in shapes],
429
+ # dim=0,
430
+ # )
431
+ max_h, max_w = grid_hws.max(dim=0).values.tolist()
432
+ max_h, max_w = max_h // 2, max_w // 2
433
+ freqs_cis = self.freqs_cis[:max_h, :max_w].reshape(-1, self.dim // 2).repeat(len(shapes), 1, 1)
434
+ return freqs_cis
435
+
436
+
437
+ class MLP2(nn.Module):
438
+ """
439
+ Args:
440
+ dims: [in_dim, hidden_dim, out_dim]
441
+ bias: whether to use bias in linear layer.
442
+ """
443
+
444
+ def __init__(self, dims: list[int], activation, bias=True):
445
+ super().__init__()
446
+ assert len(dims) == 3
447
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
448
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
449
+ self.activation = activation
450
+ for m in [self.fc0, self.fc1]:
451
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
452
+ if m.bias is not None:
453
+ nn.init.zeros_(m.bias)
454
+
455
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
456
+ x = self.fc0(x)
457
+ x = self.activation(x)
458
+ return self.fc1(x)
459
+
460
+
461
+ class MoonVitEncoderLayer(nn.Module):
462
+
463
+ def __init__(
464
+ self,
465
+ layer_idx: int,
466
+ num_heads: int,
467
+ hidden_dim: int,
468
+ mlp_dim: int,
469
+ attn_implementation: str = "eager",
470
+ activation=F.gelu,
471
+ attn_bias: bool = False,
472
+ ):
473
+ super().__init__()
474
+ self.layer_idx = layer_idx
475
+ self.num_heads = num_heads
476
+ self.hidden_dim = hidden_dim
477
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
478
+ self.attn_implementation = attn_implementation
479
+
480
+ self.norm0 = nn.LayerNorm(hidden_dim)
481
+ self.norm1 = nn.LayerNorm(hidden_dim)
482
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
483
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
484
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
485
+
486
+ def attention_qkvpacked(
487
+ self,
488
+ x: torch.Tensor,
489
+ attention_mask: torch.Tensor,
490
+ rope_freqs_cis: Optional[torch.Tensor] = None,
491
+ past_key_value = None
492
+ ):
493
+ """
494
+ Args:
495
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
496
+ cu_seqlens (torch.Tensor):
497
+ """
498
+ batch_size, seqlen, hidden_dim = x.shape
499
+ xqkv = self.wqkv(x)
500
+ xqkv = xqkv.view(batch_size, seqlen, 3, self.num_heads, self.hidden_size_per_attention_head)
501
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
502
+
503
+ xq = xq.transpose(1, 2)
504
+ xk = xk.transpose(1, 2)
505
+ xv = xv.transpose(1, 2)
506
+
507
+ # xq, xk = apply_rope(xq, xk, rope_freqs_cis)
508
+ cos, sin = rope_freqs_cis
509
+ xq, xk = apply_multimodal_rotary_pos_emb(xq, xk, cos, sin)
510
+
511
+ if past_key_value is not None:
512
+ xk, xv = past_key_value.update(xk, xv, self.layer_idx)
513
+
514
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
515
+ attn_out = attn_func(xq, xk, xv, attention_mask)
516
+ attn_out = attn_out.reshape(batch_size, seqlen, hidden_dim).contiguous()
517
+ attn_out = self.wo(attn_out)
518
+ return attn_out
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.Tensor,
523
+ attention_mask: torch.Tensor,
524
+ rope_freqs_cis: Union[torch.Tensor, None] = None,
525
+ past_key_value = None
526
+ ) -> torch.Tensor:
527
+ """
528
+ Args:
529
+ hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
530
+ Returns:
531
+ output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
532
+ """
533
+ residual = hidden_states
534
+ hidden_states = self.norm0(hidden_states)
535
+ attn_out = self.attention_qkvpacked(
536
+ hidden_states, attention_mask, rope_freqs_cis=rope_freqs_cis, past_key_value=past_key_value,
537
+ )
538
+ hidden_states = residual + attn_out
539
+
540
+ residual = hidden_states
541
+ hidden_states = self.mlp(self.norm1(hidden_states))
542
+ hidden_states = residual + hidden_states
543
+ return hidden_states
544
+
545
+
546
+ class MoonVitEncoder(nn.Module):
547
+
548
+ def __init__(
549
+ self,
550
+ hidden_dim: int,
551
+ num_layers: int,
552
+ block_cfg: dict,
553
+ ) -> None:
554
+ super().__init__()
555
+ self.blocks = nn.ModuleList(
556
+ [MoonVitEncoderLayer(layer_idx, **block_cfg) for layer_idx in range(num_layers)]
557
+ )
558
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
559
+ self.gradient_checkpointing = False
560
+
561
+ def forward(self, hidden_states, attention_mask, rope_freqs_cis, past_key_value=None) -> torch.Tensor:
562
+
563
+ for _, block in enumerate(self.blocks):
564
+ if self.gradient_checkpointing and self.training:
565
+ # hidden_states = self._gradient_checkpointing_func(
566
+ # block.__call__, hidden_states, attention_mask, rope_freqs_cis
567
+ # )
568
+ hidden_states = torch.utils.checkpoint.checkpoint(
569
+ block.__call__, hidden_states, attention_mask, rope_freqs_cis, past_key_value
570
+ )
571
+ else:
572
+ hidden_states = block(
573
+ hidden_states, attention_mask, rope_freqs_cis=rope_freqs_cis, past_key_value=past_key_value,
574
+ )
575
+
576
+ hidden_states = self.final_layernorm(hidden_states)
577
+
578
+ return hidden_states
579
+
580
+
581
+ def patch_merger(
582
+ x: torch.Tensor,
583
+ grid_hws: torch.Tensor,
584
+ merge_kernel_size: list[int, int] = (2, 2),
585
+ ) -> List[torch.Tensor]:
586
+ d_model = x.size(-1)
587
+
588
+ outputs = []
589
+ pre_sum = 0
590
+ for i, x_shape in enumerate(grid_hws.tolist()):
591
+ height, width = x_shape[0], x_shape[1]
592
+ # Get the current sequence
593
+ seq = x[pre_sum:pre_sum+height * width]
594
+ # Reshape along self.merge_kernel_size and concat to the last dimension
595
+ kernel_height, kernel_width = merge_kernel_size
596
+ new_height, new_width = height // kernel_height, width // kernel_width
597
+ reshaped_seq = seq.view(
598
+ new_height, kernel_height, new_width, kernel_width, d_model
599
+ )
600
+ reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
601
+ padded_seq = reshaped_seq.view(
602
+ new_height * new_width, kernel_height * kernel_width, -1
603
+ )
604
+ outputs.append(padded_seq)
605
+ pre_sum += height * width
606
+
607
+ return outputs
608
+
609
+
610
+ class MultiModalProjector(nn.Module):
611
+
612
+ def __init__(self, config):
613
+ super().__init__()
614
+
615
+ self.hidden_size = (
616
+ config.hidden_size
617
+ * config.merge_kernel_size[0]
618
+ * config.merge_kernel_size[1]
619
+ )
620
+
621
+ self.pre_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-05)
622
+ self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
623
+ self.act = GELUActivation()
624
+ self.linear_2 = nn.Linear(self.hidden_size, config.text_hidden_size, bias=True)
625
+ # self.linear_2 = nn.Linear(self.hidden_size, config.hidden_size, bias=True)
626
+
627
+ def forward(self, image_features: list[torch.Tensor]) -> torch.Tensor:
628
+ # image_features = torch.cat(image_features, dim=0)
629
+ # hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
630
+ hidden_states = self.pre_norm(image_features)
631
+ hidden_states = self.linear_1(hidden_states)
632
+ hidden_states = self.act(hidden_states)
633
+ hidden_states = self.linear_2(hidden_states)
634
+
635
+ return hidden_states
636
+
637
+
638
+ class MoonVitPretrainedModel(PreTrainedModel):
639
+ config_class = MoonViTConfig
640
+ model_type = "moonvit"
641
+ supports_gradient_checkpointing = True
642
+ _no_split_modules = ["PackingTransformer"]
643
+ _supports_flash_attn_2 = True
644
+ _supports_sdpa = True
645
+
646
+ def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
647
+ super().__init__(config, *inputs, **kwargs)
648
+ config = deepcopy(config)
649
+ self.merge_kernel_size = config.merge_kernel_size
650
+ self.patch_size = config.patch_size
651
+ self.patch_embed = MoonVisionPatchEmbed(
652
+ out_dim=config.hidden_size,
653
+ patch_size=config.patch_size,
654
+ pos_emb_height=config.init_pos_emb_height,
655
+ pos_emb_width=config.init_pos_emb_width,
656
+ )
657
+ self.rope_2d = Rope2DPosEmb(
658
+ config.hidden_size // config.num_attention_heads, 512, 512
659
+ )
660
+
661
+ self.encoder = MoonVitEncoder(
662
+ hidden_dim=config.hidden_size,
663
+ num_layers=config.num_hidden_layers,
664
+ block_cfg={
665
+ "num_heads": config.num_attention_heads,
666
+ "hidden_dim": config.hidden_size,
667
+ "mlp_dim": config.intermediate_size,
668
+ "activation": PytorchGELUTanh(),
669
+ "attn_bias": True,
670
+ "attn_implementation": config._attn_implementation,
671
+ },
672
+ )
673
+
674
+ self.pixel_merger = nn.Sequential(
675
+ nn.Linear(config.hidden_size*4, config.hidden_size),
676
+ nn.GELU(),
677
+ nn.Linear(config.hidden_size, config.hidden_size)
678
+ )
679
+
680
+ self.projector = nn.Sequential(
681
+ nn.LayerNorm(config.hidden_size),
682
+ nn.Linear(config.hidden_size, config.hidden_size, bias=True),
683
+ nn.GELU(),
684
+ nn.Linear(config.hidden_size, config.text_hidden_size, bias=True),
685
+ )
686
+
687
+ def _init_weights(self, module):
688
+ """Initialize the weights"""
689
+ if isinstance(module, nn.Linear):
690
+ nn.init.xavier_uniform_(module.weight)
691
+ nn.init.normal_(module.bias, std=1e-6)
692
+ elif isinstance(module, nn.LayerNorm):
693
+ module.bias.data.zero_()
694
+ module.weight.data.fill_(1.0)
695
+
696
+ def forward(
697
+ self, pixel_values: torch.Tensor, image_grid_hws: torch.Tensor
698
+ ) -> torch.Tensor:
699
+ """
700
+ Args:
701
+ pixel_values (torch.Tensor): The input pixel values.
702
+ grid_hws (torch.Tensor): The grid height and width.
703
+ Returns:
704
+ torch.Tensor: The output tokens.
705
+ """
706
+ hidden_states = self.patch_embed(pixel_values, image_grid_hws)
707
+
708
+ hidden_states_list = patch_merger(
709
+ hidden_states, image_grid_hws, merge_kernel_size=self.merge_kernel_size
710
+ )
711
+ hidden_states = self.pixel_merger(torch.cat(hidden_states_list).view(-1, hidden_states.shape[-1] * 4))
712
+
713
+ num_tokens = (image_grid_hws.prod(dim=1) // 4).tolist()
714
+ hidden_states_list = hidden_states.split(num_tokens, dim=0)
715
+ max_length = max(num_tokens)
716
+ max_h, max_w = image_grid_hws.max(dim=0).values.tolist()
717
+ max_length = max_h * max_w // 4
718
+ hidden_states = torch.stack([F.pad(h, (0, 0, 0, max_length - h.shape[0])) for h in hidden_states_list])
719
+ attention_mask = torch.zeros(len(image_grid_hws), max_length, device=hidden_states.device, dtype=torch.bool)
720
+ for i in range(len(image_grid_hws)):
721
+ attention_mask[i][:num_tokens[i]] = True
722
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
723
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
724
+
725
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=image_grid_hws)
726
+ hidden_states = self.encoder(hidden_states, attention_mask, rope_freqs_cis)
727
+ hidden_states = torch.cat([hidden_states[i][:num_tokens[i]] for i in range(len(image_grid_hws))])
728
+ # hidden_states = self.projector(hidden_states)
729
+ return hidden_states
modeling_smallvlm.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from torch import nn
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
5
+ from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM
6
+ from transformers.modeling_outputs import ModelOutput
7
+ from transformers.generation.utils import GenerationMixin
8
+ from transformers.cache_utils import Cache, DynamicCache
9
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
10
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
11
+ from transformers.models.qwen3.modeling_qwen3 import eager_attention_forward, BaseModelOutputWithPast
12
+
13
+ from .modeling_moonvit import patch_merger, get_rope_index, apply_multimodal_rotary_pos_emb
14
+ from .configuration_smallvlm import SmallVLMConfig
15
+
16
+ class Qwen2_5_VLRotaryEmbedding(nn.Module):
17
+ def __init__(self, config, device=None):
18
+ super().__init__()
19
+ # BC: "rope_type" was originally "type"
20
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
21
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
22
+ else:
23
+ self.rope_type = "default"
24
+ self.max_seq_len_cached = config.max_position_embeddings
25
+ self.original_max_seq_len = config.max_position_embeddings
26
+
27
+ self.config = config
28
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
29
+
30
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
31
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
32
+ self.original_inv_freq = self.inv_freq
33
+
34
+ @torch.no_grad()
35
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
36
+ def forward(self, x, position_ids):
37
+ # In contrast to other models, Qwen2_5_VL has different position ids for the grids
38
+ # So we expand the inv_freq to shape (3, ...)
39
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
40
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
41
+
42
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
43
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
44
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
45
+ emb = torch.cat((freqs, freqs), dim=-1)
46
+ cos = emb.cos() * self.attention_scaling
47
+ sin = emb.sin() * self.attention_scaling
48
+
49
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
50
+
51
+
52
+ def build_vision_model(config, model=None):
53
+ if model is None:
54
+ model = AutoModel.from_config(config, trust_remote_code=True)
55
+ return model
56
+
57
+ def mrope_forward(
58
+ self,
59
+ hidden_states: torch.Tensor,
60
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
61
+ attention_mask: Optional[torch.Tensor],
62
+ past_key_value: Optional[Cache] = None,
63
+ cache_position: Optional[torch.LongTensor] = None,
64
+ **kwargs,
65
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
66
+ input_shape = hidden_states.shape[:-1]
67
+ hidden_shape = (*input_shape, -1, self.head_dim)
68
+
69
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
70
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
71
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
72
+
73
+ cos, sin = position_embeddings
74
+ query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, [16, 24, 24], unsqueeze_dim=1)
75
+
76
+ if past_key_value is not None:
77
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
78
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
79
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
80
+ attention_interface: Callable = eager_attention_forward
81
+ if self.config._attn_implementation != "eager":
82
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
83
+ pass
84
+ else:
85
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
86
+
87
+ attn_output, attn_weights = attention_interface(
88
+ self,
89
+ query_states,
90
+ key_states,
91
+ value_states,
92
+ attention_mask,
93
+ dropout=0.0 if not self.training else self.attention_dropout,
94
+ scaling=self.scaling,
95
+ sliding_window=self.sliding_window, # diff with Llama
96
+ **kwargs,
97
+ )
98
+
99
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
100
+ attn_output = self.o_proj(attn_output)
101
+ return attn_output, attn_weights
102
+
103
+ import transformers
104
+ transformers.models.qwen3.modeling_qwen3.Qwen3Attention.forward = mrope_forward
105
+
106
+
107
+ def forward(
108
+ self,
109
+ input_ids: Optional[torch.LongTensor] = None,
110
+ attention_mask: Optional[torch.Tensor] = None,
111
+ position_ids: Optional[torch.LongTensor] = None,
112
+ past_key_values: Optional[Cache] = None,
113
+ inputs_embeds: Optional[torch.FloatTensor] = None,
114
+ use_cache: Optional[bool] = None,
115
+ output_attentions: Optional[bool] = None,
116
+ output_hidden_states: Optional[bool] = None,
117
+ cache_position: Optional[torch.LongTensor] = None,
118
+ **flash_attn_kwargs,
119
+ ) -> BaseModelOutputWithPast:
120
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
121
+ output_hidden_states = (
122
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
123
+ )
124
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
125
+
126
+ if (input_ids is None) ^ (inputs_embeds is not None):
127
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
128
+
129
+ if self.gradient_checkpointing and self.training and use_cache:
130
+ use_cache = False
131
+
132
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
133
+ if not isinstance(past_key_values, (type(None), Cache)):
134
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
135
+
136
+ if inputs_embeds is None:
137
+ inputs_embeds = self.embed_tokens(input_ids)
138
+
139
+ if use_cache and past_key_values is None:
140
+ past_key_values = DynamicCache()
141
+
142
+ if cache_position is None:
143
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
144
+ cache_position = torch.arange(
145
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
146
+ )
147
+
148
+ if position_ids is None:
149
+ position_ids = cache_position.unsqueeze(0)
150
+
151
+ causal_mask = self._update_causal_mask(
152
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
153
+ )
154
+
155
+ hidden_states = inputs_embeds
156
+
157
+ # create position embeddings to be shared across the decoder layers
158
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
159
+
160
+ # decoder layers
161
+ all_hidden_states = () if output_hidden_states else None
162
+ all_self_attns = () if output_attentions else None
163
+
164
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
165
+ if output_hidden_states:
166
+ all_hidden_states += (hidden_states,)
167
+
168
+ if self.gradient_checkpointing and self.training:
169
+ layer_outputs = self._gradient_checkpointing_func(
170
+ decoder_layer.__call__,
171
+ hidden_states,
172
+ causal_mask,
173
+ position_ids,
174
+ past_key_values,
175
+ output_attentions,
176
+ use_cache,
177
+ cache_position,
178
+ position_embeddings,
179
+ )
180
+ else:
181
+ layer_outputs = decoder_layer(
182
+ hidden_states,
183
+ attention_mask=causal_mask,
184
+ position_ids=position_ids,
185
+ past_key_value=past_key_values,
186
+ output_attentions=output_attentions,
187
+ use_cache=use_cache,
188
+ cache_position=cache_position,
189
+ position_embeddings=position_embeddings,
190
+ **flash_attn_kwargs,
191
+ )
192
+
193
+ hidden_states = layer_outputs[0]
194
+
195
+ if output_attentions:
196
+ all_self_attns += (layer_outputs[1],)
197
+
198
+ hidden_states = self.norm(hidden_states)
199
+
200
+ # add hidden states from the last decoder layer
201
+ if output_hidden_states:
202
+ all_hidden_states += (hidden_states,)
203
+
204
+ return BaseModelOutputWithPast(
205
+ last_hidden_state=hidden_states,
206
+ past_key_values=past_key_values if use_cache else None,
207
+ hidden_states=all_hidden_states,
208
+ attentions=all_self_attns,
209
+ )
210
+ transformers.models.qwen3.modeling_qwen3.Qwen3Model.forward = forward
211
+
212
+ class SmallVLMForCausalLM(PreTrainedModel, GenerationMixin):
213
+ config_class = SmallVLMConfig
214
+ supports_gradient_checkpointing = True
215
+ _skip_keys_device_placement = "past_key_values"
216
+ _supports_cache_class = True
217
+ _supports_flash_attn_2 = True
218
+ _supports_sdpa = True
219
+
220
+ def __init__(self, config, language_model=None, vision_model=None):
221
+ super().__init__(config)
222
+ self.rope_deltas = None # cache rope_deltas here
223
+
224
+ vision_model = build_vision_model(config.vision_model_config, vision_model)
225
+ if language_model is None:
226
+ kwargs_ = {}
227
+ if config._attn_implementation_internal is not None:
228
+ kwargs_['attn_implementation'] = config._attn_implementation_internal
229
+ language_model = AutoModelForCausalLM.from_config(config.language_model_config, trust_remote_code=True, **kwargs_)
230
+
231
+ self.vision_model = vision_model
232
+
233
+ self.language_model = language_model
234
+
235
+ self.vision_to_text_proj = nn.Sequential( # map the text embeddings to vision encoder
236
+ nn.Linear(self.config.vision_model_config.hidden_size, self.config.language_model_config.hidden_size),
237
+ nn.GELU(),
238
+ nn.Linear(self.config.language_model_config.hidden_size, self.config.language_model_config.hidden_size)
239
+ )
240
+
241
+ self.text_to_vision_proj = nn.Sequential(
242
+ nn.Linear(self.config.language_model_config.hidden_size, self.config.vision_model_config.hidden_size),
243
+ nn.GELU(),
244
+ nn.Linear(self.config.vision_model_config.hidden_size, self.config.vision_model_config.hidden_size)
245
+ )
246
+ self.vision_rotary_emb = Qwen2_5_VLRotaryEmbedding(config.vision_model_config)
247
+ self.text_rotary_emb = Qwen2_5_VLRotaryEmbedding(config.language_model_config)
248
+ self.language_model.model.rotary_emb = self.text_rotary_emb
249
+
250
+ for layer in self.language_model.model.layers:
251
+ setattr(layer.self_attn, 'layer_idx', layer.self_attn.layer_idx + self.vision_model.config.num_hidden_layers)
252
+
253
+ self.gradient_checkpointing = False
254
+
255
+ def forward(
256
+ self,
257
+ input_ids: Optional[torch.LongTensor] = None,
258
+ attention_mask: Optional[torch.Tensor] = None,
259
+ position_ids: Optional[torch.LongTensor] = None,
260
+ past_key_values: Optional[Cache] = None,
261
+ inputs_embeds: Optional[torch.FloatTensor] = None,
262
+ labels: Optional[torch.LongTensor] = None,
263
+ use_cache: Optional[bool] = None,
264
+ output_attentions: Optional[bool] = None,
265
+ output_hidden_states: Optional[bool] = None,
266
+ return_dict: Optional[bool] = None,
267
+ cache_position: Optional[torch.LongTensor] = None,
268
+ pixel_values: Optional[torch.FloatTensor] = None,
269
+ grid_hws: Optional[torch.LongTensor] = None,
270
+ ):
271
+
272
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
273
+ output_hidden_states = (
274
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
275
+ )
276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
+
278
+ if (input_ids is None) ^ (inputs_embeds is not None):
279
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
280
+
281
+ if self.gradient_checkpointing and self.training and use_cache:
282
+ use_cache = False
283
+
284
+ if use_cache and past_key_values is None:
285
+ past_key_values = DynamicCache()
286
+
287
+ inputs_embeds = self.get_input_embeddings()(input_ids)
288
+ inputs_embeds = self.text_to_vision_proj(inputs_embeds)
289
+ is_dummy_input = pixel_values is not None and pixel_values.size(0) == 0
290
+ if is_dummy_input:
291
+ pixel_values = torch.zeros((4,) + pixel_values.shape[1:], dtype=pixel_values.dtype, device=pixel_values.device)
292
+ grid_hws = torch.tensor([[1, 2, 2]], dtype=torch.int32).to(pixel_values.device)
293
+
294
+ if pixel_values is not None:
295
+ vision_embeds = self.vision_model.patch_embed(pixel_values, grid_hws[:, 1:])
296
+ vision_embeds_list = patch_merger(
297
+ vision_embeds, grid_hws[:, 1:], merge_kernel_size=self.vision_model.merge_kernel_size
298
+ )
299
+ vision_embeds = self.vision_model.pixel_merger(torch.cat(vision_embeds_list).view(-1, vision_embeds.shape[-1] * 4))
300
+
301
+ vision_mask = (input_ids == self.config.image_token_id).to(inputs_embeds.device)
302
+ inputs_embeds[vision_mask] = vision_embeds
303
+
304
+ image_token_lens = (grid_hws.prod(dim=1) // 4)
305
+ bsz, src_len = attention_mask.size()
306
+ causal_mask = attention_mask[:, None, None, :].expand(bsz, 1, src_len, src_len).to(inputs_embeds.dtype)
307
+ causal_mask.tril_()
308
+ idx = 0
309
+ for i, _ in enumerate(causal_mask):
310
+ vision_mask = input_ids[i] == self.config.image_token_id
311
+ while (vision_mask.sum() > 0):
312
+ start = torch.nonzero(vision_mask)[0][0]
313
+ num = image_token_lens[idx]
314
+ idx += 1
315
+ causal_mask[i, 0, start:start+num, start:start+num] = 1
316
+ vision_mask[start:start+num] = 0
317
+
318
+ causal_mask = 1.0 - causal_mask
319
+ causal_mask = causal_mask.masked_fill(causal_mask.to(torch.bool), torch.finfo(vision_embeds.dtype).min)
320
+ else:
321
+ causal_mask = None
322
+
323
+ if self.is_gradient_checkpointing and torch.is_grad_enabled() and self.training:
324
+ inputs_embeds.requires_grad_(True)
325
+
326
+ if cache_position is None:
327
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
328
+ cache_position = torch.arange(
329
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
330
+ )
331
+
332
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
333
+ # calculate RoPE index once per generation in the pre-fill stage only
334
+ if (
335
+ (cache_position is not None and cache_position[0] == 0)
336
+ or self.rope_deltas is None
337
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
338
+ ):
339
+ position_ids, rope_deltas = get_rope_index(
340
+ self.config.image_token_id,
341
+ self.config.video_token_id,
342
+ self.config.vision_start_token_id,
343
+ spatial_merge_size=2,
344
+ input_ids=input_ids,
345
+ image_grid_thw=grid_hws,
346
+ video_grid_thw=None,
347
+ attention_mask=attention_mask
348
+ )
349
+ self.rope_deltas = rope_deltas
350
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
351
+ else:
352
+ batch_size, seq_length, _ = inputs_embeds.shape
353
+ delta = (
354
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
355
+ if cache_position is not None
356
+ else 0
357
+ )
358
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
359
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
360
+ if cache_position is not None: # otherwise `deltas` is an int `0`
361
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
362
+ position_ids = position_ids.add(delta)
363
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
364
+
365
+ position_embeddings = self.vision_rotary_emb(inputs_embeds, position_ids)
366
+ inputs_embeds = self.vision_model.encoder(inputs_embeds, causal_mask, position_embeddings, past_key_values)
367
+
368
+ # return ModelOutput(
369
+ # last_hidden_state=self.vision_model.projector(inputs_embeds),
370
+ # text_hidden_state=self.vision_to_text_proj(inputs_embeds),
371
+ # )
372
+
373
+ inputs_embeds = self.vision_to_text_proj(inputs_embeds)
374
+
375
+ outputs = self.language_model(
376
+ input_ids=None,
377
+ labels=labels,
378
+ attention_mask=causal_mask,
379
+ position_ids=position_ids,
380
+ past_key_values=past_key_values,
381
+ inputs_embeds=inputs_embeds,
382
+ use_cache=use_cache,
383
+ output_attentions=output_attentions,
384
+ output_hidden_states=output_hidden_states,
385
+ cache_position=cache_position,
386
+ return_dict=True,
387
+ )
388
+
389
+ return ModelOutput(
390
+ loss=outputs.loss,
391
+ logits=outputs.logits,
392
+ past_key_values=outputs.past_key_values,
393
+ hidden_states=outputs.hidden_states,
394
+ attentions=outputs.attentions,
395
+ )
396
+
397
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
398
+ super().gradient_checkpointing_enable(gradient_checkpointing_kwargs)
399
+ self.language_model.enable_input_require_grads()
400
+
401
+ def get_input_embeddings(self):
402
+ return self.language_model.get_input_embeddings()
403
+
404
+ def set_input_embeddings(self, value):
405
+ self.language_model.set_input_embeddings(value)
406
+
407
+ def get_output_embeddings(self):
408
+ return self.language_model.get_output_embeddings()
409
+
410
+ def set_output_embeddings(self, new_embeddings):
411
+ self.language_model.set_output_embeddings(new_embeddings)
412
+
413
+ def set_decoder(self, decoder):
414
+ self.language_model.set_decoder(decoder)
415
+
416
+ def get_decoder(self):
417
+ return self.language_model.get_decoder()
418
+
419
+ def tie_weights(self):
420
+ return self.language_model.tie_weights()
421
+
422
+ def prepare_inputs_for_generation(
423
+ self,
424
+ input_ids,
425
+ past_key_values=None,
426
+ attention_mask=None,
427
+ inputs_embeds=None,
428
+ cache_position=None,
429
+ position_ids=None,
430
+ use_cache=True,
431
+ pixel_values=None,
432
+ **kwargs,
433
+ ):
434
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
435
+ model_inputs = super().prepare_inputs_for_generation(
436
+ input_ids,
437
+ past_key_values=past_key_values,
438
+ attention_mask=attention_mask,
439
+ inputs_embeds=inputs_embeds,
440
+ cache_position=cache_position,
441
+ position_ids=position_ids,
442
+ pixel_values=pixel_values,
443
+ use_cache=use_cache,
444
+ **kwargs,
445
+ )
446
+
447
+ # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
448
+ model_inputs["position_ids"] = None
449
+ if cache_position[0] != 0:
450
+ model_inputs["pixel_values"] = None
451
+
452
+ return model_inputs
preprocessor_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_smallvlm.SmallVLMProcessor",
4
+ "AutoImageProcessor": "image_processing_moonvit.MoonViTImageProcessor"
5
+ },
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "MoonViTImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "in_token_limit": 16384,
18
+ "merge_kernel_size": [
19
+ 2,
20
+ 2
21
+ ],
22
+ "num_pooled_tokens": 1024,
23
+ "pad_input": true,
24
+ "patch_size": 14,
25
+ "processor_class": "SmallVLMProcessor"
26
+ }
processing_smallvlm.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from collections import UserDict, OrderedDict
4
+ from typing import Union, List, Dict, Any
5
+
6
+ from transformers.processing_utils import ProcessorMixin
7
+ from transformers.feature_extraction_utils import BatchFeature
8
+ from transformers.utils.chat_template_utils import render_jinja_template
9
+
10
+ from .image_processing_moonvit import MoonViTImageProcessor
11
+
12
+ class SmallVLMProcessor(ProcessorMixin):
13
+ attributes = ["tokenizer", "image_processor"]
14
+ optional_attributes = ['chat_template']
15
+ model_input_names = ['input_ids', 'attention_mask', 'pixel_values']
16
+ image_processor_class = "AutoImageProcessor"
17
+ tokenizer_class = "AutoTokenizer"
18
+
19
+ image_token = '<|image_pad|>'
20
+
21
+ def __init__(self, tokenizer, image_processor, chat_template, **kwargs):
22
+ super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template)
23
+ self.tokenizer.add_special_tokens({'additional_special_tokens': [self.image_token]}, replace_additional_special_tokens=False)
24
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
25
+
26
+ def __call__(self, inputs=None, images=[], text=None, **kwargs) -> BatchFeature:
27
+
28
+ truncation = kwargs.pop('truncation', False)
29
+ max_length = kwargs.pop('max_length', 1024)
30
+ padding = kwargs.pop('padding', False)
31
+
32
+ if inputs is None:
33
+ inputs = {}
34
+ if isinstance(inputs, UserDict):
35
+ inputs = inputs.data
36
+
37
+ if 'input_ids' not in inputs:
38
+ input_ids = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False, **kwargs)['input_ids'][0]
39
+ inputs['input_ids'] = input_ids.tolist()
40
+
41
+ inputs = self.process_images(images, inputs=inputs)
42
+
43
+ if 'attention_mask' not in inputs:
44
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
45
+
46
+ if 'assistant_masks' in inputs:
47
+ inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')]
48
+
49
+ inputs = self.process_inputs(inputs)
50
+
51
+
52
+ if truncation and len(inputs['input_ids']) > max_length:
53
+ inputs = self.truncate(inputs, max_length)
54
+
55
+ if padding and len(inputs['input_ids']) < max_length:
56
+ inputs = self.padding(inputs, max_length)
57
+
58
+ inputs = self.to_tensor(inputs)
59
+
60
+ self.check(inputs)
61
+
62
+ new_inputs = {
63
+ "input_ids": inputs["input_ids"],
64
+ "attention_mask": inputs["attention_mask"],
65
+ }
66
+ if "pixel_values" in inputs:
67
+ new_inputs['pixel_values'] = inputs['pixel_values']
68
+ new_inputs['grid_hws'] = torch.cat([torch.ones_like(inputs['image_grid_hws'])[:, :1], inputs['image_grid_hws']], dim=1)
69
+ if 'prompt_mask' in inputs:
70
+ new_inputs['prompt_mask'] = inputs['prompt_mask']
71
+
72
+ return BatchFeature(new_inputs)
73
+
74
+ def process_images(self, images, inputs):
75
+ if len(images) > 0:
76
+ pixel_values, image_grid_hws = self.image_transform(images)
77
+ else:
78
+ pixel_values = torch.zeros((0, 3, 14, 14), dtype=torch.float32)
79
+ image_grid_hws = torch.zeros((0, 2), dtype=torch.int64)
80
+
81
+ inputs['pixel_values'] = pixel_values
82
+ inputs['image_grid_hws'] = image_grid_hws
83
+ return inputs
84
+
85
+ def image_transform(self, images):
86
+ image_inputs = self.image_processor(images, return_tensors='pt')
87
+ return image_inputs['pixel_values'], image_inputs['image_grid_hws']
88
+
89
+ def truncate(self, inputs: Dict[str, Any], max_length: int):
90
+ assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed."
91
+
92
+ inputs['input_ids'] = inputs['input_ids'][:max_length]
93
+ inputs['attention_mask'] = inputs['attention_mask'][:max_length]
94
+ if 'prompt_mask' in inputs:
95
+ inputs['prompt_mask'] = inputs['prompt_mask'][:max_length]
96
+
97
+ return inputs
98
+
99
+ def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]:
100
+ image_grid_hws = inputs.get('image_grid_hws', None)
101
+ if image_grid_hws is None:
102
+ return []
103
+ image_token_lens = (image_grid_hws.prod(dim=1) // 4).tolist()
104
+ return image_token_lens
105
+
106
+ def process_inputs(self, inputs: Dict[str, Any]):
107
+ graft_token_lens = self._get_graft_token_length(inputs)
108
+
109
+ inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id)
110
+ inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate')
111
+ if 'prompt_mask' in inputs:
112
+ inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate')
113
+
114
+ return inputs
115
+
116
+ def _graft_token(self, seq, graft_token_lens, value):
117
+ if value == 'replicate':
118
+ for i in reversed(graft_token_lens.keys()):
119
+ seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:]
120
+ else:
121
+ for i in reversed(graft_token_lens.keys()):
122
+ assert value == seq[i]
123
+ seq[i:] = [value] * graft_token_lens[i] + seq[i+1:]
124
+ return seq
125
+
126
+ def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]:
127
+ image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id]
128
+ image_token_lens = self.get_image_token_length(inputs)
129
+
130
+ assert len(image_token_pos) == len(image_token_lens), \
131
+ "Wrong image token count, " \
132
+ f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})"
133
+
134
+ graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens))
135
+
136
+ return graft_token_lens
137
+
138
+ def check(self, inputs: Dict[str, Any]):
139
+ image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item()
140
+ image_embed_count = sum(self.get_image_token_length(inputs))
141
+ assert image_embed_token_count == image_embed_count, "Wrong image embed token count"
142
+
143
+ def padding(self, inputs: Dict[str, Any], max_length: int):
144
+ padding_len = max_length - len(inputs['input_ids'])
145
+ inputs['input_ids'] += [self.pad_token_id] * padding_len
146
+ inputs['attention_mask'] += [0] * padding_len
147
+ if 'prompt_mask' in inputs:
148
+ inputs['prompt_mask'] += [0] * padding_len
149
+ return inputs
150
+
151
+ def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs):
152
+ if isinstance(token_ids, torch.Tensor):
153
+ token_ids = token_ids.tolist()
154
+ text = self.tokenizer.decode(token_ids, **kwargs)
155
+ return text
156
+
157
+ def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs):
158
+ if isinstance(sequences, torch.Tensor):
159
+ sequences = sequences.tolist()
160
+ texts = self.tokenizer.batch_decode(sequences, **kwargs)
161
+ return texts
162
+
163
+ def to_tensor(self, inputs):
164
+ inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long)
165
+ inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool)
166
+ if 'prompt_mask' in inputs:
167
+ inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool)
168
+ return inputs
169
+
170
+ @property
171
+ def pad_token_id(self):
172
+ return self.tokenizer.pad_token_id
173
+
174
+ @property
175
+ def special_tokens(self):
176
+ return [token.content for token in self.tokenizer.added_tokens_decoder.values()]
177
+
178
+ def __repr__(self):
179
+ pass
180
+
181
+ def __str__(self):
182
+ return ''
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_smallvlm.SmallVLMProcessor"
4
+ },
5
+ "processor_class": "SmallVLMProcessor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "auto_map": {
230
+ "AutoProcessor": "processing_smallvlm.SmallVLMProcessor"
231
+ },
232
+ "bos_token": null,
233
+ "clean_up_tokenization_spaces": false,
234
+ "eos_token": "<|im_end|>",
235
+ "errors": "replace",
236
+ "extra_special_tokens": {},
237
+ "model_max_length": 131072,
238
+ "pad_token": "<|endoftext|>",
239
+ "processor_class": "SmallVLMProcessor",
240
+ "split_special_tokens": false,
241
+ "tokenizer_class": "Qwen2Tokenizer",
242
+ "unk_token": null
243
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff