Yukang commited on
Commit
d407812
·
verified ·
1 Parent(s): e62d034

Upload 19 files

Browse files
Files changed (19) hide show
  1. NV_LICENSE +35 -0
  2. README.md +184 -0
  3. auto_processor.py +495 -0
  4. base_projector.py +228 -0
  5. builder.py +247 -0
  6. config.json +295 -0
  7. configuration_vila.py +92 -0
  8. constants.py +83 -0
  9. conversation.py +191 -0
  10. distributed.py +73 -0
  11. loss.py +48 -0
  12. media.py +130 -0
  13. media_encoder.py +158 -0
  14. mm_utils.py +575 -0
  15. model_utils_packing.py +35 -0
  16. modeling_vila.py +1256 -0
  17. siglip_encoder.py +286 -0
  18. tokenizer_utils.py +181 -0
  19. utils.py +211 -0
NV_LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NVIDIA License
2
+
3
+ 1. Definitions
4
+
5
+ “Licensor” means any person or entity that distributes its Work.
6
+ “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
7
+ The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
8
+ Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
9
+
10
+ 2. License Grant
11
+
12
+ 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
13
+
14
+ 3. Limitations
15
+
16
+ 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
17
+
18
+ 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
19
+
20
+ 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or educational purposes only.
21
+
22
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
23
+
24
+ 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
25
+
26
+ 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
27
+
28
+ 4. Disclaimer of Warranty.
29
+
30
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
31
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
32
+
33
+ 5. Limitation of Liability.
34
+
35
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
README.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LongVILA-R1-7B
2
+ [![Paper](https://img.shields.io/badge/ArXiv-Paper-brown)](https://arxiv.org/abs/2507.07966)
3
+ [![Code](https://img.shields.io/badge/GitHub-Long%20RL-blue)](https://github.com/NVlabs/Long-RL)
4
+ [![Model](https://img.shields.io/badge/HF-Model-yellow)](https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B)
5
+ [![Video](https://img.shields.io/badge/YouTube-Video-red)](https://www.youtube.com/watch?v=ykbblK2jiEg)
6
+ [![Demo](https://img.shields.io/badge/Gradio-Demo-bron)](https://6d8b5579459b555d59.gradio.live)
7
+
8
+
9
+ ## Introduction:
10
+ <p>
11
+ <strong>LongVILA-R1-7B</strong> supports both <u>multiple-choice</u> questions and <u>open-ended</u> questions. It can switch between thinking and non-thinking modes.<br>
12
+ <strong>LongVILA-R1-7B</strong> demonstrates strong performance in long video reasoning, achieving <strong>70.7%</strong> on VideoMME (w/ sub.) and surpassing Gemini-1.5-Pro across diverse reasoning tasks.<br>
13
+ <strong>Long-RL</strong> is a codebase that accelerates long video RL training by up to <strong>2.1×</strong> through its MR-SP system. It supports RL training on image, video, and omni inputs across VILA, Qwen/Qwen-VL, and diffusion models.
14
+ </p>
15
+
16
+ ## Evaluation:
17
+ ### Video QA Benchmarks
18
+ | Models | VideoMME (w/o sub) | VideoMME (w sub) | ActivityNet-QA (test) | LongVideoBench (val) | PerceptionTest (val) | NExT-QA | VNBench (val) |
19
+ |:-------------------|:------------------:|:----------------:|:---------------------:|:--------------------:|:--------------------:|:--------:|:-------------:|
20
+ | **LongVILA-7B** | **60.1** | **65.1** | **59.5** | **57.1** | **58.1** | **80.7** | **63.0** |
21
+ | **LongVILA-R1-7B** | **65.0** | **70.7** | **64.8** | **58.0** | **68.9** | **81.5** | **75.5** |
22
+
23
+ ### LongVideo-Reason-eval
24
+ | Models | Temporal | Goal | Plot | Spatial | Overall|
25
+ | :--- | :---: | :---: | :---: | :---: | :---: |
26
+ | | | |
27
+ | **LongVILA-R1-7B** | **68.1** | **85.7** | **70.6** | **53.3** | **72.0** |
28
+
29
+
30
+ ## Usage
31
+
32
+ ### Generation
33
+ ```python
34
+ from transformers import AutoModel
35
+
36
+ model_path = "Efficient-Large-Model/LongVILA-R1-7B"
37
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
38
+
39
+ use_thinking = True # Switching between thinking and non-thinking modes
40
+ system_prompt_thinking = "You are a helpful assistant. The user asks a question, and then you solves it.\n\nPlease first think deeply about the question based on the given video, and then provide the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n\n Question: {question}"
41
+
42
+ prompt = "What is the main purpose of the video?"
43
+ video_path = "video.mp4"
44
+
45
+ if use_thinking:
46
+ prompt = system_prompt_thinking.format(question=prompt)
47
+
48
+ response = model.generate_content([prompt, {"path": video_path}])
49
+ print("Response: ", response)
50
+ ```
51
+
52
+ ### with vLLM engine
53
+ Tested on `vllm==0.9.1`. We need to get the remote code first.
54
+ ```bash
55
+ mkdir remote_code
56
+ cp path_to/Efficient-Large-Model/LongVILA-R1-7B/*.py remote_code
57
+ ```
58
+ Then, you can use the following code for model generation.
59
+ ```python
60
+ import os
61
+ from transformers import AutoModel
62
+ from vllm import LLM, SamplingParams
63
+ from remote_code.media import extract_media
64
+ from remote_code.mm_utils import process_images
65
+ from remote_code.tokenizer_utils import tokenize_conversation
66
+
67
+ model_path = "path_to/Efficient-Large-Model/LongVILA-R1-7B"
68
+
69
+ model_encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto", llm_only_need_embed=True)
70
+ # you can change gpu_memory_utilization according to GPU memory
71
+ llm = LLM(model=os.path.join(model_path, "llm"), enable_prompt_embeds=True, gpu_memory_utilization=0.5)
72
+
73
+ use_thinking = True # Switching between thinking and non-thinking modes
74
+ system_prompt_thinking = "You are a helpful assistant. The user asks a question, and then you solves it.\n\nPlease first think deeply about the question based on the given video, and then provide the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n\n Question: {question}"
75
+
76
+ prompt = "What is the main purpose of the video?"
77
+ video_path = "video.mp4"
78
+
79
+ if use_thinking:
80
+ prompt = system_prompt_thinking.format(question=prompt)
81
+
82
+ conversation = [{"from": "human", "value": [prompt, {"path": video_path}]}]
83
+ media = extract_media(conversation, model_encoder.config)
84
+ input_ids = tokenize_conversation(conversation, model_encoder.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
85
+ media["video"] = [
86
+ process_images(images, model_encoder.vision_tower.image_processor, model_encoder.config).half()
87
+ for images in media["video"]
88
+ ]
89
+
90
+ inputs_embeds, _, _ = model_encoder._embed(input_ids, media, {"video": {}}, None, None)
91
+
92
+ completions = llm.generate(prompts=[{"prompt_embeds": inputs_embeds.squeeze(0)}], sampling_params=SamplingParams(max_tokens=1024))
93
+ response = completions[0].outputs[0].text
94
+ print("Response: ", response)
95
+ ```
96
+
97
+
98
+ # LongVILA-R1 Model Card
99
+
100
+ ## Model details
101
+
102
+ **Model type:**
103
+ LongVILA-R1 addresses the unique challenges of long video reasoning by integrating three critical components: (1) a large-scale dataset, LongVideo-Reason, comprising 104K long video QA pairs with high-quality reasoning annotations across diverse domains such as sports, games, and vlogs; (2) a two-stage training pipeline that extends VLMs with chain-of-thought supervised fine-tuning (CoT-SFT) and reinforcement learning (RL); and (3) a training infrastructure for long video RL, named Multi-modal Reinforcement Sequence Parallelism (MR-SP), which incorporates sequence parallelism and a vLLM-based engine tailored for long video, using cached video embeddings for efficient rollout and prefilling. In our experiments, LongVILA-R1-7B achieves strong performance on video benchmarks, reaching 65.0% and 70.7% accuracy on VideoMME without and with subtitles, respectively, and consistently outperforming LongVILA-R1 across multiple benchmarks. Moreover, LongVILA-R1 shows steady performance improvements as the number of input video frames increases.
104
+ **Model date:**
105
+ LongVILA-R1-7B was trained in July 2025.
106
+
107
+ **Paper or resources for more information:**
108
+ - Paper https://arxiv.org/abs/2507.07966
109
+ - Code https://github.com/NVLabs/Long-RL
110
+ - Model https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B
111
+ - Video https://www.youtube.com/watch?v=ykbblK2jiEg
112
+ - Demo https://6d8b5579459b555d59.gradio.live
113
+
114
+ ```bibtex
115
+ @misc{long-rl,
116
+ title = {Long-RL: Scaling RL to Long Sequences},
117
+ author = {Yukang Chen, Wei Huang, Shuai Yang, Qinghao Hu, Baifeng Shi, Hanrong Ye, Ligeng Zhu, Zhijian Liu, Pavlo Molchanov, Jan Kautz, Xiaojuan Qi, Sifei Liu,Hongxu Yin, Yao Lu, Song Han},
118
+ year = {2025},
119
+ publisher = {GitHub},
120
+ journal = {GitHub repository},
121
+ howpublished = {\url{https://github.com/NVlabs/Long-RL}},
122
+ }
123
+ ```
124
+ ```bibtex
125
+ @article{chen2025longvila-r1,
126
+ title={Scaling RL to Long Videos},
127
+ author={Yukang Chen and Wei Huang and Baifeng Shi and Qinghao Hu and Hanrong Ye and Ligeng Zhu and Zhijian Liu and Pavlo Molchanov and Jan Kautz and Xiaojuan Qi and Sifei Liu and Hongxu Yin and Yao Lu and Song Han},
128
+ year={2025},
129
+ eprint={2507.07966},
130
+ archivePrefix={arXiv},
131
+ primaryClass={cs.CV}
132
+ }
133
+ ```
134
+ ```bibtex
135
+ @inproceedings{chen2024longvila,
136
+ title={LongVILA: Scaling Long-Context Visual Language Models for Long Videos},
137
+ author={Yukang Chen and Fuzhao Xue and Dacheng Li and Qinghao Hu and Ligeng Zhu and Xiuyu Li and Yunhao Fang and Haotian Tang and Shang Yang and Zhijian Liu and Ethan He and Hongxu Yin and Pavlo Molchanov and Jan Kautz and Linxi Fan and Yuke Zhu and Yao Lu and Song Han},
138
+ booktitle={The International Conference on Learning Representations (ICLR)},
139
+ year={2025},
140
+ }
141
+ ```
142
+
143
+ ## License
144
+ - The weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
145
+ - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
146
+ - [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
147
+ - [Dataset Licenses](https://github.com/Efficient-Large-Model/VILA/blob/main/data_prepare/LICENSE) for each one used during training.
148
+ - [NVIDIA Licenses](https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B/blob/main/NV_LICENSE)
149
+
150
+ **Where to send questions or comments about the model:**
151
+ https://github.com/NVLabs/Long-RL/issues
152
+
153
+ ## Intended use
154
+ **Primary intended uses:**
155
+ The primary use of LongVILA-R1 is research on large multimodal models and chatbots.
156
+
157
+ **Primary intended users:**
158
+ The primary intended users of the model are researchers and hobbyists in computer vision, natural language processing, machine learning, and artificial intelligence.
159
+
160
+ ## Input:
161
+ **Input Type:** Video and Text
162
+ **Input Format:** MP4 and other video fromats
163
+
164
+ ## Output:
165
+ **Output Type:** Text
166
+ **Output Format:** String
167
+
168
+
169
+ **[Preferred/Supported] Operating System(s):** <br>
170
+ Linux
171
+
172
+
173
+ ## Inference:
174
+ **Engine:** [Tensor(RT), Triton, Or List Other Here]
175
+ * PyTorch
176
+
177
+
178
+ **Test Hardware:**
179
+ * A100
180
+ * H100
181
+ * A6000
182
+
183
+ ## Ethical Considerations
184
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
auto_processor.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import os.path as osp
4
+ import warnings
5
+ from collections import defaultdict
6
+ from io import BytesIO
7
+ from typing import List, Optional, Union
8
+
9
+ import PIL.Image
10
+ import requests
11
+ import torch
12
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.image_utils import ImageInput
15
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
16
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
17
+ from transformers.utils import logging
18
+
19
+ from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
20
+ from .media import Image, Video, extract_media
21
+ from .mm_utils import process_image, process_images
22
+ from .tokenizer_utils import tokenize_conversation
23
+
24
+
25
+ def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
26
+ if pil_image.mode == "RGBA":
27
+ white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
28
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
29
+ return white_background
30
+ else:
31
+ return pil_image.convert("RGB")
32
+
33
+
34
+ def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
35
+ if "image" in ele:
36
+ image = ele["image"]
37
+ else:
38
+ image = ele["image_url"]
39
+ image_obj = None
40
+ if isinstance(image, PIL.Image.Image):
41
+ image_obj = image
42
+ elif image.startswith("http://") or image.startswith("https://"):
43
+ response = requests.get(image, stream=True)
44
+ image_obj = PIL.Image.open(BytesIO(response.content))
45
+ elif image.startswith("file://"):
46
+ image_obj = PIL.Image.open(image[7:])
47
+ elif image.startswith("data:image"):
48
+ if "base64," in image:
49
+ _, base64_data = image.split("base64,", 1)
50
+ data = base64.b64decode(base64_data)
51
+ image_obj = PIL.Image.open(BytesIO(data))
52
+ else:
53
+ image_obj = PIL.Image.open(image)
54
+ if image_obj is None:
55
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
56
+ image = to_rgb(image_obj)
57
+
58
+ return image
59
+
60
+
61
+ def fetch_image_url_or_fpath(url_or_fpath):
62
+ if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
63
+ import tempfile
64
+
65
+ import requests
66
+
67
+ # Download the image to a temporary file
68
+ temp_dir = tempfile.mkdtemp()
69
+ temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
70
+
71
+ response = requests.get(url_or_fpath, stream=True)
72
+ response.raise_for_status()
73
+
74
+ with open(temp_file, "wb") as f:
75
+ for chunk in response.iter_content(chunk_size=8192):
76
+ f.write(chunk)
77
+
78
+ return temp_file
79
+ elif url_or_fpath.startswith("file://"):
80
+ fpath = url_or_fpath.replace("file://", "")
81
+ assert osp.exists(fpath), f"File {fpath} does not exist"
82
+ return fpath
83
+ elif osp.exists(url_or_fpath):
84
+ assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
85
+ return url_or_fpath
86
+ else:
87
+ raise ValueError(f"Unsupported image path: {url_or_fpath}")
88
+
89
+
90
+ def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
91
+ # tensor shape is (batch_size, seq_len)
92
+ max_len = max([ids.shape[1] for ids in input_ids_list])
93
+ if target_len is not None:
94
+ assert target_len >= max_len, "target_len must be greater than or equal to max_len"
95
+ max_len = target_len
96
+
97
+ new_input_ids_list = []
98
+ for i, input_ids in enumerate(input_ids_list):
99
+ pad_tensor = torch.ones_like(input_ids) * padding_value
100
+ curr_len = input_ids.shape[1]
101
+ pad_tensor = pad_tensor[:, : max_len - curr_len]
102
+ if padding_side == "right":
103
+ input_ids = torch.cat((input_ids, pad_tensor), dim=1)
104
+ else:
105
+ input_ids = torch.cat((pad_tensor, input_ids), dim=1)
106
+ new_input_ids_list.append(input_ids)
107
+ return torch.cat(new_input_ids_list, dim=0)
108
+
109
+
110
+ def extract_value_from_conv(chat):
111
+ value = []
112
+ if isinstance(chat["content"], str):
113
+ # vila_chat["value"].append(chat["content"])
114
+ value.append(chat["content"])
115
+ return value
116
+
117
+ # otherwise, it's a list of content
118
+ for content in chat["content"]:
119
+ if content["type"] == "image":
120
+ if "path" in content:
121
+ # VILA style, can be either filepath or http url
122
+ value.append(Image(fetch_image_url_or_fpath(content["path"])))
123
+ elif "image" in content:
124
+ # Qwen style
125
+ value.append(Image(fetch_image_url_or_fpath(content["image"])))
126
+ elif "image_pil" in content:
127
+ # Qwen style
128
+ assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of {media_key} must be PIL.Image.Image"
129
+ value.append(content["image_pil"])
130
+ else:
131
+ raise ValueError(f"Type = `image` , but no `path` or `image` in | {content=}, {conversation=}")
132
+ elif content["type"] == "video":
133
+ if "video" in content:
134
+ # Qwen style
135
+ value.append(Video(fetch_image_url_or_fpath(content["video"])))
136
+ else:
137
+ raise ValueError(f"Type = `video` , but no `video` in | {content=}, {conversation=}")
138
+ elif content["type"] == "text":
139
+ value.append(content["text"])
140
+ else:
141
+ raise ValueError(f"Unsupported content type: {content['type']}")
142
+ return value
143
+
144
+
145
+ class VILAProcessorKwargs(ProcessingKwargs, total=False):
146
+ _defaults = {
147
+ "text_kwargs": {
148
+ "padding": False,
149
+ },
150
+ }
151
+
152
+
153
+ class VILAProcessor(ProcessorMixin):
154
+ # attributes = ["image_processor", "tokenizer"]
155
+ attributes = []
156
+ # valid_kwargs = ["chat_template"]
157
+ valid_kwargs = []
158
+ # image_processor_class = "VILAImageProcessor"
159
+ # tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
160
+
161
+ def __init__(
162
+ self, image_processor=None, tokenizer=None, chat_template=None, config=None, padding_side="left", **kwargs
163
+ ):
164
+ self.image_token = MEDIA_TOKENS["image"]
165
+ self.video_token = MEDIA_TOKENS["video"]
166
+ self.config = config
167
+ self.image_processor = image_processor
168
+ self.tokenizer = tokenizer
169
+ self.padding_side = padding_side
170
+
171
+ # This is a special setting for Qwen.
172
+ # self.pad_token_id = tokenizer.pad_token_id
173
+ self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0] # 151643
174
+ self.eos_token_id = self.tokenizer.eos_token_id
175
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
176
+
177
+ @staticmethod
178
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
179
+ """
180
+ referernce from qwen_vl_utils
181
+ """
182
+ vision_infos = []
183
+ if isinstance(conversations[0], dict):
184
+ conversations = [conversations]
185
+ for conversation in conversations:
186
+ for message in conversation:
187
+ if isinstance(message["content"], list):
188
+ for ele in message["content"]:
189
+ if (
190
+ "image" in ele
191
+ or "image_url" in ele
192
+ or "video" in ele
193
+ or ele["type"] in ("image", "image_url", "video")
194
+ ):
195
+ vision_infos.append(ele)
196
+ return vision_infos
197
+
198
+ @staticmethod
199
+ def process_vision_info(
200
+ conversations: list[dict] | list[list[dict]],
201
+ return_video_kwargs: bool = False,
202
+ ) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
203
+ """
204
+ referernce from qwen_vl_utils
205
+ NVILA does not depend on the function, but the interface is the same.
206
+ """
207
+ vision_infos = extract_vision_info(conversations)
208
+ ## Read images or videos
209
+ image_inputs = []
210
+ video_inputs = []
211
+ video_sample_fps_list = []
212
+ for vision_info in vision_infos:
213
+ if "image" in vision_info or "image_url" in vision_info:
214
+ image_inputs.append(fetch_image(vision_info))
215
+ elif "video" in vision_info:
216
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
217
+ video_sample_fps_list.append(video_sample_fps)
218
+ video_inputs.append(video_input)
219
+ else:
220
+ raise ValueError("image, image_url or video should in content.")
221
+ if len(image_inputs) == 0:
222
+ image_inputs = None
223
+ if len(video_inputs) == 0:
224
+ video_inputs = None
225
+ if return_video_kwargs:
226
+ return image_inputs, video_inputs, {"fps": video_sample_fps_list}
227
+ return image_inputs, video_inputs
228
+
229
+ @staticmethod
230
+ def move_data_to_device(cls, prompt_inputs):
231
+ def _move_data_to_device(item):
232
+ # wrap function grpo trainer _prepare_input
233
+ kwargs = {"device": cls.args.device}
234
+ if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
235
+ kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
236
+ return item.to(**kwargs)
237
+
238
+ prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
239
+ prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
240
+ if "image" in prompt_inputs.media:
241
+ prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
242
+ return prompt_inputs
243
+
244
+ @classmethod
245
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
246
+ padding_side = kwargs.get("padding_side", "left")
247
+ if os.path.isdir(pretrained_model_name_or_path):
248
+ pretrained_model_name_or_path = pretrained_model_name_or_path
249
+ else:
250
+ print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
251
+ from huggingface_hub import snapshot_download
252
+
253
+ pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
254
+
255
+ image_processor = AutoImageProcessor.from_pretrained(
256
+ osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
257
+ )
258
+ tokenizer = AutoTokenizer.from_pretrained(
259
+ osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
260
+ )
261
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
262
+ return cls(image_processor=image_processor, tokenizer=tokenizer, config=config, padding_side=padding_side)
263
+
264
+ def __repr__(self):
265
+ return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
266
+
267
+ def __call__(
268
+ self,
269
+ conversation=None,
270
+ **kwargs: Unpack[VILAProcessorKwargs],
271
+ ) -> BatchFeature:
272
+ """
273
+ The `conv` will be look like
274
+ [
275
+ {
276
+ 'from': 'human',
277
+ 'value': [
278
+ <transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
279
+ 'What are the common elements in these pictures?'
280
+ ]
281
+ }
282
+ ]
283
+ and `conversation` will be a list of such `conv`s
284
+ """
285
+ if kwargs.get("text", None) is not None:
286
+ conversation = kwargs.get("text")
287
+ assert conversation is not None, "`conversation` or `text` is required"
288
+ padding_side = kwargs.get("padding_side", self.padding_side)
289
+
290
+ input_ids_list = []
291
+ attention_mask = []
292
+ media = defaultdict(list)
293
+ media_config = defaultdict(dict)
294
+ for conv in conversation:
295
+ feat = self.__single_call__(conv, **kwargs)
296
+ input_ids_list.append(feat.input_ids)
297
+ attention_mask.append(feat.attention_mask)
298
+ for name in feat.media:
299
+ media[name] += feat.media[name]
300
+ for name in feat.media_config:
301
+ media_config[name].update(feat.media_config[name])
302
+
303
+ # pad the input_ids to batchfy
304
+ input_ids = pad_fn(
305
+ input_ids_list,
306
+ padding_value=self.pad_token_id,
307
+ padding_side=padding_side,
308
+ )
309
+ # ignore the pad token in the attention mask
310
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
311
+ attention_mask[input_ids == self.pad_token_id] = False
312
+ input_texts = self.tokenizer.batch_decode(input_ids)
313
+ bdata = BatchFeature(
314
+ data={
315
+ # "input_texts": input_texts,
316
+ "input_ids": input_ids,
317
+ "attention_mask": attention_mask,
318
+ "media": media,
319
+ "media_config": media_config,
320
+ }
321
+ )
322
+ # NOTE: hard coded to cuda
323
+ # bdata.input_ids = bdata.input_ids.cuda()
324
+ # bdata.attention_mask = bdata.attention_mask.cuda()
325
+ # bdata.media["image"] = [img.cuda() for img in bdata.media["image"]]
326
+ return bdata
327
+
328
+ def __single_call__(
329
+ self,
330
+ conversation,
331
+ images: ImageInput = None,
332
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
333
+ videos = None,
334
+ **kwargs: Unpack[VILAProcessorKwargs],
335
+ ) -> BatchFeature:
336
+ conversation = copy.deepcopy(conversation)
337
+ media = extract_media(conversation, self.config)
338
+ # Process media
339
+ media_config = defaultdict(dict)
340
+ for name in media:
341
+ if name == "image":
342
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
343
+ self.config.image_processor = self.image_processor
344
+ if self.config.image_aspect_ratio == "dynamic":
345
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
346
+ # NOTE: this only works for images appears at the first conversation
347
+ conversation[0]["value"] = conversation[0]["value"].replace(
348
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
349
+ )
350
+ else:
351
+ if type(self.config.s2_scales) is str:
352
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
353
+ images, block_sizes = process_image(
354
+ media["image"][0], self.config, None, enable_dynamic_s2=True
355
+ )
356
+ images = images.half()
357
+ media_config[name]["block_sizes"] = [block_sizes]
358
+ else:
359
+ images = process_images(media["image"], self.image_processor, self.config).half()
360
+ media[name] = [image for image in images]
361
+ elif name == "video":
362
+ media[name] = [
363
+ process_images(images, self.image_processor, self.config).half() for images in media[name]
364
+ ]
365
+ else:
366
+ raise ValueError(f"Unsupported media type: {name}")
367
+
368
+ inputs = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True, return_ids_only=False)
369
+ input_ids = inputs.input_ids[0].unsqueeze(0).cuda()
370
+
371
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
372
+ return BatchFeature(
373
+ data={
374
+ "input_ids": input_ids,
375
+ "attention_mask": attention_mask,
376
+ "media": media,
377
+ "media_config": media_config,
378
+ }
379
+ )
380
+
381
+ def batch_decode(self, *args, **kwargs):
382
+ """
383
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
384
+ refer to the docstring of this method for more information.
385
+ """
386
+ return self.tokenizer.batch_decode(*args, **kwargs)
387
+
388
+ def decode(self, *args, **kwargs):
389
+ """
390
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
391
+ the docstring of this method for more information.
392
+ """
393
+ return self.tokenizer.decode(*args, **kwargs)
394
+
395
+ def post_process_image_text_to_text(self, generated_outputs):
396
+ """
397
+ Post-process the output of the model to decode the text.
398
+
399
+ Args:
400
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
401
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
402
+ or `(sequence_length,)`.
403
+
404
+ Returns:
405
+ `List[str]`: The decoded text.
406
+ """
407
+ return self.tokenizer.batch_decode(
408
+ generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
409
+ )
410
+
411
+ @property
412
+ def model_input_names(self):
413
+ tokenizer_input_names = self.tokenizer.model_input_names
414
+ image_processor_input_names = self.image_processor.model_input_names
415
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
416
+
417
+ def convert_gpt_conv_to_vila_conv(self, conversation):
418
+ vila_conv = []
419
+ for chat in conversation:
420
+ vila_chat = {"from": "", "value": []}
421
+ if chat["role"] in ("user", "system"):
422
+ # user allows to input image and text
423
+ vila_chat["from"] = "human" if chat["role"] == "user" else "system"
424
+ vila_chat["value"] = extract_value_from_conv(chat)
425
+ elif chat["role"] == "assistant":
426
+ vila_chat["from"] = "gpt"
427
+ vila_chat["value"] = extract_value_from_conv(chat)
428
+ else:
429
+ raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
430
+ vila_conv.append(vila_chat)
431
+
432
+ return vila_conv
433
+
434
+ def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
435
+ return self.convert_gpt_conv_to_vila_conv(conversation)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ # gpt style: user, assistant
440
+ # vila style: human, gpt
441
+ gpt_conv = [
442
+ {
443
+ "role": "user",
444
+ "content": [
445
+ {"type": "image", "path": "demo_images/demo_img_1.png"},
446
+ {"type": "text", "text": "Describe this image."},
447
+ ],
448
+ }
449
+ ]
450
+
451
+ llavaconv = [
452
+ {
453
+ "from": "human",
454
+ "value": [
455
+ PIL.Image.open("demo_images/demo_img_1.png"),
456
+ "Describe this image.",
457
+ ],
458
+ }
459
+ ]
460
+
461
+ processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True)
462
+ inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
463
+ # model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda()
464
+ # print(model)
465
+ model_path = "NVILA-Lite-2B-hf-preview"
466
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
467
+ # res = model.generate_content(["how are you today?"])
468
+ # print(model.config)
469
+ # print(model.tokenizer)
470
+ # print(res)
471
+
472
+ processor = VILAProcessor(
473
+ config=model.config,
474
+ image_processor=model.vision_tower.image_processor,
475
+ tokenizer=model.tokenizer,
476
+ )
477
+
478
+ inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
479
+ print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
480
+ print("vila conv pass")
481
+
482
+ inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
483
+ print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
484
+ print("gpt conv pass")
485
+
486
+ output_ids = model.generate(
487
+ input_ids=inputs.input_ids,
488
+ media={
489
+ "image": inputs.image,
490
+ },
491
+ media_config={"image": {}},
492
+ generation_config=model.generation_config,
493
+ max_new_tokens=100,
494
+ )
495
+ print(output_ids)
base_projector.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import re
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
22
+
23
+
24
+ class IdentityMap(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def forward(self, x, *args, **kwargs):
29
+ return x
30
+
31
+ @property
32
+ def config(self):
33
+ return {"mm_projector_type": "identity"}
34
+
35
+
36
+ class SimpleResBlock(nn.Module):
37
+ def __init__(self, channels):
38
+ super().__init__()
39
+ self.pre_norm = nn.LayerNorm(channels)
40
+
41
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
42
+
43
+ def forward(self, x):
44
+ x = self.pre_norm(x)
45
+ return x + self.proj(x)
46
+
47
+
48
+ class DownSampleBlock(nn.Module):
49
+ def forward(self, x):
50
+ vit_embeds = x
51
+ h = w = int(vit_embeds.shape[1] ** 0.5)
52
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
53
+ vit_embeds = self.flat_square(vit_embeds)
54
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
55
+ return vit_embeds
56
+
57
+ def flat_square(self, x):
58
+ n, w, h, c = x.size()
59
+ if w % 2 == 1:
60
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
61
+ n, w, h, c = x.size()
62
+ if h % 2 == 1:
63
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
64
+ n, w, h, c = x.size()
65
+ x = x.contiguous()
66
+ x = x.view(n, w, int(h / 2), int(c * 2))
67
+ x = x.permute(0, 2, 1, 3).contiguous()
68
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
69
+ x = x.permute(0, 2, 1, 3).contiguous()
70
+ return x
71
+
72
+
73
+ class DownSample2x2BlockFix(nn.Module):
74
+ def forward(self, x):
75
+ vit_embeds = x
76
+ h = w = int(vit_embeds.shape[1] ** 0.5)
77
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
78
+ vit_embeds = flat_square_2x2(vit_embeds)
79
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
80
+ return vit_embeds
81
+
82
+
83
+ def flat_square_2x2(x):
84
+ n, w, h, c = x.size()
85
+ if w % 2 == 1:
86
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
87
+ n, w, h, c = x.size()
88
+ x = x.contiguous()
89
+ if h % 2 == 1:
90
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
91
+ n, w, h, c = x.size()
92
+ x = x.view(n, w, int(h / 2), int(c * 2))
93
+ x = x.permute(0, 2, 1, 3).contiguous()
94
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
95
+ x = x.permute(0, 2, 1, 3).contiguous()
96
+ return x
97
+
98
+
99
+ class DownSample3x3BlockFix(nn.Module):
100
+ def forward(self, x):
101
+ vit_embeds = x
102
+ h = w = int(vit_embeds.shape[1] ** 0.5)
103
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
104
+ vit_embeds = flat_square_3x3(vit_embeds)
105
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
106
+ return vit_embeds
107
+
108
+
109
+ def flat_square_3x3(x):
110
+ n, w, h, c = x.size()
111
+ if w % 3 != 0:
112
+ x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
113
+ n, w, h, c = x.size()
114
+ x = x.contiguous()
115
+ if h % 3 != 0:
116
+ x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
117
+ n, w, h, c = x.size()
118
+ x = x.view(n, w, int(h / 3), int(c * 3))
119
+ x = x.permute(0, 2, 1, 3).contiguous()
120
+ x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
121
+ x = x.permute(0, 2, 1, 3).contiguous()
122
+ return x
123
+
124
+
125
+ class MultimodalProjectorConfig(PretrainedConfig):
126
+ model_type = "v2l_projector"
127
+
128
+ def __init__(self, mm_projector_type: str = None, **kwargs):
129
+ super().__init__()
130
+ self.mm_projector_type = mm_projector_type
131
+
132
+
133
+ class MultimodalProjector(PreTrainedModel):
134
+ config_class = MultimodalProjectorConfig
135
+
136
+ def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
137
+ super().__init__(mm_projector_cfg)
138
+ mm_projector_type = mm_projector_cfg.mm_projector_type
139
+ self.downsample_rate = 1
140
+ if mm_projector_type == "identity":
141
+ self.layers = IdentityMap()
142
+ elif mm_projector_type == "linear":
143
+ self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
144
+ elif mm_projector_type == "mlp_downsample":
145
+ self.layers = nn.Sequential(
146
+ DownSampleBlock(),
147
+ nn.LayerNorm(config.mm_hidden_size * 4),
148
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
149
+ nn.GELU(),
150
+ nn.Linear(config.hidden_size, config.hidden_size),
151
+ )
152
+ self.downsample_rate = 2
153
+ elif mm_projector_type == "mlp_downsample_2x2_fix":
154
+ self.layers = nn.Sequential(
155
+ DownSample2x2BlockFix(),
156
+ nn.LayerNorm(config.mm_hidden_size * 4),
157
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
158
+ nn.GELU(),
159
+ nn.Linear(config.hidden_size, config.hidden_size),
160
+ )
161
+ self.downsample_rate = 2
162
+ elif mm_projector_type == "mlp_downsample_3x3_fix":
163
+ self.layers = nn.Sequential(
164
+ DownSample3x3BlockFix(),
165
+ nn.LayerNorm(config.mm_hidden_size * 9),
166
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
167
+ nn.GELU(),
168
+ nn.LayerNorm(config.mm_hidden_size * 3),
169
+ nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(config.hidden_size, config.hidden_size),
172
+ )
173
+ self.downsample_rate = 3
174
+ elif mm_projector_type == "mlp_downsample_3x3_s2":
175
+ self.layers = nn.Sequential(
176
+ DownSample3x3BlockFix(),
177
+ nn.LayerNorm(config.mm_hidden_size * 9),
178
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
179
+ nn.GELU(),
180
+ nn.LayerNorm(config.mm_hidden_size * 3),
181
+ nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
182
+ nn.GELU(),
183
+ nn.LayerNorm(config.mm_hidden_size),
184
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
185
+ nn.GELU(),
186
+ nn.LayerNorm(config.mm_hidden_size // 3),
187
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
188
+ nn.GELU(),
189
+ nn.Linear(config.hidden_size, config.hidden_size),
190
+ )
191
+ elif mm_projector_type == "mlp_downsample_3x3_s2_new":
192
+ self.layers = nn.Sequential(
193
+ DownSample3x3BlockFix(),
194
+ nn.LayerNorm(config.mm_hidden_size * 9),
195
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
196
+ nn.GELU(),
197
+ nn.LayerNorm(config.mm_hidden_size * 4),
198
+ nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
199
+ nn.GELU(),
200
+ nn.LayerNorm(config.mm_hidden_size * 2),
201
+ nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
202
+ nn.GELU(),
203
+ nn.LayerNorm(config.mm_hidden_size),
204
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
205
+ nn.GELU(),
206
+ nn.LayerNorm(config.mm_hidden_size // 3),
207
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
208
+ nn.GELU(),
209
+ nn.Linear(config.hidden_size, config.hidden_size),
210
+ )
211
+ else:
212
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
213
+ if mlp_gelu_match:
214
+ mlp_depth = int(mlp_gelu_match.group(1))
215
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
216
+ for _ in range(1, mlp_depth):
217
+ modules.append(nn.GELU())
218
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
219
+ self.layers = nn.Sequential(*modules)
220
+ else:
221
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
222
+
223
+ def forward(self, x, *args, **kwargs):
224
+ return self.layers(x)
225
+
226
+
227
+ # AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
228
+ # AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
builder.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import math
18
+ import os
19
+ import os.path as osp
20
+ import warnings
21
+ from dataclasses import asdict
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
23
+
24
+ import torch
25
+ import transformers
26
+ from huggingface_hub import file_exists, repo_exists
27
+ from huggingface_hub.utils import HFValidationError
28
+ from transformers import (
29
+ AutoConfig,
30
+ AutoModelForCausalLM,
31
+ AutoTokenizer,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizer,
35
+ )
36
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
37
+
38
+ # from .conversation import *
39
+ from .conversation import SeparatorStyle, default_conversation
40
+
41
+ SENTINEL_TOKEN = "<vila/sentinel>"
42
+ MEDIA_TOKENS = {
43
+ "image": "<image>",
44
+ "video": "<vila/video>",
45
+ }
46
+
47
+ # from llava.model.utils import packing
48
+ # from llava.utils.logging import logger
49
+ # from llava.utils.tokenizer import infer_stop_tokens
50
+
51
+ DUMMY_CONVERSATION = [
52
+ {"from": "human", "value": "question"},
53
+ {"from": "gpt", "value": "answer"},
54
+ ] * 10
55
+
56
+
57
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
58
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
59
+
60
+
61
+ def has_tokenizer(repo_id_or_path: str) -> bool:
62
+ # Check if the tokenizer is in a local directory
63
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
64
+ return True
65
+
66
+ # Check if the tokenizer is in a Hugging Face Hub repo
67
+ try:
68
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
69
+ except HFValidationError:
70
+ return False
71
+
72
+
73
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
74
+ if not hasattr(tokenizer, "sentinel_token"):
75
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
76
+ tokenizer.sentinel_token = SENTINEL_TOKEN
77
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
78
+
79
+
80
+ def tokenize_conversation_legacy(
81
+ messages: Sequence[Dict[str, str]],
82
+ tokenizer: transformers.PreTrainedTokenizer,
83
+ add_generation_prompt: bool = False,
84
+ overrides: Optional[Dict[str, str]] = None,
85
+ no_system_prompt: bool = False,
86
+ ) -> torch.Tensor:
87
+ conv = default_conversation.copy()
88
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
89
+
90
+ if no_system_prompt:
91
+ conv.system = ""
92
+
93
+ # Skip the first message if it is not from human
94
+ if messages[0]["from"] != "human":
95
+ messages = messages[1:]
96
+
97
+ # Add a generation prompt if needed
98
+ if add_generation_prompt:
99
+ messages.append({"from": "gpt", "value": None})
100
+
101
+ conv.messages = []
102
+ for turn, message in enumerate(messages):
103
+ role = roles[message["from"]]
104
+ assert role == conv.roles[turn % 2]
105
+ if overrides is not None and message["from"] in overrides:
106
+ conv.append_message(role, overrides[message["from"]])
107
+ else:
108
+ conv.append_message(role, message["value"])
109
+
110
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
111
+
112
+
113
+ def tokenize_conversation(
114
+ messages: Sequence[Dict[str, str]],
115
+ tokenizer: transformers.PreTrainedTokenizer,
116
+ add_generation_prompt: bool = False,
117
+ overrides: Optional[Dict[str, str]] = None,
118
+ no_system_prompt: bool = False,
119
+ ) -> torch.Tensor:
120
+ # Normalize the conversation before tokenization
121
+ for message in messages:
122
+ message["value"] = message["value"].strip()
123
+
124
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
125
+ return tokenize_conversation_legacy(
126
+ messages,
127
+ tokenizer,
128
+ add_generation_prompt=add_generation_prompt,
129
+ overrides=overrides,
130
+ no_system_prompt=no_system_prompt,
131
+ )
132
+
133
+ conversation = []
134
+ for m in messages:
135
+ message = {}
136
+ if m["from"] == "human":
137
+ message["role"] = "user"
138
+ elif m["from"] == "gpt":
139
+ message["role"] = "assistant"
140
+ else:
141
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
142
+
143
+ message["content"] = m["value"]
144
+ if overrides is not None and m["from"] in overrides:
145
+ message["content"] = overrides[m["from"]]
146
+ conversation.append(message)
147
+
148
+ if no_system_prompt:
149
+ conversation = [{"role": "system", "content": ""}] + conversation
150
+
151
+ text = tokenizer.apply_chat_template(
152
+ conversation,
153
+ add_generation_prompt=add_generation_prompt,
154
+ tokenize=False,
155
+ )
156
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
157
+
158
+
159
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
160
+ _maybe_add_sentinel_token(tokenizer)
161
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
162
+
163
+ stop_tokens = {tokenizer.eos_token}
164
+ for k in range(template.size(0) - 1):
165
+ if template[k] == tokenizer.sentinel_token_id:
166
+ stop_token = tokenizer.decode(template[k + 1])
167
+ stop_tokens.add(stop_token)
168
+ return list(stop_tokens)
169
+
170
+
171
+ def context_length_extension(config):
172
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
173
+ model_max_length = getattr(config, "model_max_length", None)
174
+ if orig_ctx_len and model_max_length > orig_ctx_len:
175
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
176
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
177
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
178
+ return config
179
+
180
+
181
+ def build_llm_and_tokenizer(
182
+ model_name_or_path: str,
183
+ config: PretrainedConfig,
184
+ attn_implementation=None,
185
+ model_max_length=None,
186
+ *args,
187
+ **kwargs,
188
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
189
+ # print(model_name_or_path)
190
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
191
+ llm_cfg._attn_implementation = attn_implementation
192
+ llm_cfg.model_max_length = model_max_length
193
+ if model_max_length is not None:
194
+ context_length_extension(llm_cfg)
195
+
196
+ # Quantization related
197
+ quantization_restore_from_checkpoint = False
198
+
199
+ if quantization_restore_from_checkpoint:
200
+ fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
201
+
202
+ llm = AutoModelForCausalLM.from_pretrained(
203
+ fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
204
+ )
205
+ else:
206
+ if is_deepspeed_zero3_enabled():
207
+ # NOTE: found by wei, need to pop out device_map when using zero3
208
+ kwargs.pop("device_map")
209
+ llm = AutoModelForCausalLM.from_pretrained(
210
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
211
+ )
212
+ # packing.patch(llm)
213
+
214
+ # Locate the tokenizer.
215
+ llm_path = model_name_or_path
216
+ if not has_tokenizer(llm_path):
217
+ llm_path = osp.join(llm_path, "llm")
218
+ if not has_tokenizer(llm_path):
219
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
220
+
221
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
222
+ if model_max_length is not None:
223
+ tokenizer.model_max_length = model_max_length
224
+
225
+ # Load chat template if specified.
226
+ if getattr(config, "chat_template", None) is not None:
227
+ print(f"Using chat template: {config.chat_template}")
228
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
229
+ if not os.path.exists(fpath):
230
+ fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
231
+ with open(fpath) as fd:
232
+ chat_template = fd.read()
233
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
234
+
235
+ # Set stop tokens for the tokenizer
236
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
237
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
238
+
239
+ # Add media tokens to the tokenizer
240
+ tokenizer.media_tokens = MEDIA_TOKENS
241
+ tokenizer.media_token_ids = {}
242
+ for name, token in MEDIA_TOKENS.items():
243
+ tokenizer.add_tokens([token], special_tokens=True)
244
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
245
+
246
+ config.hidden_size = llm.config.hidden_size
247
+ return llm, tokenizer
config.json ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "./LongVILA-R1-7B",
4
+ "architectures": [
5
+ "VILAForCausalLM"
6
+ ],
7
+ "chat_template": null,
8
+ "drop_path_rate": 0.0,
9
+ "fps": 0.0,
10
+ "hidden_size": 3584,
11
+ "image_aspect_ratio": "resize",
12
+ "image_encoder": {
13
+ "_target_": "llava.model.encoders.BasicImageEncoder"
14
+ },
15
+ "interpolate_mode": "linear",
16
+ "llm_cfg": {
17
+ "_attn_implementation_autoset": false,
18
+ "_name_or_path": "./LongVILA-R1-7B/llm",
19
+ "add_cross_attention": false,
20
+ "architectures": [
21
+ "Qwen2ForCausalLM"
22
+ ],
23
+ "attention_dropout": 0.0,
24
+ "bad_words_ids": null,
25
+ "begin_suppress_tokens": null,
26
+ "bos_token_id": 151643,
27
+ "chunk_size_feed_forward": 0,
28
+ "cross_attention_hidden_size": null,
29
+ "decoder_start_token_id": null,
30
+ "diversity_penalty": 0.0,
31
+ "do_sample": false,
32
+ "early_stopping": false,
33
+ "encoder_no_repeat_ngram_size": 0,
34
+ "eos_token_id": 151645,
35
+ "exponential_decay_length_penalty": null,
36
+ "finetuning_task": null,
37
+ "forced_bos_token_id": null,
38
+ "forced_eos_token_id": null,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 3584,
41
+ "id2label": {
42
+ "0": "LABEL_0",
43
+ "1": "LABEL_1"
44
+ },
45
+ "initializer_range": 0.02,
46
+ "intermediate_size": 18944,
47
+ "is_decoder": false,
48
+ "is_encoder_decoder": false,
49
+ "label2id": {
50
+ "LABEL_0": 0,
51
+ "LABEL_1": 1
52
+ },
53
+ "length_penalty": 1.0,
54
+ "max_length": 20,
55
+ "max_position_embeddings": 32768,
56
+ "max_window_layers": 28,
57
+ "min_length": 0,
58
+ "model_max_length": 32768,
59
+ "model_type": "qwen2",
60
+ "no_repeat_ngram_size": 0,
61
+ "num_attention_heads": 28,
62
+ "num_beam_groups": 1,
63
+ "num_beams": 1,
64
+ "num_hidden_layers": 28,
65
+ "num_key_value_heads": 4,
66
+ "num_return_sequences": 1,
67
+ "output_attentions": false,
68
+ "output_hidden_states": false,
69
+ "output_scores": false,
70
+ "pad_token_id": null,
71
+ "prefix": null,
72
+ "problem_type": null,
73
+ "pruned_heads": {},
74
+ "remove_invalid_values": false,
75
+ "repetition_penalty": 1.0,
76
+ "return_dict": true,
77
+ "return_dict_in_generate": false,
78
+ "rms_norm_eps": 1e-06,
79
+ "rope_scaling": null,
80
+ "rope_theta": 1000000.0,
81
+ "sep_token_id": null,
82
+ "sliding_window": null,
83
+ "suppress_tokens": null,
84
+ "task_specific_params": null,
85
+ "temperature": 1.0,
86
+ "tf_legacy_loss": false,
87
+ "tie_encoder_decoder": false,
88
+ "tie_word_embeddings": false,
89
+ "tokenizer_class": null,
90
+ "tokenizer_model_max_length": 4096,
91
+ "tokenizer_padding_side": "right",
92
+ "top_k": 50,
93
+ "top_p": 1.0,
94
+ "torch_dtype": "bfloat16",
95
+ "torchscript": false,
96
+ "typical_p": 1.0,
97
+ "use_bfloat16": false,
98
+ "use_cache": false,
99
+ "use_sliding_window": false,
100
+ "vocab_size": 151651
101
+ },
102
+ "mm_hidden_size": 1152,
103
+ "mm_projector": "mlp_downsample_2x2_fix",
104
+ "mm_projector_cfg": {
105
+ "_attn_implementation_autoset": false,
106
+ "_name_or_path": "./LongVILA-R1-7B/mm_projector",
107
+ "add_cross_attention": false,
108
+ "architectures": [
109
+ "MultimodalProjector"
110
+ ],
111
+ "bad_words_ids": null,
112
+ "begin_suppress_tokens": null,
113
+ "bos_token_id": null,
114
+ "chunk_size_feed_forward": 0,
115
+ "cross_attention_hidden_size": null,
116
+ "decoder_start_token_id": null,
117
+ "diversity_penalty": 0.0,
118
+ "do_sample": false,
119
+ "early_stopping": false,
120
+ "encoder_no_repeat_ngram_size": 0,
121
+ "eos_token_id": null,
122
+ "exponential_decay_length_penalty": null,
123
+ "finetuning_task": null,
124
+ "forced_bos_token_id": null,
125
+ "forced_eos_token_id": null,
126
+ "id2label": {
127
+ "0": "LABEL_0",
128
+ "1": "LABEL_1"
129
+ },
130
+ "is_decoder": false,
131
+ "is_encoder_decoder": false,
132
+ "label2id": {
133
+ "LABEL_0": 0,
134
+ "LABEL_1": 1
135
+ },
136
+ "length_penalty": 1.0,
137
+ "max_length": 20,
138
+ "min_length": 0,
139
+ "mm_projector_type": "mlp_downsample_2x2_fix",
140
+ "model_type": "v2l_projector",
141
+ "no_repeat_ngram_size": 0,
142
+ "num_beam_groups": 1,
143
+ "num_beams": 1,
144
+ "num_return_sequences": 1,
145
+ "output_attentions": false,
146
+ "output_hidden_states": false,
147
+ "output_scores": false,
148
+ "pad_token_id": null,
149
+ "prefix": null,
150
+ "problem_type": null,
151
+ "pruned_heads": {},
152
+ "remove_invalid_values": false,
153
+ "repetition_penalty": 1.0,
154
+ "return_dict": true,
155
+ "return_dict_in_generate": false,
156
+ "sep_token_id": null,
157
+ "suppress_tokens": null,
158
+ "task_specific_params": null,
159
+ "temperature": 1.0,
160
+ "tf_legacy_loss": false,
161
+ "tie_encoder_decoder": false,
162
+ "tie_word_embeddings": true,
163
+ "tokenizer_class": null,
164
+ "top_k": 50,
165
+ "top_p": 1.0,
166
+ "torch_dtype": "bfloat16",
167
+ "torchscript": false,
168
+ "typical_p": 1.0,
169
+ "use_bfloat16": false
170
+ },
171
+ "mm_projector_lr": null,
172
+ "mm_use_im_patch_token": false,
173
+ "mm_use_im_start_end": false,
174
+ "mm_vision_select_feature": "cls_patch",
175
+ "mm_vision_select_layer": -2,
176
+ "model_dtype": "torch.bfloat16",
177
+ "model_name_or_path": "./LongVILA-R1-7B",
178
+ "model_type": "vila",
179
+ "num_time_tokens": 0,
180
+ "num_video_frames": 256,
181
+ "resume_path": "./LongVILA-R1-7B",
182
+ "s2": false,
183
+ "s2_max_split_size": 336,
184
+ "s2_scales": "336,672,1008",
185
+ "soft_ce_std": 1.0,
186
+ "time_token_format": "<t{t}>",
187
+ "time_token_ids": [],
188
+ "transformers_version": "4.46.2",
189
+ "tune_language_model": true,
190
+ "tune_mm_projector": true,
191
+ "tune_vision_tower": true,
192
+ "version": "2.0",
193
+ "video_encoder": {
194
+ "_target_": "llava.model.encoders.TSPVideoEncoder",
195
+ "pool_sizes": [
196
+ [
197
+ 8,
198
+ 1,
199
+ 1
200
+ ]
201
+ ]
202
+ },
203
+ "video_max_tiles": 1,
204
+ "vision_resolution": -1,
205
+ "vision_tower": "Efficient-Large-Model/paligemma-siglip-so400m-patch14-448",
206
+ "vision_tower_cfg": {
207
+ "_attn_implementation_autoset": false,
208
+ "_name_or_path": "./LongVILA-R1-7B/vision_tower",
209
+ "add_cross_attention": false,
210
+ "architectures": [
211
+ "SiglipVisionModel"
212
+ ],
213
+ "attention_dropout": 0.0,
214
+ "bad_words_ids": null,
215
+ "begin_suppress_tokens": null,
216
+ "bos_token_id": null,
217
+ "chunk_size_feed_forward": 0,
218
+ "cross_attention_hidden_size": null,
219
+ "decoder_start_token_id": null,
220
+ "diversity_penalty": 0.0,
221
+ "do_sample": false,
222
+ "early_stopping": false,
223
+ "encoder_no_repeat_ngram_size": 0,
224
+ "eos_token_id": null,
225
+ "exponential_decay_length_penalty": null,
226
+ "finetuning_task": null,
227
+ "forced_bos_token_id": null,
228
+ "forced_eos_token_id": null,
229
+ "hidden_act": "gelu_pytorch_tanh",
230
+ "hidden_size": 1152,
231
+ "id2label": {
232
+ "0": "LABEL_0",
233
+ "1": "LABEL_1"
234
+ },
235
+ "image_size": 448,
236
+ "intermediate_size": 4304,
237
+ "is_decoder": false,
238
+ "is_encoder_decoder": false,
239
+ "label2id": {
240
+ "LABEL_0": 0,
241
+ "LABEL_1": 1
242
+ },
243
+ "layer_norm_eps": 1e-06,
244
+ "length_penalty": 1.0,
245
+ "max_length": 20,
246
+ "min_length": 0,
247
+ "model_type": "siglip_vision_model",
248
+ "no_repeat_ngram_size": 0,
249
+ "num_attention_heads": 16,
250
+ "num_beam_groups": 1,
251
+ "num_beams": 1,
252
+ "num_channels": 3,
253
+ "num_hidden_layers": 27,
254
+ "num_image_tokens": 256,
255
+ "num_return_sequences": 1,
256
+ "output_attentions": false,
257
+ "output_hidden_states": false,
258
+ "output_scores": false,
259
+ "pad_token_id": null,
260
+ "patch_size": 14,
261
+ "prefix": null,
262
+ "problem_type": null,
263
+ "projection_dim": 2048,
264
+ "projector_hidden_act": "gelu_fast",
265
+ "pruned_heads": {},
266
+ "remove_invalid_values": false,
267
+ "repetition_penalty": 1.0,
268
+ "return_dict": true,
269
+ "return_dict_in_generate": false,
270
+ "sep_token_id": null,
271
+ "suppress_tokens": null,
272
+ "task_specific_params": null,
273
+ "temperature": 1.0,
274
+ "tf_legacy_loss": false,
275
+ "tie_encoder_decoder": false,
276
+ "tie_word_embeddings": true,
277
+ "tokenizer_class": null,
278
+ "top_k": 50,
279
+ "top_p": 1.0,
280
+ "torch_dtype": "bfloat16",
281
+ "torchscript": false,
282
+ "typical_p": 1.0,
283
+ "use_bfloat16": false,
284
+ "vision_use_head": false
285
+ },
286
+ "vision_tower_lr": null,
287
+ "weight_memory_efficient": true,
288
+ "xvila_mode": false,
289
+ "auto_map": {
290
+ "AutoProcessor": "auto_processor.VILAProcessor",
291
+ "AutoConfig": "modeling_vila.VILAConfig",
292
+ "AutoModel": "modeling_vila.VILAForCausalLM",
293
+ "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM"
294
+ }
295
+ }
configuration_vila.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import os.path as osp
5
+ from copy import deepcopy
6
+ from threading import Thread
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torchvision
11
+ from PIL import Image
12
+ from transformers import (
13
+ AutoProcessor,
14
+ PretrainedConfig,
15
+ PreTrainedModel,
16
+ Qwen2Config,
17
+ Qwen2ForCausalLM,
18
+ Qwen2PreTrainedModel,
19
+ TextIteratorStreamer,
20
+ )
21
+
22
+
23
+ class VILAConfig(PretrainedConfig):
24
+ model_type = "vila"
25
+ keys_to_ignore_at_inference = ["past_key_values"]
26
+
27
+ def __init__(
28
+ self,
29
+ llm_cfg=None,
30
+ vision_tower_cfg=None,
31
+ mm_projector_cfg=None,
32
+ architectures=None,
33
+ resume_path=None,
34
+ hidden_size=None,
35
+ mm_hidden_size=None,
36
+ image_aspect_ratio=None,
37
+ num_video_frames=None,
38
+ fps=None,
39
+ mm_vision_select_layer=None,
40
+ mm_vision_select_feature=None,
41
+ mm_use_im_start_end=False,
42
+ mm_use_im_patch_token=False,
43
+ mm_projector_lr=None,
44
+ vision_tower_lr=None,
45
+ vision_resolution=None,
46
+ interpolate_mode=None,
47
+ s2=None,
48
+ dynamic_s2=None,
49
+ s2_scales=None,
50
+ s2_max_split_size=None,
51
+ s2_resize_output_to_scale_idx=0,
52
+ min_tiles: Optional[int] = 1,
53
+ max_tiles: Optional[int] = 12,
54
+ num_time_tokens=None,
55
+ time_token_format=None,
56
+ image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
57
+ video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
58
+ **kwargs,
59
+ ):
60
+ super().__init__(**kwargs)
61
+
62
+ self.architectures = architectures
63
+ self.llm_cfg = llm_cfg
64
+ self.vision_tower_cfg = vision_tower_cfg
65
+ self.mm_projector_cfg = mm_projector_cfg
66
+ self.resume_path = resume_path
67
+
68
+ self.hidden_size = hidden_size
69
+ self.mm_hidden_size = mm_hidden_size
70
+ self.image_aspect_ratio = image_aspect_ratio
71
+ self.num_video_frames = num_video_frames
72
+ self.fps = fps
73
+ self.mm_vision_select_layer = mm_vision_select_layer
74
+ self.mm_vision_select_feature = mm_vision_select_feature
75
+ self.mm_use_im_start_end = mm_use_im_start_end
76
+ self.mm_use_im_patch_token = mm_use_im_patch_token
77
+ self.mm_projector_lr = mm_projector_lr
78
+ self.vision_tower_lr = vision_tower_lr
79
+ self.vision_resolution = vision_resolution
80
+ self.interpolate_mode = interpolate_mode
81
+ self.s2 = s2
82
+ self.dynamic_s2 = dynamic_s2
83
+ self.s2_scales = s2_scales
84
+ self.s2_max_split_size = s2_max_split_size
85
+ self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
86
+ self.min_tiles = min_tiles
87
+ self.max_tiles = max_tiles
88
+ self.num_time_tokens = num_time_tokens
89
+ self.time_token_format = time_token_format
90
+
91
+ self.image_encoder = image_encoder
92
+ self.video_encoder = video_encoder
constants.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
20
+ WORKER_HEART_BEAT_INTERVAL = 15
21
+
22
+ LOGDIR = "."
23
+
24
+ # Model Constants
25
+ IGNORE_INDEX = -100
26
+ DEFAULT_IMAGE_TOKEN = "<image>"
27
+ DEFAULT_SOUND_TOKEN = "<sound>"
28
+ DEFAULT_SPEECH_TOKEN = "<speech>"
29
+ SENTINEL_TOKEN = "<vila/sentinel>"
30
+ DEFAULT_IM_START_TOKEN = "<im_start>"
31
+ DEFAULT_IM_END_TOKEN = "<im_end>"
32
+
33
+
34
+ SENTINEL_TOKEN = "<vila/sentinel>"
35
+
36
+ MEDIA_TOKENS = {
37
+ "image": "<image>",
38
+ "video": "<vila/video>",
39
+ "speech": "<speech>",
40
+ "sound": "<sound>",
41
+ }
42
+
43
+ # <image> <vila/video> <vila/sentinel>
44
+ """
45
+ vila:
46
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
47
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
48
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
49
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
50
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
51
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
52
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
53
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
54
+
55
+ xvila:
56
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
57
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
58
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
59
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
60
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
61
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
62
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
63
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
64
+ 151651: AddedToken("<speech>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
65
+ 151652: AddedToken("<sound>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
66
+ 151653: AddedToken("<|image_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
67
+ 151654: AddedToken("<|image_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
68
+ 151655: AddedToken("<|video_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
69
+ 151656: AddedToken("<|video_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
70
+ 151657: AddedToken("<|speech_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
71
+ 151658: AddedToken("<|speech_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
72
+ 151659: AddedToken("<|sound_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
73
+ 151660: AddedToken("<|sound_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
74
+ """
75
+ MM_BOS_EOS_TOKENS = {
76
+ "image": ["<|image_bos|>", "<|image_eos|>"],
77
+ "video": ["<|video_bos|>", "<|video_eos|>"],
78
+ "speech": ["<|speech_bos|>", "<|speech_eos|>"],
79
+ "sound": ["<|sound_bos|>", "<|sound_eos|>"],
80
+ }
81
+
82
+ NUM_EXTRA_TOKENS_VILA = 8
83
+ NUM_EXTRA_TOKENS_XVILA = 10
conversation.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+ import dataclasses
19
+ from enum import Enum, auto
20
+ from typing import List
21
+
22
+ # from llava.utils.logging import logger
23
+
24
+
25
+ class SeparatorStyle(Enum):
26
+ """Different separator style."""
27
+
28
+ AUTO = auto()
29
+ TWO = auto()
30
+ MPT = auto()
31
+ PLAIN = auto()
32
+ LLAMA_3 = auto()
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Conversation:
37
+ """A class that keeps all conversation history."""
38
+
39
+ system: str
40
+ roles: List[str]
41
+ messages: List[List[str]]
42
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
43
+ sep: str = "###"
44
+ sep2: str = None
45
+ version: str = "Unknown"
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0].replace("<image>", "").strip()
53
+ messages[0] = (init_role, "<image>\n" + init_msg)
54
+
55
+ if self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
66
+ ret = self.system + self.sep
67
+ for rid, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
72
+ ret += role + message + sep
73
+ else:
74
+ ret += role
75
+ elif self.sep_style == SeparatorStyle.MPT:
76
+ ret = self.system + self.sep
77
+ for role, message in messages:
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + message + self.sep
82
+ else:
83
+ ret += role
84
+ elif self.sep_style == SeparatorStyle.PLAIN:
85
+ seps = [self.sep, self.sep2]
86
+ ret = self.system
87
+ for i, (role, message) in enumerate(messages):
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, _, _ = message
91
+ ret += message + seps[i % 2]
92
+ else:
93
+ ret += ""
94
+ else:
95
+ raise ValueError(f"Invalid style: {self.sep_style}")
96
+
97
+ return ret
98
+
99
+ def append_message(self, role, message):
100
+ self.messages.append([role, message])
101
+
102
+ def copy(self):
103
+ return Conversation(
104
+ system=self.system,
105
+ roles=self.roles,
106
+ messages=[[x, y] for x, y in self.messages],
107
+ sep_style=self.sep_style,
108
+ sep=self.sep,
109
+ sep2=self.sep2,
110
+ version=self.version,
111
+ )
112
+
113
+
114
+ conv_auto = Conversation(
115
+ system="",
116
+ roles=("", ""),
117
+ messages=(),
118
+ sep_style=SeparatorStyle.AUTO,
119
+ sep="\n",
120
+ )
121
+
122
+ conv_vicuna_v1 = Conversation(
123
+ system="A chat between a curious user and an artificial intelligence assistant. "
124
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
125
+ roles=("USER", "ASSISTANT"),
126
+ version="v1",
127
+ messages=(),
128
+ sep_style=SeparatorStyle.TWO,
129
+ sep=" ",
130
+ sep2="</s>",
131
+ )
132
+
133
+ conv_llava_plain = Conversation(
134
+ system="",
135
+ roles=("", ""),
136
+ messages=(),
137
+ sep_style=SeparatorStyle.PLAIN,
138
+ sep="\n",
139
+ )
140
+
141
+ hermes_2 = Conversation(
142
+ system="<|im_start|>system\nAnswer the questions.",
143
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
144
+ sep_style=SeparatorStyle.MPT,
145
+ sep="<|im_end|>",
146
+ messages=(),
147
+ version="hermes-2",
148
+ )
149
+
150
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
151
+ llama_3_chat = Conversation(
152
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
153
+ "You are able to understand the visual content that the user provides, "
154
+ "and assist the user with a variety of tasks using natural language.",
155
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
156
+ version="llama_v3",
157
+ messages=(),
158
+ sep_style=SeparatorStyle.LLAMA_3,
159
+ sep="<|eot_id|>",
160
+ sep2="<|end_of_text|>",
161
+ )
162
+
163
+
164
+ default_conversation = conv_auto
165
+ conv_templates = {
166
+ "auto": conv_auto,
167
+ "hermes-2": hermes_2,
168
+ "llama_3": llama_3_chat,
169
+ "v1": conv_vicuna_v1,
170
+ "vicuna_v1": conv_vicuna_v1,
171
+ "plain": conv_llava_plain,
172
+ }
173
+
174
+
175
+ CONVERSATION_MODE_MAPPING = {
176
+ "vila1.5-3b": "vicuna_v1",
177
+ "vila1.5-8b": "llama_3",
178
+ "vila1.5-13b": "vicuna_v1",
179
+ "vila1.5-40b": "hermes-2",
180
+ "llama-3": "llama_3",
181
+ "llama3": "llama_3",
182
+ }
183
+
184
+
185
+ def auto_set_conversation_mode(model_name_or_path: str) -> str:
186
+ global default_conversation
187
+ for k, v in CONVERSATION_MODE_MAPPING.items():
188
+ if k in model_name_or_path.lower():
189
+ print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
190
+ default_conversation = conv_templates[v]
191
+ return
distributed.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from typing import Any, List, Optional
4
+
5
+ from torch import distributed as dist
6
+
7
+ __all__ = [
8
+ "init",
9
+ "is_initialized",
10
+ "size",
11
+ "rank",
12
+ "local_size",
13
+ "local_rank",
14
+ "is_main",
15
+ "barrier",
16
+ "gather",
17
+ "all_gather",
18
+ ]
19
+
20
+
21
+ def init() -> None:
22
+ if "RANK" not in os.environ:
23
+ warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
24
+ return
25
+ dist.init_process_group(backend="nccl", init_method="env://")
26
+
27
+
28
+ def is_initialized() -> bool:
29
+ return dist.is_initialized()
30
+
31
+
32
+ def size() -> int:
33
+ return int(os.environ.get("WORLD_SIZE", 1))
34
+
35
+
36
+ def rank() -> int:
37
+ return int(os.environ.get("RANK", 0))
38
+
39
+
40
+ def local_size() -> int:
41
+ return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
42
+
43
+
44
+ def local_rank() -> int:
45
+ return int(os.environ.get("LOCAL_RANK", 0))
46
+
47
+
48
+ def is_main() -> bool:
49
+ return rank() == 0
50
+
51
+
52
+ def barrier() -> None:
53
+ dist.barrier()
54
+
55
+
56
+ def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
57
+ if not is_initialized():
58
+ return [obj]
59
+ if is_main():
60
+ objs = [None for _ in range(size())]
61
+ dist.gather_object(obj, objs, dst=dst)
62
+ return objs
63
+ else:
64
+ dist.gather_object(obj, dst=dst)
65
+ return None
66
+
67
+
68
+ def all_gather(obj: Any) -> List[Any]:
69
+ if not is_initialized():
70
+ return [obj]
71
+ objs = [None for _ in range(size())]
72
+ dist.all_gather_object(objs, obj)
73
+ return objs
loss.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from torch.nn.functional import cross_entropy
5
+
6
+ from .constants import IGNORE_INDEX
7
+
8
+ __all__ = ["soft_cross_entropy"]
9
+
10
+
11
+ def soft_cross_entropy(
12
+ outputs: torch.Tensor,
13
+ targets: torch.Tensor,
14
+ soft_tokens: Union[torch.Tensor, List[int]],
15
+ std: float = 1,
16
+ ignore_index: int = IGNORE_INDEX,
17
+ ) -> torch.Tensor:
18
+ # Remove last token from outputs and first token from targets
19
+ outputs = outputs[..., :-1, :].contiguous()
20
+ targets = targets[..., 1:].contiguous()
21
+
22
+ # Flatten outputs and targets
23
+ targets = targets.view(-1)
24
+ outputs = outputs.view(targets.size(0), -1)
25
+
26
+ # Remove outputs and targets with ignore_index
27
+ indices = targets != ignore_index
28
+ outputs = outputs[indices]
29
+ targets = targets[indices]
30
+
31
+ # Convert soft token IDs to tensor
32
+ if isinstance(soft_tokens, list):
33
+ soft_tokens = torch.tensor(soft_tokens).to(targets)
34
+
35
+ # Calculate loss for non-soft tokens
36
+ indices = torch.isin(targets, soft_tokens, invert=True)
37
+ loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
38
+
39
+ # Calculate loss for soft tokens
40
+ indices = torch.isin(targets, soft_tokens)
41
+ targets_indices = torch.zeros_like(outputs[indices])
42
+ for k, target in enumerate(targets[indices]):
43
+ dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
44
+ targets_indices[k][soft_tokens] = dist / dist.sum()
45
+ loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
46
+
47
+ # Return average loss
48
+ return loss / targets.size(0)
media.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from collections import defaultdict
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import PIL
9
+ import PIL.Image
10
+ import requests
11
+ from transformers import PretrainedConfig
12
+
13
+ # from llava.constants import MEDIA_TOKENS
14
+ # from llava.media import Image, Video
15
+ # from llava.utils import make_list
16
+ # from llava.utils.logging import logger
17
+
18
+ MEDIA_TOKENS = {
19
+ "image": "<image>",
20
+ "video": "<vila/video>",
21
+ }
22
+
23
+
24
+ class Media:
25
+ pass
26
+
27
+
28
+ class File(Media):
29
+ def __init__(self, path: str) -> None:
30
+ self.path = path
31
+
32
+
33
+ class Image(File):
34
+ pass
35
+
36
+
37
+ class Video(File):
38
+ pass
39
+
40
+
41
+ def make_list(obj: Any) -> List:
42
+ return obj if isinstance(obj, list) else [obj]
43
+
44
+
45
+ def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
46
+ if isinstance(image, Image):
47
+ if image.path.startswith("http://") or image.path.startswith("https://"):
48
+ image = PIL.Image.open(requests.get(image.path, stream=True).raw)
49
+ else:
50
+ image = PIL.Image.open(image.path)
51
+ return image
52
+
53
+
54
+ def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
55
+ # Load video frames from a directory
56
+ if os.path.isdir(video_path):
57
+ frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
58
+ indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
59
+ return [PIL.Image.open(frame_paths[index]) for index in indices]
60
+
61
+ # Load video frames from a video file
62
+ vidcap = cv2.VideoCapture(video_path)
63
+
64
+ # Find the last frame as frame count might not be accurate
65
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ while frame_count > 0:
67
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
68
+ if vidcap.grab():
69
+ break
70
+ frame_count -= 1
71
+ else:
72
+ raise ValueError(f"Video '{video_path}' has no frames.")
73
+
74
+ # Extract frames uniformly
75
+ indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
76
+ frames = {}
77
+ for index in indices:
78
+ if index in frames:
79
+ continue
80
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
81
+ success, frame = vidcap.read()
82
+ if not success:
83
+ print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
84
+ continue
85
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+ frames[index] = PIL.Image.fromarray(frame)
87
+ return [frames[index] for index in indices if index in frames]
88
+
89
+
90
+ def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
91
+ num_frames = config.num_video_frames
92
+ video_path = video.path if isinstance(video, Video) else video["path"]
93
+ frames = _load_video(video_path, num_frames=num_frames)
94
+ return frames
95
+
96
+
97
+ def extract_media(
98
+ messages: List[Dict[str, Any]],
99
+ config: Optional[PretrainedConfig] = None,
100
+ draft: bool = False,
101
+ ) -> Dict[str, List[Any]]:
102
+ media = defaultdict(list)
103
+ for message in messages:
104
+ text = ""
105
+ for part in make_list(message["value"]):
106
+ if isinstance(part, str):
107
+ for token in MEDIA_TOKENS.values():
108
+ if token in part:
109
+ print(f"Media token '{token}' found in text: '{part}'. Removed.")
110
+ part = part.replace(token, "").strip()
111
+ text += part
112
+ elif isinstance(part, (Image, PIL.Image.Image)):
113
+ if draft:
114
+ media["image"].append(part)
115
+ else:
116
+ media["image"].append(_extract_image(part))
117
+ text += MEDIA_TOKENS["image"]
118
+ elif isinstance(part, dict) or isinstance(part, Video):
119
+ if draft:
120
+ media["video"].append(part)
121
+ else:
122
+ media["video"].append(_extract_video(part, config))
123
+ text += MEDIA_TOKENS["video"]
124
+ else:
125
+ raise ValueError(f"Unsupported prompt part type: {type(part)}")
126
+ message["value"] = text
127
+
128
+ if MEDIA_TOKENS["video"] in messages[0]["value"]:
129
+ messages[0]["value"] = "<vila/video>" + messages[0]["value"].replace("<vila/video>", "")
130
+ return media
media_encoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class BaseEncoder(nn.Module):
9
+ def __init__(self, parent: nn.Module) -> None:
10
+ super().__init__()
11
+ self._parent = [parent]
12
+
13
+ @property
14
+ def parent(self) -> nn.Module:
15
+ return self._parent[0]
16
+
17
+
18
+ class BasicImageEncoder(BaseEncoder):
19
+ def __init__(
20
+ self,
21
+ parent: torch.nn.Module,
22
+ start_tokens: Optional[str] = None,
23
+ end_tokens: Optional[str] = "\n",
24
+ ) -> None:
25
+ super().__init__(parent)
26
+ self.start_tokens = start_tokens
27
+ self.end_tokens = end_tokens
28
+
29
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
30
+ if tokens is None:
31
+ return None
32
+ token_ids = self.parent.tokenizer(tokens).input_ids
33
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
34
+ return self.parent.llm_model_embed_tokens(token_ids)
35
+
36
+ def _process_features(
37
+ self,
38
+ features: torch.Tensor,
39
+ start_token_embeds: Optional[torch.Tensor],
40
+ end_token_embeds: Optional[torch.Tensor],
41
+ ) -> torch.Tensor:
42
+ if start_token_embeds is not None:
43
+ features = torch.cat([start_token_embeds, features], dim=0)
44
+ if end_token_embeds is not None:
45
+ features = torch.cat([features, end_token_embeds], dim=0)
46
+ return features
47
+
48
+ def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]:
49
+ images = torch.stack(images, dim=0)
50
+ features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
51
+ process_features = partial(
52
+ self._process_features,
53
+ start_token_embeds=self.embed_tokens(self.start_tokens),
54
+ end_token_embeds=self.embed_tokens(self.end_tokens),
55
+ )
56
+ return [process_features(f).to(device) for f in features]
57
+
58
+
59
+ class BasicVideoEncoder(BaseEncoder):
60
+ def __init__(
61
+ self,
62
+ parent: torch.nn.Module,
63
+ start_tokens: Optional[str] = None,
64
+ end_tokens: Optional[str] = "\n",
65
+ ) -> None:
66
+ super().__init__(parent)
67
+ self.start_tokens = start_tokens
68
+ self.end_tokens = end_tokens
69
+
70
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
71
+ if tokens is None:
72
+ return None
73
+ token_ids = self.parent.tokenizer(tokens).input_ids
74
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
75
+ return self.parent.llm_model_embed_tokens(token_ids)
76
+
77
+ def _process_features(
78
+ self,
79
+ features: torch.Tensor,
80
+ start_token_embeds: Optional[torch.Tensor],
81
+ end_token_embeds: Optional[torch.Tensor],
82
+ ) -> torch.Tensor:
83
+ if start_token_embeds is not None:
84
+ start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
85
+ features = torch.cat([start_embeds, features], dim=1)
86
+ if end_token_embeds is not None:
87
+ end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
88
+ features = torch.cat([features, end_embeds], dim=1)
89
+ return features.flatten(0, 1)
90
+
91
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
92
+ num_frames = [video.shape[0] for video in videos]
93
+ images = torch.cat(videos, dim=0)
94
+ features = self.parent.encode_images(images)
95
+ features = torch.split(features, num_frames)
96
+ process_features = partial(
97
+ self._process_features,
98
+ start_token_embeds=self.embed_tokens(self.start_tokens),
99
+ end_token_embeds=self.embed_tokens(self.end_tokens),
100
+ )
101
+ return [process_features(f) for f in features]
102
+
103
+ def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
104
+ if x.shape[dim] % size == 0:
105
+ return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
106
+ else:
107
+ return x.narrow(dim, start=0, length=1)
108
+
109
+ class TSPVideoEncoder(BasicVideoEncoder):
110
+ def __init__(
111
+ self,
112
+ parent: torch.nn.Module,
113
+ #pool_sizes: List[Tuple[int, int, int]],
114
+ start_tokens: Optional[str] = None,
115
+ end_tokens: Optional[str] = "\n",
116
+ sep_tokens: Optional[str] = None,
117
+ ) -> None:
118
+ super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
119
+ self.pool_sizes = [[8, 1, 1]] #pool_sizes
120
+ self.sep_tokens = sep_tokens
121
+
122
+ def _process_features(
123
+ self,
124
+ inputs: torch.Tensor,
125
+ start_token_embeds: Optional[torch.Tensor],
126
+ end_token_embeds: Optional[torch.Tensor],
127
+ sep_token_embeds: Optional[torch.Tensor],
128
+ ) -> torch.Tensor:
129
+ nt, ns = inputs.shape[:2]
130
+ nl = int(ns**0.5)
131
+ outputs = []
132
+ for pool_size in self.pool_sizes:
133
+ features = inputs.view(nt, nl, nl, -1)
134
+ for dim, p in enumerate(pool_size):
135
+ features = pool(features, p, dim=dim)
136
+ features = features.flatten(1, 2)
137
+ features = super()._process_features(
138
+ features,
139
+ start_token_embeds=start_token_embeds,
140
+ end_token_embeds=end_token_embeds,
141
+ )
142
+ if sep_token_embeds is not None:
143
+ features = torch.cat([features, sep_token_embeds], dim=0)
144
+ outputs.append(features)
145
+ return torch.cat(outputs, dim=0)
146
+
147
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
148
+ num_frames = [video.shape[0] for video in videos]
149
+ images = torch.cat(videos, dim=0)
150
+ features = self.parent.encode_images(images)
151
+ features = torch.split(features, num_frames)
152
+ process_features = partial(
153
+ self._process_features,
154
+ start_token_embeds=self.embed_tokens(self.start_tokens),
155
+ end_token_embeds=self.embed_tokens(self.end_tokens),
156
+ sep_token_embeds=self.embed_tokens(self.sep_tokens),
157
+ )
158
+ return [process_features(f) for f in features]
mm_utils.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
18
+
19
+ import base64
20
+ import os
21
+ import tempfile
22
+ from io import BytesIO
23
+
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ from transformers import StoppingCriteria
28
+
29
+ from .constants import DEFAULT_IMAGE_TOKEN
30
+
31
+
32
+ def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
33
+ import cv2
34
+
35
+ if fps == None or frame_count == None:
36
+ # if one of fps or frame_count is None, still recompute
37
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
38
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
+ if fps == 0 or frame_count == 0:
40
+ print(f"Video file not found. return empty images. {video_file_name}")
41
+ return [
42
+ Image.new("RGB", (720, 720)),
43
+ ] * num_frames, 0
44
+
45
+ duration = frame_count / fps
46
+ frame_interval = frame_count // num_frames
47
+ if frame_interval == 0 and frame_count <= 1:
48
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
49
+ return [
50
+ Image.new("RGB", (720, 720)),
51
+ ] * num_frames, 0
52
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
53
+
54
+ images = []
55
+ count = 0
56
+ success = True
57
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
58
+ while success:
59
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
60
+ if frame_count >= num_frames:
61
+ success, frame = vidcap.read()
62
+ if count in frame_indices:
63
+ try:
64
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ im_pil = Image.fromarray(img)
66
+ images.append(im_pil)
67
+ except BaseException:
68
+ continue
69
+ if len(images) >= num_frames:
70
+ return images, num_frames
71
+ count += 1
72
+ else:
73
+ # Left padding frames if the video is not long enough
74
+ success, frame = vidcap.read()
75
+ if success:
76
+ try:
77
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
78
+ im_pil = Image.fromarray(img)
79
+ images.append(im_pil)
80
+ except BaseException:
81
+ continue
82
+ count += 1
83
+ else:
84
+ break
85
+ if len(images) == 0:
86
+ raise ValueError("Did not find enough frames in the video. return empty image.")
87
+
88
+ return images, len(images)
89
+
90
+
91
+ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
92
+ """
93
+ num_frames is the max number of frames the model can support.
94
+ frame_count is the number of frames in the input video.
95
+ max_fps is the max FPS of the model can support.
96
+ fps is the fps of the input video.
97
+ """
98
+
99
+ import random
100
+
101
+ import cv2
102
+
103
+ if fps == None or frame_count == None:
104
+ # if one of fps or frame_count is None, still recompute
105
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
106
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
107
+
108
+ if fps == 0 or frame_count == 0:
109
+ print(f"Video file not found. return empty images. {video_file_name}")
110
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
111
+ return [
112
+ Image.new("RGB", (720, 720)),
113
+ ] * empty_video_frames, 0
114
+
115
+ duration = frame_count / fps
116
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
117
+ # If the video is too long (longer than max_fps and num_frames can support),
118
+ # we will use lower fps to sample frames.
119
+ if duration >= num_frames / max_fps:
120
+ frame_interval = frame_count // num_frames
121
+
122
+ # If the video is too short, we will skip the video if there is only one frame.
123
+ if frame_interval == 0 and frame_count <= 1:
124
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
125
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
126
+ return [
127
+ Image.new("RGB", (720, 720)),
128
+ ] * empty_video_frames, 0
129
+
130
+ images = []
131
+ count = 0
132
+ success = True
133
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
134
+
135
+ while success:
136
+ if frame_count >= num_frames:
137
+ # success, frame = vidcap.read()
138
+ if count in frame_indices:
139
+ success, frame = vidcap.read()
140
+ try:
141
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
142
+ im_pil = Image.fromarray(img)
143
+ images.append(im_pil)
144
+ except:
145
+ # print("Failed to read frame:", count)
146
+ continue
147
+ if len(images) >= num_frames:
148
+ return images, num_frames
149
+ else:
150
+ success = vidcap.grab()
151
+ count += 1
152
+ else:
153
+ # Left padding frames if the video is not long enough
154
+ success, frame = vidcap.read()
155
+ if success:
156
+ try:
157
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
158
+ im_pil = Image.fromarray(img)
159
+ images.append(im_pil)
160
+ except:
161
+ # print("Failed to read frame:", count)
162
+ continue
163
+ count += 1
164
+ else:
165
+ break
166
+ else:
167
+ frames_required = int(duration * max_fps)
168
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
169
+ if frames_required == 0:
170
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
171
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
172
+ return [
173
+ Image.new("RGB", (720, 720)),
174
+ ] * empty_video_frames, 0
175
+ elif frames_required == 1:
176
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
177
+ images = []
178
+ count = 0
179
+ looked = 0
180
+ success = True
181
+
182
+ while success:
183
+ success, frame = vidcap.read()
184
+ if success and (looked in frame_indices):
185
+ try:
186
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
187
+ im_pil = Image.fromarray(img)
188
+ images.append(im_pil)
189
+ except:
190
+ continue
191
+ count += 1
192
+ looked += 1
193
+
194
+ if len(images) == 0:
195
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
196
+ return [
197
+ Image.new("RGB", (720, 720)),
198
+ ] * empty_video_frames, 0
199
+ else:
200
+ return images, len(images)
201
+
202
+
203
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
204
+ """
205
+ Extract frames from a video using OpenCV.
206
+
207
+ Args:
208
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
209
+ frames (int): Number of frames to extract from the video.
210
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
211
+
212
+ Returns:
213
+ list: List of PIL Images extracted from the video.
214
+
215
+ Raises:
216
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
217
+ """
218
+ import cv2
219
+
220
+ if isinstance(vpath_or_bytesio, str):
221
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
222
+ if max_fps > 0.0:
223
+ return get_frame_from_vcap_with_fps(
224
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
225
+ )
226
+ return get_frame_from_vcap(
227
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
228
+ )
229
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
230
+ # assuming mp4
231
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
232
+ temp_video.write(vpath_or_bytesio.read())
233
+ temp_video_name = temp_video.name
234
+ vidcap = cv2.VideoCapture(temp_video_name)
235
+ if max_fps > 0.0:
236
+ return get_frame_from_vcap_with_fps(
237
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
238
+ )
239
+ return get_frame_from_vcap(
240
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
241
+ )
242
+ else:
243
+ raise NotImplementedError(type(vpath_or_bytesio))
244
+
245
+
246
+ def load_image_from_base64(image):
247
+ return Image.open(BytesIO(base64.b64decode(image)))
248
+
249
+
250
+ def expand2square(pil_img, background_color):
251
+ """
252
+ Expand the given PIL image to a square shape by adding padding.
253
+
254
+ Parameters:
255
+ - pil_img: The PIL image to be expanded.
256
+ - background_color: The color of the padding to be added.
257
+
258
+ Returns:
259
+ - The expanded PIL image.
260
+
261
+ If the image is already square, it is returned as is.
262
+ If the image is wider than it is tall, padding is added to the top and bottom.
263
+ If the image is taller than it is wide, padding is added to the left and right.
264
+ """
265
+ width, height = pil_img.size
266
+ if pil_img.mode == "L":
267
+ background_color = background_color[0]
268
+ if width == height:
269
+ return pil_img
270
+ elif width > height:
271
+ result = Image.new(pil_img.mode, (width, width), background_color)
272
+ result.paste(pil_img, (0, (width - height) // 2))
273
+ return result
274
+ else:
275
+ result = Image.new(pil_img.mode, (height, height), background_color)
276
+ result.paste(pil_img, ((height - width) // 2, 0))
277
+ return result
278
+
279
+
280
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
281
+ best_ratio_diff = float("inf")
282
+ best_ratio = (1, 1)
283
+ area = width * height
284
+ for ratio in target_ratios:
285
+ target_aspect_ratio = ratio[0] / ratio[1]
286
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
287
+ if ratio_diff < best_ratio_diff:
288
+ best_ratio_diff = ratio_diff
289
+ best_ratio = ratio
290
+ elif ratio_diff == best_ratio_diff:
291
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
292
+ best_ratio = ratio
293
+ return best_ratio
294
+
295
+
296
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
297
+ orig_width, orig_height = image.size
298
+ aspect_ratio = orig_width / orig_height
299
+
300
+ # calculate the existing image aspect ratio
301
+ target_ratios = {
302
+ (i, j)
303
+ for n in range(min_num, max_num + 1)
304
+ for i in range(1, n + 1)
305
+ for j in range(1, n + 1)
306
+ if i * j <= max_num and i * j >= min_num
307
+ }
308
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
309
+
310
+ # find the closest aspect ratio to the target
311
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
312
+
313
+ # calculate the target width and height
314
+ target_width = image_size * target_aspect_ratio[0]
315
+ target_height = image_size * target_aspect_ratio[1]
316
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
317
+
318
+ # resize the image
319
+ resized_img = image.resize((target_width, target_height))
320
+ processed_images = []
321
+ for i in range(blocks):
322
+ box = (
323
+ (i % (target_width // image_size)) * image_size,
324
+ (i // (target_width // image_size)) * image_size,
325
+ ((i % (target_width // image_size)) + 1) * image_size,
326
+ ((i // (target_width // image_size)) + 1) * image_size,
327
+ )
328
+ # split the image
329
+ split_img = resized_img.crop(box)
330
+ processed_images.append(split_img)
331
+ assert len(processed_images) == blocks
332
+ if use_thumbnail and len(processed_images) != 1:
333
+ thumbnail_img = image.resize((image_size, image_size))
334
+ processed_images.append(thumbnail_img)
335
+ return processed_images
336
+
337
+
338
+ def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
339
+ orig_width, orig_height = image.size
340
+ aspect_ratio = orig_width / orig_height
341
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
342
+
343
+ processed_images = []
344
+
345
+ ##########################################################################################
346
+ ############# Add tiles for all but the last scale using fixed squre ratio ###############
347
+ ##########################################################################################
348
+
349
+ for scale in s2_scales[:-1]:
350
+ target_width = image_size * (scale // s2_scales[0])
351
+ target_height = image_size * (scale // s2_scales[0])
352
+ blocks = (scale // s2_scales[0]) ** 2
353
+
354
+ # resize the image
355
+ resized_img = image.resize((target_width, target_height))
356
+ for i in range(blocks):
357
+ box = (
358
+ (i % (target_width // image_size)) * image_size,
359
+ (i // (target_width // image_size)) * image_size,
360
+ ((i % (target_width // image_size)) + 1) * image_size,
361
+ ((i // (target_width // image_size)) + 1) * image_size,
362
+ )
363
+ # split the image
364
+ split_img = resized_img.crop(box)
365
+ processed_images.append(split_img)
366
+
367
+ ##########################################################################################
368
+ ################ Add tiles for the last scale using dynamic aspect ratio #################
369
+ ##########################################################################################
370
+
371
+ # calculate the existing image aspect ratio
372
+ target_ratios = {
373
+ (i, j)
374
+ for n in range(min_num, max_num + 1)
375
+ for i in range(1, n + 1)
376
+ for j in range(1, n + 1)
377
+ if i * j <= max_num and i * j >= min_num
378
+ }
379
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
380
+
381
+ # find the closest aspect ratio to the target
382
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
383
+
384
+ # calculate the target width and height
385
+ target_width = image_size * target_aspect_ratio[0]
386
+ target_height = image_size * target_aspect_ratio[1]
387
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
388
+
389
+ # resize the image
390
+ resized_img = image.resize((target_width, target_height))
391
+ for i in range(blocks):
392
+ box = (
393
+ (i % (target_width // image_size)) * image_size,
394
+ (i // (target_width // image_size)) * image_size,
395
+ ((i % (target_width // image_size)) + 1) * image_size,
396
+ ((i // (target_width // image_size)) + 1) * image_size,
397
+ )
398
+ # split the image
399
+ split_img = resized_img.crop(box)
400
+ processed_images.append(split_img)
401
+
402
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
403
+
404
+
405
+ def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
406
+ prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
407
+ idx = 0
408
+ all_images = []
409
+ for img in images:
410
+ processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
411
+ all_images.append(processed_images)
412
+ prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
413
+ idx += 2
414
+ prompt = "".join(prompt)
415
+ if all_images:
416
+ all_images = torch.cat(all_images)
417
+ else:
418
+ all_images = None
419
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
420
+ return all_images, prompt
421
+
422
+
423
+ def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
424
+ idx = 0
425
+ all_images = []
426
+ all_block_size = []
427
+ for img in images:
428
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
429
+ all_images.append(processed_images)
430
+ all_block_size.append(block_size)
431
+ idx += 2
432
+ if all_images:
433
+ all_images = torch.cat(all_images)
434
+ else:
435
+ all_images = None
436
+ return all_images, all_block_size
437
+
438
+
439
+ def process_image(
440
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
441
+ ):
442
+ processor = data_args.image_processor
443
+ if isinstance(image_file, str):
444
+ if image_folder is not None:
445
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
446
+ else:
447
+ image = Image.open(image_file).convert("RGB")
448
+ else:
449
+ # image is stored in bytearray
450
+ image = image_file
451
+ image = image.convert("RGB")
452
+ if hasattr(data_args.image_processor, "crop_size"):
453
+ # CLIP vision tower
454
+ crop_size = data_args.image_processor.crop_size
455
+ else:
456
+ # SIGLIP vision tower
457
+ assert hasattr(data_args.image_processor, "size")
458
+ crop_size = data_args.image_processor.size
459
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
460
+ assert crop_size["height"] == crop_size["width"]
461
+ images, block_size = dynamic_s2_preprocess(
462
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
463
+ )
464
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
465
+ return torch.stack(images), block_size
466
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
467
+ assert crop_size["height"] == crop_size["width"]
468
+ if max_tiles is not None:
469
+ max_num = max_tiles
470
+ else:
471
+ max_num = data_args.max_tiles
472
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
473
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
474
+ return torch.stack(images)
475
+
476
+ if data_args.image_aspect_ratio == "resize":
477
+ image = image.resize((crop_size["width"], crop_size["height"]))
478
+ if data_args.image_aspect_ratio == "pad":
479
+
480
+ def expand2square(pil_img, background_color):
481
+ width, height = pil_img.size
482
+ if width == height:
483
+ return pil_img
484
+ elif width > height:
485
+ result = Image.new(pil_img.mode, (width, width), background_color)
486
+ result.paste(pil_img, (0, (width - height) // 2))
487
+ return result
488
+ else:
489
+ result = Image.new(pil_img.mode, (height, height), background_color)
490
+ result.paste(pil_img, ((height - width) // 2, 0))
491
+ return result
492
+
493
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
494
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
495
+ else:
496
+ # Using default behavior of the vision encoder
497
+ # For CLIP, default is central crop
498
+ # For Radio, default is central crop
499
+ # For Siglip, default is resize
500
+ # For InternVIT, default is resize
501
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
502
+ return image
503
+
504
+
505
+ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
506
+ model_cfg.image_processor = image_processor
507
+ new_images = [
508
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
509
+ for image in images
510
+ ]
511
+
512
+ if all(x.shape == new_images[0].shape for x in new_images):
513
+ if len(new_images[0].shape) == 4:
514
+ new_images = torch.cat(new_images, dim=0)
515
+ elif len(new_images[0].shape) == 3:
516
+ new_images = torch.stack(new_images, dim=0)
517
+ else:
518
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
519
+ else:
520
+ raise ValueError("The shape of images in new_images is different!")
521
+ return new_images
522
+
523
+
524
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
525
+ if return_ids:
526
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
527
+ else:
528
+ return tokenizer(prompt, return_tensors=return_tensors)
529
+
530
+
531
+ def is_gemma_tokenizer(tokenizer):
532
+ return "gemma" in tokenizer.__class__.__name__.lower()
533
+
534
+
535
+ def get_model_name_from_path(model_path):
536
+ model_path = model_path.strip("/")
537
+ model_paths = model_path.split("/")
538
+ if model_paths[-1].startswith("checkpoint-"):
539
+ return model_paths[-2] + "_" + model_paths[-1]
540
+ else:
541
+ return model_paths[-1]
542
+
543
+
544
+ class KeywordsStoppingCriteria(StoppingCriteria):
545
+ def __init__(self, keywords, tokenizer, input_ids):
546
+ self.keywords = keywords
547
+ self.keyword_ids = []
548
+ self.max_keyword_len = 0
549
+ for keyword in keywords:
550
+ cur_keyword_ids = tokenizer(keyword).input_ids
551
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
552
+ cur_keyword_ids = cur_keyword_ids[1:]
553
+ if len(cur_keyword_ids) > self.max_keyword_len:
554
+ self.max_keyword_len = len(cur_keyword_ids)
555
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
556
+ self.tokenizer = tokenizer
557
+ self.start_len = input_ids.shape[1]
558
+
559
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
560
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
561
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
562
+ for keyword_id in self.keyword_ids:
563
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
564
+ return True
565
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
566
+ for keyword in self.keywords:
567
+ if keyword in outputs:
568
+ return True
569
+ return False
570
+
571
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
572
+ outputs = []
573
+ for i in range(output_ids.shape[0]):
574
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
575
+ return all(outputs)
model_utils_packing.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import transformers
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ __all__ = ["patch"]
10
+
11
+
12
+ def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
13
+ if hasattr(_get_unpad_data, "seqlens_in_batch"):
14
+ seqlens_in_batch = _get_unpad_data.seqlens_in_batch
15
+ else:
16
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
17
+
18
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
19
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
20
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
21
+ return indices, cu_seqlens, max_seqlen_in_batch
22
+
23
+
24
+ def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
25
+ _get_unpad_data.seqlens_in_batch = seqlens_in_batch
26
+
27
+
28
+ def patch(model: nn.Module) -> None:
29
+ if transformers.__version__ < "4.43.0":
30
+ m = import_module(model.__module__)
31
+ if not hasattr(m, "_get_unpad_data"):
32
+ raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
33
+ m._get_unpad_data = _get_unpad_data
34
+ else:
35
+ transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
modeling_vila.py ADDED
@@ -0,0 +1,1256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import os.path
7
+ import os.path as osp
8
+ import shutil
9
+ import warnings
10
+ from abc import ABC
11
+ from collections import OrderedDict, defaultdict, deque
12
+ from copy import deepcopy
13
+ from itertools import chain
14
+ from threading import Thread
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision
22
+ from einops import rearrange
23
+ from PIL import Image
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModel,
27
+ AutoProcessor,
28
+ AutoTokenizer,
29
+ GenerationConfig,
30
+ LogitsProcessor,
31
+ PretrainedConfig,
32
+ PreTrainedModel,
33
+ Qwen2Config,
34
+ Qwen2ForCausalLM,
35
+ Qwen2PreTrainedModel,
36
+ TextIteratorStreamer,
37
+ )
38
+ from transformers.modeling_outputs import CausalLMOutputWithPast
39
+ from transformers.modeling_utils import ContextManagers, no_init_weights
40
+
41
+ from .auto_processor import VILAProcessor
42
+ from .base_projector import MultimodalProjector, MultimodalProjectorConfig
43
+ from .builder import build_llm_and_tokenizer
44
+ from .configuration_vila import VILAConfig
45
+ from .constants import *
46
+ from .conversation import SeparatorStyle, default_conversation
47
+ from .distributed import all_gather as vila_all_gather
48
+ from .loss import soft_cross_entropy
49
+ from .media import extract_media
50
+ from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder
51
+ from .mm_utils import process_image, process_images
52
+ from .model_utils_packing import set_seqlens_in_batch
53
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
54
+ from .tokenizer_utils import tokenize_conversation
55
+ from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
56
+
57
+ # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
58
+
59
+ # ease debugging
60
+ python_input = input
61
+
62
+
63
+ # quick hack for remote code
64
+ def get_pg_manager():
65
+ return None
66
+
67
+
68
+ def get_model_weights_dtype(model: nn.Module):
69
+ pass
70
+
71
+
72
+ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
73
+ if model_type_or_path is None:
74
+ return None
75
+ ## load from pretrained model
76
+ if config.resume_path:
77
+ assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
78
+ return MultimodalProjector.from_pretrained(model_type_or_path, config)
79
+ ## build from scratch
80
+ else:
81
+ mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
82
+ mm_projector = MultimodalProjector(mm_projector_cfg, config)
83
+ return mm_projector
84
+
85
+
86
+ def check_dot_in_model_path(model_path: str):
87
+ """Check if the model path contains dot, which will affect the remote code loading."""
88
+ if osp.isdir(model_path): # local model
89
+ if "." in osp.abspath(model_path):
90
+ return True
91
+ else: # remote model
92
+ if "." in model_path:
93
+ return True
94
+ return False
95
+
96
+
97
+ def get_vila_version(model_path: str) -> str:
98
+ VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
99
+ for version in VERSIONS:
100
+ if version in model_path.lower():
101
+ return version
102
+ return None
103
+
104
+
105
+ def generate_jinja_template(conv_mode: str) -> str:
106
+ if conv_mode == "vicuna_v1":
107
+ return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
108
+ {% set roles = ["user", "assistant"] %}
109
+ {% set sep = " " %}
110
+
111
+ {{ system_prompt }}
112
+
113
+ {% for message in messages %}
114
+ {% if message['role'] == roles[0] %}
115
+ {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
116
+ {% else %}
117
+ {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
118
+ {% endif %}
119
+ {% endfor %}
120
+ {% if messages[-1]['role'] == 'user' %}
121
+ {{ "ASSISTANT:" }}
122
+ {% endif %}
123
+ """
124
+ elif conv_mode == "llama_3":
125
+ return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
126
+ {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
127
+ {% set sep = "<|eot_id|>" %}
128
+
129
+ {{ system_prompt }}
130
+ {% for message in messages %}
131
+ {% if message['role'] == 'user' %}
132
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
133
+ {% else %}
134
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
135
+ {% endif %}
136
+ {% endfor %}
137
+ {% if messages[-1]['role'] == 'user' %}
138
+ {{ roles[1] }}
139
+ {% endif %}
140
+ """
141
+ elif conv_mode == "hermes_2":
142
+ return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
143
+ {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
144
+ {% set sep = "<|im_end|>" %}
145
+
146
+ {{ system_prompt }}{{ sep }}
147
+
148
+ {% for message in messages %}
149
+ {% if message['role'] == 'user' %}
150
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
151
+ {% else %}
152
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
153
+ {% endif %}
154
+ {% endfor %}"""
155
+ else:
156
+ raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
157
+
158
+
159
+ def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
160
+ ## skip vision tower instantiation
161
+ if model_name_or_path is None:
162
+ return None
163
+
164
+ vision_tower_arch = None
165
+ if config.resume_path and "radio" not in model_name_or_path:
166
+ assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
167
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
168
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
169
+ vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
170
+
171
+ use_s2 = getattr(config, "s2", False)
172
+ use_dynamic_s2 = getattr(config, "dynamic_s2", False)
173
+
174
+ if "siglip" in vision_tower_name:
175
+ if use_dynamic_s2:
176
+ vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
177
+ elif use_s2:
178
+ vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
179
+ else:
180
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
181
+ else:
182
+ raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
183
+
184
+ config.mm_hidden_size = (
185
+ vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
186
+ )
187
+ return vision_tower
188
+
189
+
190
+ class VILAPretrainedModel(PreTrainedModel):
191
+ config_class = VILAConfig
192
+ main_input_name = "input_embeds"
193
+ supports_gradient_checkpointing = True
194
+ _supports_flash_attn_2 = True
195
+ _no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
196
+
197
+ def __init__(self, config: VILAConfig, *args, **kwargs):
198
+ super().__init__(config)
199
+ self.config = config
200
+ cfgs = get_model_config(config)
201
+ if len(cfgs) == 3:
202
+ llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
203
+ else:
204
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
205
+
206
+ # loading on auto by default
207
+ device_map = kwargs.get("device_map", "auto")
208
+ self.mm_projector = build_mm_projector(mm_projector_cfg, config)
209
+ self.vision_tower = build_vision_tower(vision_tower_cfg, config)
210
+ if device_map in ["auto", "cuda"]:
211
+ self.mm_projector = self.mm_projector.cuda()
212
+ self.vision_tower = self.vision_tower.cuda()
213
+ # set device_map auto can autoamtically shard llm to different devices
214
+ self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
215
+ self.llm_model_embed_tokens = self.llm.model.embed_tokens
216
+
217
+ try:
218
+ use_tsp_encoder = "TSPVideoEncoder" in getattr(config, "video_encoder", None)["_target_"]
219
+ except:
220
+ use_tsp_encoder = False
221
+ print("use_tsp_encoder", use_tsp_encoder)
222
+ self.tokenizer.padding_side = "left"
223
+ self.encoders = {"image": BasicImageEncoder(self), "video": TSPVideoEncoder(self) if use_tsp_encoder else BasicVideoEncoder(self)}
224
+
225
+ self.post_config()
226
+ self.is_loaded = True
227
+ self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
228
+ if self.llm_only_need_embed:
229
+ print("We only need the embed_tokens in llm.")
230
+ del self.llm
231
+ self.llm = None
232
+ torch.cuda.empty_cache()
233
+
234
+ assert (
235
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
236
+ ), "At least one of the components must be instantiated."
237
+
238
+ @classmethod
239
+ def convert_vila_dev_ckpt_to_remote(
240
+ self,
241
+ model_path: str,
242
+ output_dir: str = None,
243
+ vila_version: str | None = None,
244
+ conv_mode: str | None = None,
245
+ copy: bool = False,
246
+ copy_weights: bool = True,
247
+ copy_code: bool = True,
248
+ *model_args,
249
+ **kwargs,
250
+ ):
251
+ # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
252
+ assert model_path != output_dir, "model_path and output_dir cannot be the same"
253
+ if os.path.isdir(model_path):
254
+ model_path = model_path
255
+ else:
256
+ from huggingface_hub import HfApi, snapshot_download
257
+
258
+ model_path = snapshot_download(model_path)
259
+ print("downloading HF model to", model_path)
260
+
261
+ if check_dot_in_model_path(model_path) and output_dir is None:
262
+ raise ValueError(
263
+ f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
264
+ )
265
+ if output_dir is not None and "." in output_dir:
266
+ raise ValueError(
267
+ f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
268
+ )
269
+
270
+ if copy:
271
+ print("copy is set to True, copying weights and code to output_dir")
272
+ copy_weights = copy_code = True
273
+ # copy weights and code to output_dir
274
+ self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights)
275
+ self.copy_remote_py_files(output_dir, copy=copy_code)
276
+
277
+ if vila_version is None:
278
+ vila_version = get_vila_version(output_dir)
279
+
280
+ cfg_path = os.path.join(output_dir, "config.json")
281
+ config = json.load(open(cfg_path))
282
+ config["version"] = "2.0" # nvila tag
283
+ config["architectures"] = ["VILAForCausalLM"]
284
+ config["auto_map"] = {
285
+ "AutoProcessor": "auto_processor.VILAProcessor",
286
+ "AutoConfig": "modeling_vila.VILAConfig",
287
+ "AutoModel": "modeling_vila.VILAForCausalLM",
288
+ "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
289
+ }
290
+ # vila1.5 legacy support
291
+ config["model_type"] = "vila"
292
+ if vila_version in ["vila1.5", "vila-m3"]:
293
+ if conv_mode is None:
294
+ raise ValueError(f"Please specify the conversation mode for {output_dir}.")
295
+ config["chat_template"] = conv_mode
296
+ jinja_template = generate_jinja_template(conv_mode)
297
+ jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja")
298
+ with open(jinja_path, "w") as f:
299
+ f.write(jinja_template)
300
+ json.dump(config, open(cfg_path, "w"), indent=2)
301
+
302
+ ##########################################################################################
303
+ config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True)
304
+ tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config)
305
+ tokenizer.save_pretrained(osp.join(output_dir, "llm"))
306
+ ##########################################################################################
307
+
308
+ @classmethod
309
+ def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
310
+ # Create output directory if it doesn't exist
311
+ os.makedirs(output_dir, exist_ok=True)
312
+ # Create symlinks for all files in model_path to output_dir
313
+ for item in os.listdir(model_path):
314
+ src_path = os.path.join(model_path, item)
315
+ dst_path = os.path.join(output_dir, item)
316
+
317
+ # Remove existing file/directory at destination if it exists
318
+ if os.path.exists(dst_path):
319
+ if os.path.islink(dst_path):
320
+ os.unlink(dst_path)
321
+ elif os.path.isdir(dst_path):
322
+ shutil.rmtree(dst_path)
323
+ else:
324
+ os.remove(dst_path)
325
+
326
+ # Create symlink
327
+ if copy:
328
+ if os.path.isdir(src_path):
329
+ shutil.copytree(src_path, dst_path)
330
+ else:
331
+ shutil.copy2(src_path, dst_path)
332
+ print(f"Copied {src_path} to {dst_path}")
333
+ else:
334
+ os.symlink(src_path, dst_path)
335
+ print(f"Created symlink from {src_path} to {dst_path}")
336
+
337
+ @classmethod
338
+ def copy_remote_py_files(cls, output_dir, copy=True):
339
+ ## copy .py and REAMDE for next loading remote code
340
+ current_file_path = os.path.abspath(__file__)
341
+ current_folder = os.path.dirname(current_file_path)
342
+ for file_name in os.listdir(current_folder):
343
+ if file_name == "INSTRUCTIONS.md":
344
+ src_fname = os.path.join(current_folder, file_name)
345
+ dst_fname = os.path.join(output_dir, "README.md")
346
+ if os.path.exists(dst_fname):
347
+ old_reamde = open(dst_fname).read()
348
+ else:
349
+ old_reamde = ""
350
+ with open(src_fname) as src, open(dst_fname, "w") as dst:
351
+ dst.write(src.read())
352
+ dst.write(old_reamde)
353
+ print("[HF remote code] REAMDE ", src_fname, "to", dst_fname)
354
+ if file_name.endswith(".py") or file_name.endswith(".jinja"):
355
+ full_file_name = os.path.join(current_folder, file_name)
356
+ if os.path.isfile(full_file_name):
357
+ if copy:
358
+ shutil.copy(full_file_name, output_dir)
359
+ print("[HF remote code] copying", full_file_name, "to", output_dir)
360
+ else:
361
+ # symlink to ease development
362
+ if os.path.exists(os.path.join(output_dir, file_name)):
363
+ os.remove(os.path.join(output_dir, file_name))
364
+ os.symlink(full_file_name, os.path.join(output_dir, file_name))
365
+ print("[HF remote code] linking", full_file_name, "to", output_dir)
366
+
367
+ def save_pretrained(self, output_dir, state_dict=None, **kwargs):
368
+ if state_dict is None:
369
+ # other wise fetch from deepspeed
370
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
371
+ state_dict = self.state_dict()
372
+
373
+ if getattr(self, "tokenizer", None):
374
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
375
+
376
+ if self.get_llm():
377
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
378
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
379
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
380
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
381
+ self.config.llm_cfg = self.llm.config
382
+
383
+ if self.get_vision_tower():
384
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
385
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
386
+ vision_tower_state_dict = OrderedDict(
387
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
388
+ )
389
+ self.vision_tower.vision_tower.save_pretrained(
390
+ os.path.join(output_dir, "vision_tower"),
391
+ state_dict=vision_tower_state_dict,
392
+ )
393
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
394
+ self.config.vision_tower_cfg = self.vision_tower.config
395
+ if hasattr(self.config.vision_tower_cfg, "auto_map"):
396
+ if "radio" not in self.get_vision_tower().__class__.__name__.lower():
397
+ delattr(self.config.vision_tower_cfg, "auto_map")
398
+
399
+ if self.get_mm_projector():
400
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
401
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
402
+ mm_projector_state_dict = OrderedDict(
403
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
404
+ )
405
+ self.mm_projector.save_pretrained(
406
+ os.path.join(output_dir, "mm_projector"),
407
+ state_dict=mm_projector_state_dict,
408
+ )
409
+ self.config.mm_projector_cfg = self.mm_projector.config
410
+
411
+ ## update and save top-level config
412
+ self.config._name_or_path = output_dir
413
+ self.config.architectures = [self.__class__.__name__]
414
+ self.config.save_pretrained(output_dir)
415
+
416
+ ## copy .py and REAMDE for next loading remote code
417
+ self.copy_remote_py_files(output_dir)
418
+
419
+ @classmethod
420
+ def from_pretrained(
421
+ cls,
422
+ pretrained_model_name_or_path: Optional[str] = None,
423
+ *model_args,
424
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
425
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
426
+ ignore_mismatched_sizes: bool = False,
427
+ force_download: bool = False,
428
+ local_files_only: bool = False,
429
+ token: Optional[Union[str, bool]] = None,
430
+ revision: str = "main",
431
+ use_safetensors: Optional[bool] = None,
432
+ weights_only: bool = True,
433
+ **kwargs,
434
+ ):
435
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
436
+ return cls._from_config(config, **kwargs)
437
+
438
+ def init_llm(self, llm_config, config, *args, **kwargs):
439
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
440
+ # hard coded for NVILA
441
+ # variables for XGrammar
442
+ NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
443
+
444
+ self.pad_token_list = (
445
+ self.tokenizer.pad_token_id,
446
+ self.tokenizer.eos_token_id,
447
+ self.tokenizer.tokenize("<|endoftext|>")[0], # for qwen
448
+ )
449
+
450
+ self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
451
+ # XGrammar tokenizer and grammar compiler
452
+ # lazy init only when specified json output during inference
453
+ self.grammar_compiler = None
454
+ self.llm.resize_token_embeddings(len(self.tokenizer))
455
+ return self.llm, self.tokenizer
456
+
457
+ def post_config(self):
458
+ ######################################################################
459
+ self.llm = self.llm.to(torch.float16)
460
+ self.mm_projector = self.mm_projector.to(torch.float16)
461
+ self.vision_tower = self.vision_tower.to(torch.float16)
462
+ ######################################################################
463
+ self.training = self.llm.training
464
+ if self.training:
465
+ self.train()
466
+ else:
467
+ self.eval()
468
+ ## configuration
469
+ if getattr(self.config, "llm_cfg", None) is None:
470
+ self.config.llm_cfg = self.llm.config
471
+ if getattr(self.config, "vision_tower_cfg", None) is None:
472
+ self.config.vision_tower_cfg = self.vision_tower.config
473
+ if getattr(self.config, "mm_projector_cfg", None) is None:
474
+ self.config.mm_projector_cfg = self.mm_projector.config
475
+
476
+ def get_llm(self):
477
+ llm = getattr(self, "llm", None)
478
+ if type(llm) is list:
479
+ llm = llm[0]
480
+ return llm
481
+
482
+ def get_lm_head(self):
483
+ lm_head = getattr(self.get_llm(), "lm_head", None)
484
+ return lm_head
485
+
486
+ def get_vision_tower(self):
487
+ vision_tower = getattr(self, "vision_tower", None)
488
+ if type(vision_tower) is list:
489
+ vision_tower = vision_tower[0]
490
+ return vision_tower
491
+
492
+ def get_mm_projector(self):
493
+ mm_projector = getattr(self, "mm_projector", None)
494
+ if type(mm_projector) is list:
495
+ mm_projector = mm_projector[0]
496
+ return mm_projector
497
+
498
+ def freezed_module_patch(self):
499
+ """
500
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
501
+ """
502
+ if self.training:
503
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
504
+ pass
505
+ # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
506
+ if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
507
+ self.get_vision_tower().eval()
508
+ if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
509
+ self.get_mm_projector().eval()
510
+
511
+
512
+ class VILAForCausalLM(VILAPretrainedModel):
513
+ def __init__(self, config: VILAConfig, *args, **kwargs):
514
+ super().__init__(config, *args, **kwargs)
515
+
516
+ def merge_features_for_dynamic_s2(self, image_features, block_sizes):
517
+ scales = self.get_vision_tower().scales
518
+ resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
519
+
520
+ image_features_each_image = []
521
+ new_block_sizes = []
522
+ block_cnt = 0
523
+ for block_size_each_image in block_sizes:
524
+ if block_size_each_image is None:
525
+ cur_features = image_features[block_cnt : block_cnt + 1]
526
+ cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
527
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
528
+ image_features_each_image.append(cur_features)
529
+ new_block_sizes.append((1, 1))
530
+ block_cnt += 1
531
+ else:
532
+ cur_features_each_scale = []
533
+ for scale in scales[:-1]:
534
+ num_blocks_this_scale = (scale // scales[0]) ** 2
535
+ cur_features_each_scale.append(
536
+ self.merge_chessboard(
537
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
538
+ num_split_h=scale // scales[0],
539
+ num_split_w=scale // scales[0],
540
+ )
541
+ ) # 1 * C * H * W
542
+ block_cnt += num_blocks_this_scale
543
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
544
+ cur_features_each_scale.append(
545
+ self.merge_chessboard(
546
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
547
+ num_split_h=block_size_each_image[0],
548
+ num_split_w=block_size_each_image[1],
549
+ )
550
+ ) # 1 * C * H * W
551
+ block_cnt += num_blocks_last_scale
552
+
553
+ # resize and concat features from different scales
554
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
555
+ cur_features = torch.cat(
556
+ [
557
+ F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
558
+ cur_features_each_scale[i].dtype
559
+ )
560
+ for i in range(len(cur_features_each_scale))
561
+ ],
562
+ dim=1,
563
+ )
564
+ # cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
565
+
566
+ image_features_each_image.append(cur_features)
567
+
568
+ if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
569
+ new_block_sizes.append(block_size_each_image)
570
+ else:
571
+ new_block_sizes.append(
572
+ (
573
+ scales[resize_output_to_scale_idx] // scales[0],
574
+ scales[resize_output_to_scale_idx] // scales[0],
575
+ )
576
+ )
577
+
578
+ assert block_cnt == len(image_features)
579
+
580
+ return image_features_each_image, new_block_sizes
581
+
582
+ def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
583
+ if block_sizes is None:
584
+ block_sizes = [None] * len(images)
585
+ if getattr(self.config, "dynamic_s2", False):
586
+ image_features = self.get_vision_tower()(images)
587
+ image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
588
+
589
+ image_features = [
590
+ self.split_chessboard(x, block_size[0], block_size[1])
591
+ for x, block_size in zip(image_features, new_block_sizes)
592
+ ] # list of B * C * H * W tensors
593
+ image_features = torch.cat(
594
+ [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
595
+ ) # B * N * C
596
+ image_features = self.get_mm_projector()(image_features)
597
+ image_features = list(
598
+ image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
599
+ )
600
+ image_features = [
601
+ self.merge_chessboard(x, block_size[0], block_size[1])
602
+ for x, block_size in zip(image_features, new_block_sizes)
603
+ ] # list of 1 * C * H * W tensors
604
+ image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
605
+ if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
606
+ image_features = torch.stack(image_features, dim=0)
607
+ else:
608
+ image_features = self.get_vision_tower()(images)
609
+ image_features = self.get_mm_projector()(image_features)
610
+ return image_features
611
+
612
+ def train(self, mode: bool = True):
613
+ super().train(mode)
614
+ return self
615
+
616
+ @torch.inference_mode()
617
+ def _embed(
618
+ self,
619
+ input_ids: torch.Tensor,
620
+ media: Dict[str, List[torch.Tensor]],
621
+ media_config: Dict[str, Dict[str, Any]],
622
+ labels: Optional[torch.Tensor],
623
+ attention_mask: Optional[torch.Tensor],
624
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
625
+ media = copy.deepcopy(media)
626
+ media_config = copy.deepcopy(media_config)
627
+
628
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
629
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
630
+
631
+ PROCESS_GROUP_MANAGER = get_pg_manager()
632
+ if PROCESS_GROUP_MANAGER is not None:
633
+ for name in media:
634
+ self.encoders[name].end_tokens = None
635
+
636
+ # Extract text and media embeddings
637
+ text_embeds = self.llm_model_embed_tokens(input_ids)
638
+
639
+ use_cache = False
640
+ if "use_cache" in media_config:
641
+ use_cache = media_config.pop("use_cache")
642
+
643
+ if use_cache:
644
+ print("Use cached embedding")
645
+ if media is not None:
646
+ media_embeds = media if use_cache else self.__embed_media_tokens(media, media_config)
647
+ else:
648
+ # no media was provided, so we just return an empty dict
649
+ media_embeds = {}
650
+
651
+ # This is a workaround to make sure the dummy embeddings are consumed
652
+ while media_embeds.get("dummy"):
653
+ dummy_embed = media_embeds["dummy"].popleft()
654
+ text_embeds += torch.sum(dummy_embed) * 0
655
+
656
+ # Remove padding
657
+ batch_size = labels.shape[0]
658
+ text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
659
+ labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
660
+
661
+ # Build inverse mapping from token ID to media name
662
+ media_tokens = {}
663
+ for name, token_id in self.tokenizer.media_token_ids.items():
664
+ media_tokens[token_id] = name
665
+
666
+ # Fuse text and media embeddings
667
+ inputs_m, labels_m = [], []
668
+ for k in range(batch_size):
669
+ inputs_mk, labels_mk = [], []
670
+ pos = 0
671
+ while pos < len(labels[k]):
672
+ if input_ids[k][pos].item() in media_tokens:
673
+ end = pos + 1
674
+ name = media_tokens[input_ids[k][pos].item()]
675
+ input = media_embeds[name].popleft()
676
+ label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
677
+ elif input_ids[k][pos].item() in self.pad_token_list:
678
+ # skip pad tokens
679
+ end = pos + 1
680
+ pos = end
681
+ continue
682
+ else:
683
+ end = pos
684
+ while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
685
+ end += 1
686
+ input = text_embeds[k][pos:end]
687
+ label = labels[k][pos:end]
688
+
689
+ inputs_mk.append(input)
690
+ labels_mk.append(label)
691
+ pos = end
692
+ inputs_m.append(torch.cat(inputs_mk, dim=0))
693
+ labels_m.append(torch.cat(labels_mk, dim=0))
694
+ inputs, labels = inputs_m, labels_m
695
+
696
+ # Check if all media embeddings are consumed
697
+ for name in media_embeds:
698
+ if media_embeds[name]:
699
+ raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
700
+
701
+ # Truncate sequences to `model_max_length` as media embeddings are inserted
702
+ inputs, labels = self.__truncate_sequence(inputs, labels)
703
+
704
+ # Pad sequences to the longest one in the batch
705
+ return self.__batchify_sequence(inputs, labels)
706
+
707
+ def __embed_media_tokens(
708
+ self,
709
+ media: Dict[str, List[torch.Tensor]],
710
+ media_config: Dict[str, Dict[str, Any]],
711
+ ) -> Dict[str, List[torch.Tensor]]:
712
+ embeds = defaultdict(deque)
713
+ for name in media:
714
+ if self.training:
715
+ # Gather metainfo of media objects from all ranks
716
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
717
+ infos = list(chain(vila_all_gather(info)))
718
+
719
+ # The entire batch does not contain any media objects of this type.
720
+ if not infos:
721
+ continue
722
+
723
+ # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
724
+ if media.get(name) is None or len(media[name]) == 0:
725
+ dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
726
+ embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
727
+ continue
728
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
729
+ return embeds
730
+
731
+ def __truncate_sequence(
732
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
733
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
734
+ if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
735
+ warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
736
+ inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
737
+ labels = [label[: self.tokenizer.model_max_length] for label in labels]
738
+ return inputs, labels
739
+
740
+ def __batchify_sequence(
741
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
742
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
743
+ batch_size = len(inputs)
744
+ device = inputs[0].device
745
+ hidden_size = inputs[0].shape[1]
746
+ max_length = max(inputs[k].shape[0] for k in range(batch_size))
747
+ attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
748
+
749
+ inputs_p, labels_p = [], []
750
+ for k in range(batch_size):
751
+ size_pk = max_length - inputs[k].shape[0]
752
+ inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
753
+ labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
754
+ if self.tokenizer.padding_side == "right":
755
+ attention_mask[k, inputs[k].shape[0] :] = False
756
+ inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
757
+ labels_pk = torch.cat([labels[k], labels_pk], dim=0)
758
+ else:
759
+ attention_mask[k, : -inputs[k].shape[0]] = False
760
+ inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
761
+ labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
762
+ inputs_p.append(inputs_pk)
763
+ labels_p.append(labels_pk)
764
+
765
+ inputs = torch.stack(inputs_p, dim=0)
766
+ labels = torch.stack(labels_p, dim=0)
767
+ return inputs, labels, attention_mask
768
+
769
+ def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
770
+ # Handle sequence parallelism
771
+ PROCESS_GROUP_MANAGER = get_pg_manager()
772
+
773
+ # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
774
+ if PROCESS_GROUP_MANAGER is not None:
775
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
776
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
777
+ sp_group = PROCESS_GROUP_MANAGER.sp_pg
778
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
779
+ ring_rank = PROCESS_GROUP_MANAGER.ring_rank
780
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
781
+ ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
782
+ ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
783
+
784
+ bs, shard_seqlen = position_ids.shape
785
+ sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
786
+ dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
787
+ sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
788
+
789
+ if sp_rank == 0:
790
+ original_start_id = 0
791
+ else:
792
+ original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
793
+ original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
794
+
795
+ # Gather attention_mask, position_ids, labels and input_embeds
796
+ all_inputs_embeds = torch.zeros(
797
+ bs,
798
+ torch.sum(sp_seq_len_cat),
799
+ inputs_embeds.shape[-1],
800
+ dtype=inputs_embeds.dtype,
801
+ device=inputs_embeds.device,
802
+ ).contiguous()
803
+ all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
804
+ dist.barrier(group=sp_group)
805
+ dist.all_reduce(all_inputs_embeds, group=sp_group)
806
+ dist.barrier(group=sp_group)
807
+
808
+ attention_mask_list = [
809
+ torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
810
+ for i in range(sp_degree)
811
+ ]
812
+ position_ids_list = [
813
+ torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
814
+ for i in range(sp_degree)
815
+ ]
816
+ labels_list = [
817
+ torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
818
+ ]
819
+
820
+ dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
821
+ dist.all_gather(position_ids_list, position_ids, group=sp_group)
822
+ dist.all_gather(labels_list, labels, group=sp_group)
823
+
824
+ effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
825
+ effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
826
+ effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
827
+
828
+ global_attention_mask_list = []
829
+ global_position_ids_list = []
830
+ global_labels_list = []
831
+ global_inputs_embeds_list = []
832
+ for i in range(bs):
833
+ global_attention_mask_batch_list = []
834
+ global_position_ids_batch_list = []
835
+ global_labels_batch_list = []
836
+ global_inputs_embeds_batch_list = []
837
+ for j in range(sp_degree):
838
+ eff_len = effective_seqlen_batch_list[i][j]
839
+ prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
840
+
841
+ global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
842
+ global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
843
+ global_labels_batch_list.append(labels_list[j][i, :eff_len])
844
+ global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
845
+ global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
846
+ global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
847
+ global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
848
+ global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
849
+
850
+ global_attention_mask = torch.nn.utils.rnn.pad_sequence(
851
+ global_attention_mask_list, batch_first=True, padding_value=False
852
+ )
853
+ global_position_ids = torch.nn.utils.rnn.pad_sequence(
854
+ global_position_ids_list, batch_first=True, padding_value=-1
855
+ )
856
+ global_labels = torch.nn.utils.rnn.pad_sequence(
857
+ global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
858
+ )
859
+ global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
860
+ global_inputs_embeds_list, batch_first=True, padding_value=0
861
+ )
862
+
863
+ # Re-shard the inputs
864
+ if ring_degree > 1:
865
+ total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
866
+ new_seqlen_per_rank = total_effective_seqlen // sp_degree
867
+ assert torch.all(
868
+ total_effective_seqlen % sp_degree == 0
869
+ ), "total_effective_seqlen must be divisible by sp_degree"
870
+
871
+ max_new_seqlen = torch.max(new_seqlen_per_rank).item()
872
+
873
+ new_attention_mask = torch.zeros(
874
+ (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
875
+ )
876
+ new_position_ids = torch.zeros(
877
+ (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
878
+ )
879
+ new_labels = torch.full(
880
+ (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
881
+ )
882
+ new_inputs_embeds = torch.zeros(
883
+ (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
884
+ dtype=global_inputs_embeds.dtype,
885
+ device=global_inputs_embeds.device,
886
+ )
887
+
888
+ if ring_type == "ring_varlen":
889
+ for i in range(bs):
890
+ start_idx = new_seqlen_per_rank[i] * sp_rank
891
+ end_idx = start_idx + new_seqlen_per_rank[i]
892
+ new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
893
+ new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
894
+ new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
895
+ new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
896
+ i, start_idx:end_idx, :
897
+ ]
898
+ elif ring_type == "zigzag_ring_varlen":
899
+ chunk_size = total_effective_seqlen // (2 * sp_degree)
900
+ for i in range(bs):
901
+ # Zigzag pattern indices
902
+ if sp_degree == ring_degree:
903
+ forward_rank_idx = sp_rank
904
+ backward_rank_idx = 2 * sp_degree - sp_rank - 1
905
+ else:
906
+ ulysses_offset = ulysses_rank * ring_degree * 2
907
+ forward_rank_idx = ring_rank + ulysses_offset
908
+ backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
909
+
910
+ # Calculate start and end indices for the forward and backward zigzag
911
+ start_idx_fwd = forward_rank_idx * chunk_size[i]
912
+ end_idx_fwd = start_idx_fwd + chunk_size[i]
913
+
914
+ start_idx_bwd = backward_rank_idx * chunk_size[i]
915
+ end_idx_bwd = start_idx_bwd + chunk_size[i]
916
+
917
+ # Fill new tensors with zigzag data
918
+ new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
919
+ new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
920
+ i, start_idx_bwd:end_idx_bwd
921
+ ]
922
+
923
+ new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
924
+ new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
925
+ i, start_idx_bwd:end_idx_bwd
926
+ ]
927
+
928
+ new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
929
+ new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
930
+
931
+ new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
932
+ new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
933
+ i, start_idx_bwd:end_idx_bwd, :
934
+ ]
935
+ else:
936
+ raise ValueError(f"Invalid ring_type: {ring_type}")
937
+ else:
938
+ global_seq_len = global_attention_mask.shape[-1]
939
+ seq_len_sharded = global_seq_len // sp_degree
940
+ start_idx_reshard = seq_len_sharded * sp_rank
941
+ end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
942
+
943
+ new_attention_mask = torch.narrow(
944
+ global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
945
+ )
946
+ new_position_ids = torch.narrow(
947
+ global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
948
+ )
949
+ new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
950
+ new_inputs_embeds = torch.narrow(
951
+ global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
952
+ )
953
+
954
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
955
+
956
+ device = inputs_embeds.device
957
+ batch_size = inputs_embeds.shape[0]
958
+ seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
959
+
960
+ # Pack all sequences together
961
+ inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
962
+ attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
963
+ position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
964
+ labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
965
+
966
+ # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
967
+ inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
968
+ attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
969
+ position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
970
+ labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
971
+
972
+ # Mask the first token of each sequence to avoid contamination
973
+ for label in labels_p:
974
+ label[0] = IGNORE_INDEX
975
+
976
+ # Batch the data
977
+ inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
978
+ attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
979
+ position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
980
+ labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
981
+
982
+ if hasattr(
983
+ self, "pad_to_multiple_of"
984
+ ): # related to quantization, please refer to ModelArguments for more information.
985
+ assert len(labels_p.shape) == 2
986
+ batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
987
+ hidden_size = inputs_embeds_p.shape[-1]
988
+
989
+ if max_length % self.pad_to_multiple_of != 0:
990
+ max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
991
+ difference = max_length - cur_length
992
+
993
+ inputs_embeds_p = torch.cat(
994
+ (
995
+ inputs_embeds_p,
996
+ torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
997
+ ),
998
+ dim=1,
999
+ )
1000
+ labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
1001
+ attention_mask_p = torch.cat(
1002
+ (
1003
+ attention_mask_p,
1004
+ torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
1005
+ ),
1006
+ dim=1,
1007
+ )
1008
+ position_ids_p = torch.cat(
1009
+ (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
1010
+ )
1011
+
1012
+ return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
1013
+
1014
+ def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
1015
+ raise NotImplementedError("This method is not implemented for VILA model.")
1016
+ # Convert response format to logits processor
1017
+ import xgrammar as xgr
1018
+
1019
+ logging.info("[XGrammar] Compiling grammar for contrained output")
1020
+
1021
+ if self.grammar_compiler is None:
1022
+ # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
1023
+ self.grammar_compiler = xgr.GrammarCompiler(
1024
+ xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
1025
+ )
1026
+
1027
+ if response_format.type == "json_schema":
1028
+ compiled_grammar = self.grammar_compiler.compile_json_schema(
1029
+ response_format.json_schema.schema_,
1030
+ indent=2,
1031
+ )
1032
+ else:
1033
+ compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
1034
+
1035
+ return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
1036
+
1037
+ def forward(
1038
+ self,
1039
+ input_ids: torch.LongTensor = None,
1040
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1041
+ images: Optional[torch.FloatTensor] = None,
1042
+ media_config: Optional[List] = None,
1043
+ pixel_values: Optional[torch.FloatTensor] = None,
1044
+ attention_mask: Optional[torch.Tensor] = None,
1045
+ position_ids: Optional[torch.LongTensor] = None,
1046
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1047
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1048
+ labels: Optional[torch.LongTensor] = None,
1049
+ packing: bool = True,
1050
+ force_packing: bool = False,
1051
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
1052
+ dpo_forward: bool = False,
1053
+ **kwargs,
1054
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1055
+ self.freezed_module_patch()
1056
+
1057
+ if images is not None:
1058
+ if media is not None:
1059
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
1060
+ print("The 'images' argument is deprecated. Please use 'media' instead.")
1061
+ media = {"image": images}
1062
+
1063
+ if media_config is None:
1064
+ media_config = defaultdict(dict)
1065
+
1066
+ if inputs_embeds is None:
1067
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
1068
+
1069
+ if force_packing or (packing and self.training and not dpo_forward):
1070
+ if seqlens_in_batch is None:
1071
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
1072
+ set_seqlens_in_batch(seqlens_in_batch)
1073
+
1074
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
1075
+ inputs_embeds, attention_mask, position_ids, labels
1076
+ )
1077
+
1078
+ outputs = self.llm(
1079
+ inputs_embeds=inputs_embeds,
1080
+ attention_mask=attention_mask,
1081
+ position_ids=position_ids,
1082
+ past_key_values=past_key_values,
1083
+ labels=labels,
1084
+ **kwargs,
1085
+ )
1086
+
1087
+ if self.training and getattr(self.config, "time_token_ids", []):
1088
+ outputs.loss = soft_cross_entropy(
1089
+ outputs.logits,
1090
+ labels,
1091
+ soft_tokens=self.config.time_token_ids,
1092
+ std=self.config.soft_ce_std,
1093
+ )
1094
+
1095
+ if dpo_forward:
1096
+ return outputs.logits, labels
1097
+
1098
+ return outputs
1099
+
1100
+ # @torch.inference_mode()
1101
+ def generate(
1102
+ self,
1103
+ input_ids: Optional[torch.FloatTensor] = None,
1104
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1105
+ media_config: Dict[str, Dict[str, Any]] = None,
1106
+ attention_mask: Optional[torch.LongTensor] = None,
1107
+ return_output_ids_only: bool = True,
1108
+ **generation_kwargs,
1109
+ ) -> torch.LongTensor:
1110
+ """
1111
+ input_tokens: <image> describe the image
1112
+ media: [Tensor(1, 3, 384, 384), ]
1113
+ ----------->
1114
+ input_tokens: 36000 001 002 003 004
1115
+ input_emds: <media emd> 001 002 003 004
1116
+ """
1117
+ inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1118
+ output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
1119
+
1120
+ if return_output_ids_only:
1121
+ return_value = output_ids
1122
+ else:
1123
+ # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
1124
+ generation_config = generation_kwargs.get("generation_config", None)
1125
+ if generation_config is not None:
1126
+ num_generations = generation_config.num_return_sequences
1127
+ repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
1128
+ return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
1129
+ else:
1130
+ return_value = torch.cat([input_ids, output_ids], dim=-1)
1131
+
1132
+ return return_value
1133
+
1134
+ @torch.inference_mode()
1135
+ def generate_content(
1136
+ self,
1137
+ prompt: Union[str, List],
1138
+ generation_config: Optional[GenerationConfig] = None,
1139
+ response_format=None,
1140
+ ) -> str:
1141
+ conversation = [{"from": "human", "value": prompt}]
1142
+
1143
+ # Convert response format to logits processor
1144
+ xgr_logits_processor = None
1145
+
1146
+ # Extract media from the conversation
1147
+
1148
+ media = extract_media(conversation, self.config)
1149
+
1150
+ # Process media
1151
+ media_config = defaultdict(dict)
1152
+ for name in media:
1153
+ if name == "image":
1154
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
1155
+ self.config.image_processor = self.vision_tower.image_processor
1156
+ if self.config.image_aspect_ratio == "dynamic":
1157
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
1158
+ conversation[0]["value"] = conversation[0]["value"].replace(
1159
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
1160
+ )
1161
+ else:
1162
+ if type(self.config.s2_scales) is str:
1163
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1164
+ images, block_sizes = process_image(
1165
+ media["image"][0], self.config, None, enable_dynamic_s2=True
1166
+ )
1167
+ images = images.half()
1168
+ media_config[name]["block_sizes"] = [block_sizes]
1169
+ else:
1170
+ images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
1171
+ media[name] = [image for image in images]
1172
+ elif name == "video":
1173
+ if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
1174
+ media[name] = [
1175
+ process_images(
1176
+ images,
1177
+ self.vision_tower.image_processor,
1178
+ self.config,
1179
+ enable_dynamic_res=True,
1180
+ max_tiles=self.config.video_max_tiles,
1181
+ ).half()
1182
+ for images in media[name]
1183
+ ]
1184
+ elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
1185
+ self.config.image_processor = self.vision_tower.image_processor
1186
+ if type(self.config.s2_scales) is str:
1187
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1188
+ media[name] = [
1189
+ torch.cat(
1190
+ [
1191
+ process_image(
1192
+ image,
1193
+ self.config,
1194
+ None,
1195
+ enable_dynamic_s2=True,
1196
+ max_tiles=self.config.video_max_tiles,
1197
+ )[0].half()
1198
+ for image in images
1199
+ ]
1200
+ )
1201
+ for images in media[name]
1202
+ ]
1203
+ else:
1204
+ media[name] = [
1205
+ process_images(images, self.vision_tower.image_processor, self.config).half()
1206
+ for images in media[name]
1207
+ ]
1208
+ else:
1209
+ raise ValueError(f"Unsupported media type: {name}")
1210
+
1211
+ # Tokenize the conversation
1212
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
1213
+
1214
+ # Set up the generation config
1215
+ generation_config = generation_config or self.default_generation_config
1216
+
1217
+ # Generate the response
1218
+ try:
1219
+ output_ids = self.generate(
1220
+ input_ids=input_ids,
1221
+ media=media,
1222
+ media_config=media_config,
1223
+ generation_config=generation_config,
1224
+ logits_processor=xgr_logits_processor, # structured generation
1225
+ )
1226
+ except ValueError:
1227
+ if not generation_config.do_sample:
1228
+ raise
1229
+ logging.warning("Generation failed with sampling, retrying with greedy decoding.")
1230
+ generation_config.do_sample = False
1231
+ output_ids = self.generate(
1232
+ input_ids=input_ids,
1233
+ media=media,
1234
+ media_config=media_config,
1235
+ generation_config=generation_config,
1236
+ logits_processor=xgr_logits_processor,
1237
+ )
1238
+
1239
+ # Decode the response
1240
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
1241
+ return response
1242
+
1243
+ @property
1244
+ def default_generation_config(self) -> GenerationConfig:
1245
+ generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
1246
+ if self.tokenizer.eos_token_id is None:
1247
+ raise ValueError("Tokenizer must have an EOS token")
1248
+ if generation_config.max_length == GenerationConfig().max_length:
1249
+ generation_config.max_length = self.tokenizer.model_max_length
1250
+ if generation_config.pad_token_id is None:
1251
+ generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
1252
+ if generation_config.bos_token_id is None:
1253
+ generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
1254
+ if generation_config.eos_token_id is None:
1255
+ generation_config.eos_token_id = self.tokenizer.eos_token_id
1256
+ return generation_config
siglip_encoder.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from accelerate.hooks import add_hook_to_module
21
+ from einops import rearrange
22
+ from s2wrapper import forward as multiscale_forward
23
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
24
+ from transformers.image_processing_utils import BaseImageProcessor
25
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
26
+ from transformers.models.siglip import SiglipVisionModel
27
+
28
+
29
+ class VisionTower(nn.Module):
30
+ def __init__(self, vision_tower, args, delay_load=False):
31
+ super().__init__()
32
+
33
+ self.is_loaded = False
34
+
35
+ self.vision_tower_name = vision_tower
36
+ self.select_layer = getattr(args, "mm_vision_select_layer", -2)
37
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
38
+
39
+ self.cfg_only = None
40
+
41
+ def feature_select(self, image_forward_outs):
42
+ image_features = image_forward_outs.hidden_states[self.select_layer]
43
+ if self.select_feature == "patch":
44
+ image_features = image_features[:, 1:]
45
+ elif self.select_feature == "cls_patch":
46
+ image_features = image_features
47
+ else:
48
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
49
+ return image_features
50
+
51
+ def _maybe_resize_pos_embeds(
52
+ self,
53
+ model: PreTrainedModel,
54
+ image_processor: BaseImageProcessor,
55
+ resolution: int = -1,
56
+ interpolate_mode: str = "linear",
57
+ ):
58
+ if resolution in [model.config.image_size, -1]:
59
+ return
60
+ print(
61
+ f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
62
+ )
63
+ embeddings = model.vision_model.embeddings
64
+ patch_size = embeddings.patch_size
65
+ num_new_tokens = int((resolution // patch_size) ** 2)
66
+
67
+ old_embeddings = embeddings.position_embedding
68
+ match interpolate_mode:
69
+ case "linear":
70
+ ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
71
+ ## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
72
+ import torch
73
+ import torch.nn as nn
74
+
75
+ if is_deepspeed_zero3_enabled():
76
+ try:
77
+ import deepspeed
78
+ except ImportError:
79
+ raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.")
80
+ with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
81
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
82
+ else:
83
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
84
+ new_embeddings = nn.Embedding(
85
+ num_new_tokens,
86
+ old_embedding_dim,
87
+ dtype=old_embeddings.weight.dtype,
88
+ device=old_embeddings.weight.device,
89
+ )
90
+ mapped_indices = (
91
+ torch.arange(num_new_tokens).to(old_embeddings.weight.device)
92
+ / (num_new_tokens - 1)
93
+ * (old_num_tokens - 1)
94
+ )
95
+ floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
96
+ ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
97
+ if is_deepspeed_zero3_enabled():
98
+ params = [old_embeddings.weight, new_embeddings.weight]
99
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
100
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
101
+ ceil_indices, :
102
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
103
+ else:
104
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
105
+ ceil_indices, :
106
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
107
+ new_embeddings.weight.data = interpolated_embeds
108
+ case _:
109
+ raise NotImplementedError
110
+
111
+ if hasattr(old_embeddings, "_hf_hook"):
112
+ hook = old_embeddings._hf_hook
113
+ add_hook_to_module(new_embeddings, hook)
114
+ new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
115
+ ## update vision encoder's configurations
116
+ model.config.image_size = resolution
117
+ if hasattr(image_processor, "crop_size"):
118
+ # CLIP vision tower
119
+ image_processor.crop_size = resolution
120
+ else:
121
+ # SIGLIP vision tower
122
+ assert hasattr(image_processor, "size")
123
+ image_processor.size = {"height": resolution, "width": resolution}
124
+ embeddings.position_embedding = new_embeddings
125
+ embeddings.image_size = resolution
126
+ embeddings.num_patches = embeddings.num_positions = num_new_tokens
127
+ embeddings.position_ids = (
128
+ torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
129
+ )
130
+
131
+ def forward(self, images):
132
+ if type(images) is list:
133
+ image_features = []
134
+ for image in images:
135
+ image_forward_out = self.vision_tower(
136
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
137
+ output_hidden_states=True,
138
+ )
139
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
140
+ image_features.append(image_feature)
141
+ else:
142
+ image_forward_outs = self.vision_tower(
143
+ images.to(device=self.device, dtype=self.dtype),
144
+ output_hidden_states=True,
145
+ )
146
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
147
+
148
+ return image_features
149
+
150
+ @property
151
+ def dummy_feature(self):
152
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
153
+
154
+ @property
155
+ def dtype(self):
156
+ return self.vision_tower.dtype
157
+
158
+ @property
159
+ def device(self):
160
+ return self.vision_tower.device
161
+
162
+ @property
163
+ def config(self):
164
+ if self.is_loaded:
165
+ return self.vision_tower.config
166
+ else:
167
+ return self.cfg_only
168
+
169
+ @property
170
+ def hidden_size(self):
171
+ return self.config.hidden_size
172
+
173
+ @property
174
+ def num_patches(self):
175
+ return (self.config.image_size // self.config.patch_size) ** 2
176
+
177
+
178
+ class VisionTowerS2(VisionTower):
179
+ def __init__(self, vision_tower, args, delay_load=False):
180
+ super().__init__(vision_tower, args, delay_load)
181
+
182
+ self.scales = list(map(int, args.s2_scales.split(",")))
183
+ self.scales.sort()
184
+ self.max_split_size = args.s2_max_split_size
185
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
186
+
187
+ def forward_feature(self, images):
188
+ image_forward_outs = self.vision_tower(
189
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
190
+ )
191
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
192
+ return image_features
193
+
194
+ def forward(self, images):
195
+ if type(images) is list:
196
+ image_features = []
197
+ for image in images:
198
+ image_feature = multiscale_forward(
199
+ self.forward_feature,
200
+ image.unsqueeze(0),
201
+ img_sizes=self.scales,
202
+ max_split_size=self.max_split_size,
203
+ resize_output_to_idx=self.resize_output_to_scale_idx,
204
+ )
205
+ image_features.append(image_feature)
206
+ else:
207
+ image_features = multiscale_forward(
208
+ self.forward_feature,
209
+ images,
210
+ img_sizes=self.scales,
211
+ max_split_size=self.max_split_size,
212
+ resize_output_to_idx=self.resize_output_to_scale_idx,
213
+ )
214
+
215
+ return image_features
216
+
217
+ @property
218
+ def hidden_size(self):
219
+ return self.config.hidden_size * len(self.scales)
220
+
221
+
222
+ class VisionTowerDynamicS2(VisionTower):
223
+ def __init__(self, vision_tower, args, delay_load=False):
224
+ super().__init__(vision_tower, args, delay_load)
225
+
226
+ self.scales = list(map(int, args.s2_scales.split(",")))
227
+ self.scales.sort()
228
+ self.max_split_size = args.s2_max_split_size
229
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
230
+
231
+ def forward_feature(self, images):
232
+ image_forward_outs = self.vision_tower(
233
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
234
+ )
235
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
236
+ return image_features
237
+
238
+ def forward(self, images):
239
+ assert type(images) is not list
240
+ image_features = self.forward_feature(images)
241
+
242
+ return image_features
243
+
244
+ @property
245
+ def hidden_size(self):
246
+ return self.config.hidden_size * len(self.scales)
247
+
248
+
249
+ class SiglipVisionTower(VisionTower):
250
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
251
+ super().__init__(model_name_or_path, config)
252
+ self.vision_tower = SiglipVisionModel.from_pretrained(
253
+ model_name_or_path,
254
+ attn_implementation=config._attn_implementation,
255
+ torch_dtype=eval(config.model_dtype),
256
+ )
257
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
258
+ self.is_loaded = True
259
+
260
+
261
+ class SiglipVisionTowerS2(VisionTowerS2):
262
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
263
+ super().__init__(model_name_or_path, config)
264
+ self.vision_tower = SiglipVisionModel.from_pretrained(
265
+ model_name_or_path,
266
+ attn_implementation=config._attn_implementation,
267
+ torch_dtype=eval(config.model_dtype),
268
+ )
269
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
270
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
271
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
272
+ self.is_loaded = True
273
+
274
+
275
+ class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
276
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
277
+ super().__init__(model_name_or_path, config)
278
+ self.vision_tower = SiglipVisionModel.from_pretrained(
279
+ model_name_or_path,
280
+ attn_implementation=config._attn_implementation,
281
+ torch_dtype=eval(config.model_dtype),
282
+ )
283
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
284
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
285
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
286
+ self.is_loaded = True
tokenizer_utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from typing import Any, Dict, List, Optional, Sequence
18
+
19
+ import torch
20
+ import transformers
21
+
22
+ from .constants import IGNORE_INDEX, SENTINEL_TOKEN
23
+ from .conversation import SeparatorStyle, default_conversation
24
+ from .mm_utils import tokenizer_image_token
25
+
26
+ DUMMY_CONVERSATION = [
27
+ {"from": "human", "value": "question"},
28
+ {"from": "gpt", "value": "answer"},
29
+ ] * 10
30
+
31
+
32
+ def tokenize_conversation_legacy(
33
+ messages: Sequence[Dict[str, str]],
34
+ tokenizer: transformers.PreTrainedTokenizer,
35
+ add_generation_prompt: bool = False,
36
+ overrides: Optional[Dict[str, str]] = None,
37
+ no_system_prompt: bool = False,
38
+ ) -> torch.Tensor:
39
+ conv = default_conversation.copy()
40
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
41
+
42
+ if no_system_prompt:
43
+ conv.system = ""
44
+
45
+ # Skip the first message if it is not from human
46
+ if messages[0]["from"] != "human":
47
+ messages = messages[1:]
48
+
49
+ # Add a generation prompt if needed
50
+ if add_generation_prompt:
51
+ messages.append({"from": "gpt", "value": None})
52
+
53
+ conv.messages = []
54
+ for turn, message in enumerate(messages):
55
+ role = roles[message["from"]]
56
+ assert role == conv.roles[turn % 2]
57
+ if overrides is not None and message["from"] in overrides:
58
+ conv.append_message(role, overrides[message["from"]])
59
+ else:
60
+ conv.append_message(role, message["value"])
61
+
62
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
63
+
64
+
65
+ def tokenize_conversation(
66
+ messages: Sequence[Dict[str, str]],
67
+ tokenizer: transformers.PreTrainedTokenizer,
68
+ add_generation_prompt: bool = False,
69
+ overrides: Optional[Dict[str, str]] = None,
70
+ no_system_prompt: bool = False,
71
+ return_ids_only=True,
72
+ ) -> torch.Tensor:
73
+ # Normalize the conversation before tokenization
74
+ for message in messages:
75
+ message["value"] = message["value"].strip()
76
+
77
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
78
+ return tokenize_conversation_legacy(
79
+ messages,
80
+ tokenizer,
81
+ add_generation_prompt=add_generation_prompt,
82
+ overrides=overrides,
83
+ no_system_prompt=no_system_prompt,
84
+ )
85
+
86
+ conversation = []
87
+ for m in messages:
88
+ message = {}
89
+ if m["from"] == "human":
90
+ message["role"] = "user"
91
+ elif m["from"] == "gpt":
92
+ message["role"] = "assistant"
93
+ elif m["from"] == "system":
94
+ message["role"] = "system"
95
+ if no_system_prompt:
96
+ raise ValueError("message[role]=system is not allowed when no_system_prompt is set to True.")
97
+ else:
98
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
99
+
100
+ message["content"] = m["value"]
101
+ if overrides is not None and m["from"] in overrides:
102
+ message["content"] = overrides[m["from"]]
103
+ conversation.append(message)
104
+
105
+ if no_system_prompt:
106
+ conversation = [{"role": "system", "content": ""}] + conversation
107
+
108
+ text = tokenizer.apply_chat_template(
109
+ conversation,
110
+ add_generation_prompt=add_generation_prompt,
111
+ tokenize=False,
112
+ )
113
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt", return_ids=return_ids_only)
114
+
115
+
116
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
117
+ if not hasattr(tokenizer, "sentinel_token"):
118
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
119
+ tokenizer.sentinel_token = SENTINEL_TOKEN
120
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
121
+
122
+
123
+ def preprocess_conversation(
124
+ conversation: Sequence[Dict[str, str]],
125
+ tokenizer: transformers.PreTrainedTokenizer,
126
+ no_system_prompt: bool = False,
127
+ retried: bool = False,
128
+ **kwargs: Any,
129
+ ) -> Dict[str, Any]:
130
+ inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
131
+ labels = torch.ones_like(inputs) * IGNORE_INDEX
132
+
133
+ # Generate the template by replacing the assistant's response with a sentinel.
134
+ _maybe_add_sentinel_token(tokenizer)
135
+ template = tokenize_conversation(
136
+ conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
137
+ )
138
+
139
+ # Remove sentinel tokens from the template.
140
+ mask = torch.ones_like(template, dtype=torch.bool)
141
+ for k in range(template.size(0) - 1):
142
+ if template[k] == tokenizer.sentinel_token_id:
143
+ mask[k : k + 2] = False
144
+ if k > 0 and retried:
145
+ mask[k - 1] = False
146
+ template = template[mask]
147
+
148
+ # Match the tokenized conversation with the template (with no assistant's response).
149
+ # Every token that is not matched will be included in the label for training.
150
+ p = 0
151
+ for k in range(inputs.size(0)):
152
+ if p < template.size(0) and inputs[k] == template[p]:
153
+ p += 1
154
+ else:
155
+ labels[k] = inputs[k]
156
+
157
+ # Mask all tokens in the label if the template is not fully matched.
158
+ if p < template.size(0):
159
+ if not retried:
160
+ return preprocess_conversation(
161
+ conversation,
162
+ tokenizer,
163
+ no_system_prompt=no_system_prompt,
164
+ retried=True,
165
+ )
166
+ print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
167
+ labels[:] = IGNORE_INDEX
168
+
169
+ return {"input_ids": inputs, "labels": labels}
170
+
171
+
172
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
173
+ _maybe_add_sentinel_token(tokenizer)
174
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
175
+
176
+ stop_tokens = {tokenizer.eos_token}
177
+ for k in range(template.size(0) - 1):
178
+ if template[k] == tokenizer.sentinel_token_id:
179
+ stop_token = tokenizer.decode(template[k + 1])
180
+ stop_tokens.add(stop_token)
181
+ return list(stop_tokens)
utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+ import os
18
+ import os.path as osp
19
+
20
+ from huggingface_hub import repo_exists, snapshot_download
21
+ from huggingface_hub.utils import HFValidationError, validate_repo_id
22
+ from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
23
+
24
+ from .configuration_vila import VILAConfig
25
+ from .constants import MEDIA_TOKENS
26
+ from .tokenizer_utils import infer_stop_tokens
27
+
28
+
29
+ def load_tokenizer_then_handle_media_tokens_and_chat_template(
30
+ model_name_or_path, config: VILAConfig, model_max_length=None
31
+ ):
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ osp.join(model_name_or_path, "llm"), padding_side="right", use_fast=True, legacy=False
34
+ )
35
+ if model_max_length is not None:
36
+ tokenizer.model_max_length = model_max_length
37
+
38
+ # Load chat template if specified.
39
+ if getattr(config, "chat_template", None) is not None:
40
+ print(f"Using chat template: {config.chat_template}")
41
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
42
+ if not os.path.exists(fpath):
43
+ fpath = os.path.join(model_name_or_path, f"{config.chat_template}.jinja")
44
+ with open(fpath) as fd:
45
+ chat_template = fd.read()
46
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
47
+
48
+ # Set stop tokens for the tokenizer
49
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
50
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
51
+
52
+ # Add media tokens to the tokenizer
53
+ tokenizer.media_tokens = MEDIA_TOKENS
54
+ tokenizer.media_token_ids = {}
55
+ for name, token in MEDIA_TOKENS.items():
56
+ tokenizer.add_tokens([token], special_tokens=True)
57
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
58
+
59
+ return tokenizer
60
+
61
+
62
+ def get_model_config(config):
63
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
64
+
65
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
66
+ root_path = config._name_or_path
67
+ else:
68
+ root_path = config.resume_path
69
+
70
+ # download from huggingface
71
+ if root_path is not None and not osp.exists(root_path):
72
+ try:
73
+ valid_hf_repo = repo_exists(root_path)
74
+ except HFValidationError as e:
75
+ valid_hf_repo = False
76
+ if valid_hf_repo:
77
+ root_path = snapshot_download(root_path)
78
+
79
+ return_list = []
80
+ for key in default_keys:
81
+ cfg = getattr(config, key, None)
82
+ if isinstance(cfg, dict):
83
+ try:
84
+ return_list.append(os.path.join(root_path, key[:-4]))
85
+ except:
86
+ raise ValueError(f"Cannot find resume path in config for {key}!")
87
+ elif isinstance(cfg, PretrainedConfig):
88
+ return_list.append(os.path.join(root_path, key[:-4]))
89
+ elif isinstance(cfg, str):
90
+ return_list.append(cfg)
91
+
92
+ return return_list
93
+
94
+
95
+ def get_model_config_fp8(config):
96
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
97
+
98
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
99
+ root_path = config._name_or_path
100
+ else:
101
+ root_path = config.resume_path
102
+
103
+ # download from huggingface
104
+ if root_path is not None and not osp.exists(root_path):
105
+ try:
106
+ valid_hf_repo = repo_exists(root_path)
107
+ except HFValidationError as e:
108
+ valid_hf_repo = False
109
+ if valid_hf_repo:
110
+ root_path = snapshot_download(root_path)
111
+
112
+ return_list = []
113
+ for key in default_keys:
114
+ cfg = getattr(config, key, None)
115
+ if isinstance(cfg, dict):
116
+ try:
117
+ return_list.append(os.path.join(root_path, key[:-4]))
118
+ except:
119
+ raise ValueError(f"Cannot find resume path in config for {key}!")
120
+ elif isinstance(cfg, PretrainedConfig):
121
+ return_list.append(os.path.join(root_path, key[:-4]))
122
+ elif isinstance(cfg, str):
123
+ return_list.append(cfg)
124
+
125
+ # fp8_llm
126
+ key = "fp8_llm_cfg"
127
+ directory_path = os.path.join(root_path, key[:-4])
128
+ assert os.path.isdir(directory_path) and os.listdir(
129
+ directory_path
130
+ ), "You need to first convert the model weights to FP8 explicitly."
131
+ return_list.append(directory_path)
132
+
133
+ return return_list
134
+
135
+
136
+ def get_model_config_fp8(config):
137
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
138
+
139
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
140
+ root_path = config._name_or_path
141
+ else:
142
+ root_path = config.resume_path
143
+
144
+ # download from huggingface
145
+ if root_path is not None and not osp.exists(root_path):
146
+ try:
147
+ valid_hf_repo = repo_exists(root_path)
148
+ except HFValidationError as e:
149
+ valid_hf_repo = False
150
+ if valid_hf_repo:
151
+ root_path = snapshot_download(root_path)
152
+
153
+ return_list = []
154
+ for key in default_keys:
155
+ cfg = getattr(config, key, None)
156
+ if isinstance(cfg, dict):
157
+ try:
158
+ return_list.append(os.path.join(root_path, key[:-4]))
159
+ except:
160
+ raise ValueError(f"Cannot find resume path in config for {key}!")
161
+ elif isinstance(cfg, PretrainedConfig):
162
+ return_list.append(os.path.join(root_path, key[:-4]))
163
+ elif isinstance(cfg, str):
164
+ return_list.append(cfg)
165
+
166
+ # fp8_llm
167
+ key = "fp8_llm_cfg"
168
+ directory_path = os.path.join(root_path, key[:-4])
169
+ assert os.path.isdir(directory_path) and os.listdir(
170
+ directory_path
171
+ ), "You need to first convert the model weights to FP8 explicitly."
172
+ return_list.append(directory_path)
173
+
174
+ return return_list
175
+
176
+
177
+ def is_mm_model(model_path):
178
+ """
179
+ Check if the model at the given path is a visual language model.
180
+
181
+ Args:
182
+ model_path (str): The path to the model.
183
+
184
+ Returns:
185
+ bool: True if the model is an MM model, False otherwise.
186
+ """
187
+ config = AutoConfig.from_pretrained(model_path)
188
+ architectures = config.architectures
189
+ for architecture in architectures:
190
+ if "llava" in architecture.lower():
191
+ return True
192
+ return False
193
+
194
+
195
+ def auto_upgrade(config):
196
+ cfg = AutoConfig.from_pretrained(config)
197
+ if "llava" in config and "llava" not in cfg.model_type:
198
+ assert cfg.model_type == "llama"
199
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
200
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
201
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
202
+ if confirm.lower() in ["y", "yes"]:
203
+ print("Upgrading checkpoint...")
204
+ assert len(cfg.architectures) == 1
205
+ setattr(cfg.__class__, "model_type", "llava")
206
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
207
+ cfg.save_pretrained(config)
208
+ print("Checkpoint upgraded.")
209
+ else:
210
+ print("Checkpoint upgrade aborted.")
211
+ exit(1)