Upload 19 files
Browse files- NV_LICENSE +35 -0
 - README.md +184 -0
 - auto_processor.py +495 -0
 - base_projector.py +228 -0
 - builder.py +247 -0
 - config.json +295 -0
 - configuration_vila.py +92 -0
 - constants.py +83 -0
 - conversation.py +191 -0
 - distributed.py +73 -0
 - loss.py +48 -0
 - media.py +130 -0
 - media_encoder.py +158 -0
 - mm_utils.py +575 -0
 - model_utils_packing.py +35 -0
 - modeling_vila.py +1256 -0
 - siglip_encoder.py +286 -0
 - tokenizer_utils.py +181 -0
 - utils.py +211 -0
 
    	
        NV_LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,35 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            NVIDIA License
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            1. Definitions
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            “Licensor” means any person or entity that distributes its Work.
         
     | 
| 6 | 
         
            +
            “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works  thereof  that are made available under this license.
         
     | 
| 7 | 
         
            +
            The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
         
     | 
| 8 | 
         
            +
            Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            2. License Grant
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            3. Limitations
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or educational purposes only.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            4. Disclaimer of Warranty.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
         
     | 
| 31 | 
         
            +
            MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            5. Limitation of Liability.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,184 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # LongVILA-R1-7B
         
     | 
| 2 | 
         
            +
            [](https://arxiv.org/abs/2507.07966)
         
     | 
| 3 | 
         
            +
            [](https://github.com/NVlabs/Long-RL)
         
     | 
| 4 | 
         
            +
            [](https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B)
         
     | 
| 5 | 
         
            +
            [](https://www.youtube.com/watch?v=ykbblK2jiEg)
         
     | 
| 6 | 
         
            +
            [](https://6d8b5579459b555d59.gradio.live)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ## Introduction:
         
     | 
| 10 | 
         
            +
             <p>
         
     | 
| 11 | 
         
            +
              <strong>LongVILA-R1-7B</strong> supports both <u>multiple-choice</u> questions and <u>open-ended</u> questions. It can switch between thinking and non-thinking modes.<br> 
         
     | 
| 12 | 
         
            +
              <strong>LongVILA-R1-7B</strong> demonstrates strong performance in long video reasoning, achieving <strong>70.7%</strong> on VideoMME (w/ sub.) and surpassing Gemini-1.5-Pro across diverse reasoning tasks.<br> 
         
     | 
| 13 | 
         
            +
              <strong>Long-RL</strong> is a codebase that accelerates long video RL training by up to <strong>2.1×</strong> through its MR-SP system. It supports RL training on image, video, and omni inputs across VILA, Qwen/Qwen-VL, and diffusion models.
         
     | 
| 14 | 
         
            +
            </p>
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            ## Evaluation:
         
     | 
| 17 | 
         
            +
            ### Video QA Benchmarks
         
     | 
| 18 | 
         
            +
            | Models             | VideoMME (w/o sub) | VideoMME (w sub) | ActivityNet-QA (test) | LongVideoBench (val) | PerceptionTest (val) | NExT-QA  | VNBench (val) |
         
     | 
| 19 | 
         
            +
            |:-------------------|:------------------:|:----------------:|:---------------------:|:--------------------:|:--------------------:|:--------:|:-------------:|
         
     | 
| 20 | 
         
            +
            | **LongVILA-7B**    |      **60.1**      |     **65.1**     |       **59.5**        |       **57.1**       |       **58.1**       | **80.7** |   **63.0**    |
         
     | 
| 21 | 
         
            +
            | **LongVILA-R1-7B** |      **65.0**      |     **70.7**     |       **64.8**        |       **58.0**       |       **68.9**       | **81.5** |   **75.5**    |
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            ### LongVideo-Reason-eval
         
     | 
| 24 | 
         
            +
            | Models |  Temporal | Goal | Plot | Spatial | Overall|
         
     | 
| 25 | 
         
            +
            | :--- | :---: | :---: | :---: | :---: | :---: |
         
     | 
| 26 | 
         
            +
            |  |  |  |
         
     | 
| 27 | 
         
            +
            | **LongVILA-R1-7B** | **68.1** | **85.7** | **70.6** | **53.3** | **72.0** |
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            ## Usage
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            ### Generation
         
     | 
| 33 | 
         
            +
            ```python
         
     | 
| 34 | 
         
            +
            from transformers import AutoModel
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            model_path = "Efficient-Large-Model/LongVILA-R1-7B"
         
     | 
| 37 | 
         
            +
            model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            use_thinking = True # Switching between thinking and non-thinking modes
         
     | 
| 40 | 
         
            +
            system_prompt_thinking = "You are a helpful assistant. The user asks a question, and then you solves it.\n\nPlease first think deeply about the question based on the given video, and then provide the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n\n Question: {question}"
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            prompt = "What is the main purpose of the video?"
         
     | 
| 43 | 
         
            +
            video_path = "video.mp4"
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            if use_thinking:
         
     | 
| 46 | 
         
            +
              prompt = system_prompt_thinking.format(question=prompt)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            response = model.generate_content([prompt, {"path": video_path}])
         
     | 
| 49 | 
         
            +
            print("Response: ", response)
         
     | 
| 50 | 
         
            +
            ```
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            ### with vLLM engine
         
     | 
| 53 | 
         
            +
            Tested on `vllm==0.9.1`. We need to get the remote code first.
         
     | 
| 54 | 
         
            +
            ```bash
         
     | 
| 55 | 
         
            +
            mkdir remote_code
         
     | 
| 56 | 
         
            +
            cp path_to/Efficient-Large-Model/LongVILA-R1-7B/*.py remote_code
         
     | 
| 57 | 
         
            +
            ```
         
     | 
| 58 | 
         
            +
            Then, you can use the following code for model generation.
         
     | 
| 59 | 
         
            +
            ```python
         
     | 
| 60 | 
         
            +
            import os
         
     | 
| 61 | 
         
            +
            from transformers import AutoModel
         
     | 
| 62 | 
         
            +
            from vllm import LLM, SamplingParams
         
     | 
| 63 | 
         
            +
            from remote_code.media import extract_media
         
     | 
| 64 | 
         
            +
            from remote_code.mm_utils import process_images
         
     | 
| 65 | 
         
            +
            from remote_code.tokenizer_utils import tokenize_conversation
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            model_path = "path_to/Efficient-Large-Model/LongVILA-R1-7B"
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            model_encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto", llm_only_need_embed=True)
         
     | 
| 70 | 
         
            +
            # you can change gpu_memory_utilization according to GPU memory
         
     | 
| 71 | 
         
            +
            llm = LLM(model=os.path.join(model_path, "llm"), enable_prompt_embeds=True, gpu_memory_utilization=0.5)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            use_thinking = True # Switching between thinking and non-thinking modes
         
     | 
| 74 | 
         
            +
            system_prompt_thinking = "You are a helpful assistant. The user asks a question, and then you solves it.\n\nPlease first think deeply about the question based on the given video, and then provide the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n\n Question: {question}"
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            prompt = "What is the main purpose of the video?"
         
     | 
| 77 | 
         
            +
            video_path = "video.mp4"
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            if use_thinking:
         
     | 
| 80 | 
         
            +
              prompt = system_prompt_thinking.format(question=prompt)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            conversation = [{"from": "human", "value": [prompt, {"path": video_path}]}]
         
     | 
| 83 | 
         
            +
            media = extract_media(conversation, model_encoder.config)
         
     | 
| 84 | 
         
            +
            input_ids = tokenize_conversation(conversation, model_encoder.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
         
     | 
| 85 | 
         
            +
            media["video"] = [
         
     | 
| 86 | 
         
            +
                process_images(images, model_encoder.vision_tower.image_processor, model_encoder.config).half()
         
     | 
| 87 | 
         
            +
                for images in media["video"]
         
     | 
| 88 | 
         
            +
            ]
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            inputs_embeds, _, _ = model_encoder._embed(input_ids, media, {"video": {}}, None, None)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            completions = llm.generate(prompts=[{"prompt_embeds": inputs_embeds.squeeze(0)}], sampling_params=SamplingParams(max_tokens=1024))
         
     | 
| 93 | 
         
            +
            response = completions[0].outputs[0].text
         
     | 
| 94 | 
         
            +
            print("Response: ", response)
         
     | 
| 95 | 
         
            +
            ```
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            # LongVILA-R1 Model Card
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            ## Model details
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            **Model type:**
         
     | 
| 103 | 
         
            +
            LongVILA-R1 addresses the unique challenges of long video reasoning by integrating three critical components: (1) a large-scale dataset, LongVideo-Reason, comprising 104K long video QA pairs with high-quality reasoning annotations across diverse domains such as sports, games, and vlogs; (2) a two-stage training pipeline that extends VLMs with chain-of-thought supervised fine-tuning (CoT-SFT) and reinforcement learning (RL); and (3) a training infrastructure for long video RL, named Multi-modal Reinforcement Sequence Parallelism (MR-SP), which incorporates sequence parallelism and a vLLM-based engine tailored for long video, using cached video embeddings for efficient rollout and prefilling. In our experiments, LongVILA-R1-7B achieves strong performance on video benchmarks, reaching 65.0% and 70.7% accuracy on VideoMME without and with subtitles, respectively, and consistently outperforming LongVILA-R1 across multiple benchmarks. Moreover, LongVILA-R1 shows steady performance improvements as the number of input video frames increases.
         
     | 
| 104 | 
         
            +
            **Model date:**
         
     | 
| 105 | 
         
            +
            LongVILA-R1-7B was trained in July 2025.
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            **Paper or resources for more information:**
         
     | 
| 108 | 
         
            +
            - Paper https://arxiv.org/abs/2507.07966
         
     | 
| 109 | 
         
            +
            - Code https://github.com/NVLabs/Long-RL
         
     | 
| 110 | 
         
            +
            - Model https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B
         
     | 
| 111 | 
         
            +
            - Video https://www.youtube.com/watch?v=ykbblK2jiEg
         
     | 
| 112 | 
         
            +
            - Demo https://6d8b5579459b555d59.gradio.live
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            ```bibtex
         
     | 
| 115 | 
         
            +
            @misc{long-rl,
         
     | 
| 116 | 
         
            +
              title = {Long-RL: Scaling RL to Long Sequences},
         
     | 
| 117 | 
         
            +
              author = {Yukang Chen, Wei Huang, Shuai Yang, Qinghao Hu, Baifeng Shi, Hanrong Ye, Ligeng Zhu, Zhijian Liu, Pavlo Molchanov, Jan Kautz, Xiaojuan Qi, Sifei Liu,Hongxu Yin, Yao Lu, Song Han},
         
     | 
| 118 | 
         
            +
              year = {2025},
         
     | 
| 119 | 
         
            +
              publisher = {GitHub},
         
     | 
| 120 | 
         
            +
              journal = {GitHub repository},
         
     | 
| 121 | 
         
            +
              howpublished = {\url{https://github.com/NVlabs/Long-RL}},
         
     | 
| 122 | 
         
            +
            }
         
     | 
| 123 | 
         
            +
            ```
         
     | 
| 124 | 
         
            +
            ```bibtex
         
     | 
| 125 | 
         
            +
            @article{chen2025longvila-r1,
         
     | 
| 126 | 
         
            +
                  title={Scaling RL to Long Videos},
         
     | 
| 127 | 
         
            +
                  author={Yukang Chen and Wei Huang and Baifeng Shi and Qinghao Hu and Hanrong Ye and Ligeng Zhu and Zhijian Liu and Pavlo Molchanov and Jan Kautz and Xiaojuan Qi and Sifei Liu and Hongxu Yin and Yao Lu and Song Han},
         
     | 
| 128 | 
         
            +
                  year={2025},
         
     | 
| 129 | 
         
            +
                  eprint={2507.07966},
         
     | 
| 130 | 
         
            +
                  archivePrefix={arXiv},
         
     | 
| 131 | 
         
            +
                  primaryClass={cs.CV}
         
     | 
| 132 | 
         
            +
            }
         
     | 
| 133 | 
         
            +
            ```
         
     | 
| 134 | 
         
            +
            ```bibtex
         
     | 
| 135 | 
         
            +
            @inproceedings{chen2024longvila,
         
     | 
| 136 | 
         
            +
                  title={LongVILA: Scaling Long-Context Visual Language Models for Long Videos},
         
     | 
| 137 | 
         
            +
                  author={Yukang Chen and Fuzhao Xue and Dacheng Li and Qinghao Hu and Ligeng Zhu and Xiuyu Li and Yunhao Fang and Haotian Tang and Shang Yang and Zhijian Liu and Ethan He and Hongxu Yin and Pavlo Molchanov and Jan Kautz and Linxi Fan and Yuke Zhu and Yao Lu and Song Han},
         
     | 
| 138 | 
         
            +
                  booktitle={The International Conference on Learning Representations (ICLR)},
         
     | 
| 139 | 
         
            +
                  year={2025},
         
     | 
| 140 | 
         
            +
            }
         
     | 
| 141 | 
         
            +
            ```
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            ## License
         
     | 
| 144 | 
         
            +
            - The weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
         
     | 
| 145 | 
         
            +
            - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
         
     | 
| 146 | 
         
            +
                - [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
         
     | 
| 147 | 
         
            +
                - [Dataset Licenses](https://github.com/Efficient-Large-Model/VILA/blob/main/data_prepare/LICENSE) for each one used during training.
         
     | 
| 148 | 
         
            +
                - [NVIDIA Licenses](https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B/blob/main/NV_LICENSE)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            **Where to send questions or comments about the model:**
         
     | 
| 151 | 
         
            +
            https://github.com/NVLabs/Long-RL/issues
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            ## Intended use
         
     | 
| 154 | 
         
            +
            **Primary intended uses:**
         
     | 
| 155 | 
         
            +
            The primary use of LongVILA-R1 is research on large multimodal models and chatbots.
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            **Primary intended users:**
         
     | 
| 158 | 
         
            +
            The primary intended users of the model are researchers and hobbyists in computer vision, natural language processing, machine learning, and artificial intelligence.
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            ## Input:
         
     | 
| 161 | 
         
            +
            **Input Type:** Video and Text
         
     | 
| 162 | 
         
            +
            **Input Format:** MP4 and other video fromats
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            ## Output: 
         
     | 
| 165 | 
         
            +
            **Output Type:** Text
         
     | 
| 166 | 
         
            +
            **Output Format:** String
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            **[Preferred/Supported] Operating System(s):** <br>
         
     | 
| 170 | 
         
            +
            Linux
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            ## Inference:
         
     | 
| 174 | 
         
            +
            **Engine:** [Tensor(RT), Triton, Or List Other Here]
         
     | 
| 175 | 
         
            +
            * PyTorch
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            **Test Hardware:**
         
     | 
| 179 | 
         
            +
            * A100
         
     | 
| 180 | 
         
            +
            * H100
         
     | 
| 181 | 
         
            +
            * A6000
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            ## Ethical Considerations
         
     | 
| 184 | 
         
            +
            NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications.  When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
         
     | 
    	
        auto_processor.py
    ADDED
    
    | 
         @@ -0,0 +1,495 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import copy
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import os.path as osp
         
     | 
| 4 | 
         
            +
            import warnings
         
     | 
| 5 | 
         
            +
            from collections import defaultdict
         
     | 
| 6 | 
         
            +
            from io import BytesIO
         
     | 
| 7 | 
         
            +
            from typing import List, Optional, Union
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import PIL.Image
         
     | 
| 10 | 
         
            +
            import requests
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
         
     | 
| 13 | 
         
            +
            from transformers.feature_extraction_utils import BatchFeature
         
     | 
| 14 | 
         
            +
            from transformers.image_utils import ImageInput
         
     | 
| 15 | 
         
            +
            from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
         
     | 
| 16 | 
         
            +
            from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
         
     | 
| 17 | 
         
            +
            from transformers.utils import logging
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
         
     | 
| 20 | 
         
            +
            from .media import Image, Video, extract_media
         
     | 
| 21 | 
         
            +
            from .mm_utils import process_image, process_images
         
     | 
| 22 | 
         
            +
            from .tokenizer_utils import tokenize_conversation
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
         
     | 
| 26 | 
         
            +
                if pil_image.mode == "RGBA":
         
     | 
| 27 | 
         
            +
                    white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
         
     | 
| 28 | 
         
            +
                    white_background.paste(pil_image, mask=pil_image.split()[3])  # Use alpha channel as mask
         
     | 
| 29 | 
         
            +
                    return white_background
         
     | 
| 30 | 
         
            +
                else:
         
     | 
| 31 | 
         
            +
                    return pil_image.convert("RGB")
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
         
     | 
| 35 | 
         
            +
                if "image" in ele:
         
     | 
| 36 | 
         
            +
                    image = ele["image"]
         
     | 
| 37 | 
         
            +
                else:
         
     | 
| 38 | 
         
            +
                    image = ele["image_url"]
         
     | 
| 39 | 
         
            +
                image_obj = None
         
     | 
| 40 | 
         
            +
                if isinstance(image, PIL.Image.Image):
         
     | 
| 41 | 
         
            +
                    image_obj = image
         
     | 
| 42 | 
         
            +
                elif image.startswith("http://") or image.startswith("https://"):
         
     | 
| 43 | 
         
            +
                    response = requests.get(image, stream=True)
         
     | 
| 44 | 
         
            +
                    image_obj = PIL.Image.open(BytesIO(response.content))
         
     | 
| 45 | 
         
            +
                elif image.startswith("file://"):
         
     | 
| 46 | 
         
            +
                    image_obj = PIL.Image.open(image[7:])
         
     | 
| 47 | 
         
            +
                elif image.startswith("data:image"):
         
     | 
| 48 | 
         
            +
                    if "base64," in image:
         
     | 
| 49 | 
         
            +
                        _, base64_data = image.split("base64,", 1)
         
     | 
| 50 | 
         
            +
                        data = base64.b64decode(base64_data)
         
     | 
| 51 | 
         
            +
                        image_obj = PIL.Image.open(BytesIO(data))
         
     | 
| 52 | 
         
            +
                else:
         
     | 
| 53 | 
         
            +
                    image_obj = PIL.Image.open(image)
         
     | 
| 54 | 
         
            +
                if image_obj is None:
         
     | 
| 55 | 
         
            +
                    raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
         
     | 
| 56 | 
         
            +
                image = to_rgb(image_obj)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                return image
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def fetch_image_url_or_fpath(url_or_fpath):
         
     | 
| 62 | 
         
            +
                if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
         
     | 
| 63 | 
         
            +
                    import tempfile
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    import requests
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    # Download the image to a temporary file
         
     | 
| 68 | 
         
            +
                    temp_dir = tempfile.mkdtemp()
         
     | 
| 69 | 
         
            +
                    temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    response = requests.get(url_or_fpath, stream=True)
         
     | 
| 72 | 
         
            +
                    response.raise_for_status()
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    with open(temp_file, "wb") as f:
         
     | 
| 75 | 
         
            +
                        for chunk in response.iter_content(chunk_size=8192):
         
     | 
| 76 | 
         
            +
                            f.write(chunk)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    return temp_file
         
     | 
| 79 | 
         
            +
                elif url_or_fpath.startswith("file://"):
         
     | 
| 80 | 
         
            +
                    fpath = url_or_fpath.replace("file://", "")
         
     | 
| 81 | 
         
            +
                    assert osp.exists(fpath), f"File {fpath} does not exist"
         
     | 
| 82 | 
         
            +
                    return fpath
         
     | 
| 83 | 
         
            +
                elif osp.exists(url_or_fpath):
         
     | 
| 84 | 
         
            +
                    assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
         
     | 
| 85 | 
         
            +
                    return url_or_fpath
         
     | 
| 86 | 
         
            +
                else:
         
     | 
| 87 | 
         
            +
                    raise ValueError(f"Unsupported image path: {url_or_fpath}")
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
         
     | 
| 91 | 
         
            +
                # tensor shape is (batch_size, seq_len)
         
     | 
| 92 | 
         
            +
                max_len = max([ids.shape[1] for ids in input_ids_list])
         
     | 
| 93 | 
         
            +
                if target_len is not None:
         
     | 
| 94 | 
         
            +
                    assert target_len >= max_len, "target_len must be greater than or equal to max_len"
         
     | 
| 95 | 
         
            +
                    max_len = target_len
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                new_input_ids_list = []
         
     | 
| 98 | 
         
            +
                for i, input_ids in enumerate(input_ids_list):
         
     | 
| 99 | 
         
            +
                    pad_tensor = torch.ones_like(input_ids) * padding_value
         
     | 
| 100 | 
         
            +
                    curr_len = input_ids.shape[1]
         
     | 
| 101 | 
         
            +
                    pad_tensor = pad_tensor[:, : max_len - curr_len]
         
     | 
| 102 | 
         
            +
                    if padding_side == "right":
         
     | 
| 103 | 
         
            +
                        input_ids = torch.cat((input_ids, pad_tensor), dim=1)
         
     | 
| 104 | 
         
            +
                    else:
         
     | 
| 105 | 
         
            +
                        input_ids = torch.cat((pad_tensor, input_ids), dim=1)
         
     | 
| 106 | 
         
            +
                    new_input_ids_list.append(input_ids)
         
     | 
| 107 | 
         
            +
                return torch.cat(new_input_ids_list, dim=0)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def extract_value_from_conv(chat):
         
     | 
| 111 | 
         
            +
                value = []
         
     | 
| 112 | 
         
            +
                if isinstance(chat["content"], str):
         
     | 
| 113 | 
         
            +
                    # vila_chat["value"].append(chat["content"])
         
     | 
| 114 | 
         
            +
                    value.append(chat["content"])
         
     | 
| 115 | 
         
            +
                    return value
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                # otherwise, it's a list of content
         
     | 
| 118 | 
         
            +
                for content in chat["content"]:
         
     | 
| 119 | 
         
            +
                    if content["type"] == "image":
         
     | 
| 120 | 
         
            +
                        if "path" in content:
         
     | 
| 121 | 
         
            +
                            # VILA style, can be either filepath or http url
         
     | 
| 122 | 
         
            +
                            value.append(Image(fetch_image_url_or_fpath(content["path"])))
         
     | 
| 123 | 
         
            +
                        elif "image" in content:
         
     | 
| 124 | 
         
            +
                            # Qwen style
         
     | 
| 125 | 
         
            +
                            value.append(Image(fetch_image_url_or_fpath(content["image"])))
         
     | 
| 126 | 
         
            +
                        elif "image_pil" in content:
         
     | 
| 127 | 
         
            +
                            # Qwen style
         
     | 
| 128 | 
         
            +
                            assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of {media_key} must be PIL.Image.Image"
         
     | 
| 129 | 
         
            +
                            value.append(content["image_pil"])
         
     | 
| 130 | 
         
            +
                        else:
         
     | 
| 131 | 
         
            +
                            raise ValueError(f"Type = `image` , but no `path` or `image` in | {content=}, {conversation=}")
         
     | 
| 132 | 
         
            +
                    elif content["type"] == "video":
         
     | 
| 133 | 
         
            +
                        if "video" in content:
         
     | 
| 134 | 
         
            +
                            # Qwen style
         
     | 
| 135 | 
         
            +
                            value.append(Video(fetch_image_url_or_fpath(content["video"])))
         
     | 
| 136 | 
         
            +
                        else:
         
     | 
| 137 | 
         
            +
                            raise ValueError(f"Type = `video` , but no `video` in | {content=}, {conversation=}")
         
     | 
| 138 | 
         
            +
                    elif content["type"] == "text":
         
     | 
| 139 | 
         
            +
                        value.append(content["text"])
         
     | 
| 140 | 
         
            +
                    else:
         
     | 
| 141 | 
         
            +
                        raise ValueError(f"Unsupported content type: {content['type']}")
         
     | 
| 142 | 
         
            +
                return value
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            class VILAProcessorKwargs(ProcessingKwargs, total=False):
         
     | 
| 146 | 
         
            +
                _defaults = {
         
     | 
| 147 | 
         
            +
                    "text_kwargs": {
         
     | 
| 148 | 
         
            +
                        "padding": False,
         
     | 
| 149 | 
         
            +
                    },
         
     | 
| 150 | 
         
            +
                }
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            class VILAProcessor(ProcessorMixin):
         
     | 
| 154 | 
         
            +
                # attributes = ["image_processor", "tokenizer"]
         
     | 
| 155 | 
         
            +
                attributes = []
         
     | 
| 156 | 
         
            +
                # valid_kwargs = ["chat_template"]
         
     | 
| 157 | 
         
            +
                valid_kwargs = []
         
     | 
| 158 | 
         
            +
                # image_processor_class = "VILAImageProcessor"
         
     | 
| 159 | 
         
            +
                # tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                def __init__(
         
     | 
| 162 | 
         
            +
                    self, image_processor=None, tokenizer=None, chat_template=None, config=None, padding_side="left", **kwargs
         
     | 
| 163 | 
         
            +
                ):
         
     | 
| 164 | 
         
            +
                    self.image_token = MEDIA_TOKENS["image"]
         
     | 
| 165 | 
         
            +
                    self.video_token = MEDIA_TOKENS["video"]
         
     | 
| 166 | 
         
            +
                    self.config = config
         
     | 
| 167 | 
         
            +
                    self.image_processor = image_processor
         
     | 
| 168 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 169 | 
         
            +
                    self.padding_side = padding_side
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    # This is a special setting for Qwen.
         
     | 
| 172 | 
         
            +
                    # self.pad_token_id = tokenizer.pad_token_id
         
     | 
| 173 | 
         
            +
                    self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0]  # 151643
         
     | 
| 174 | 
         
            +
                    self.eos_token_id = self.tokenizer.eos_token_id
         
     | 
| 175 | 
         
            +
                    super().__init__(image_processor, tokenizer, chat_template=chat_template)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                @staticmethod
         
     | 
| 178 | 
         
            +
                def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
         
     | 
| 179 | 
         
            +
                    """
         
     | 
| 180 | 
         
            +
                    referernce from qwen_vl_utils
         
     | 
| 181 | 
         
            +
                    """
         
     | 
| 182 | 
         
            +
                    vision_infos = []
         
     | 
| 183 | 
         
            +
                    if isinstance(conversations[0], dict):
         
     | 
| 184 | 
         
            +
                        conversations = [conversations]
         
     | 
| 185 | 
         
            +
                    for conversation in conversations:
         
     | 
| 186 | 
         
            +
                        for message in conversation:
         
     | 
| 187 | 
         
            +
                            if isinstance(message["content"], list):
         
     | 
| 188 | 
         
            +
                                for ele in message["content"]:
         
     | 
| 189 | 
         
            +
                                    if (
         
     | 
| 190 | 
         
            +
                                        "image" in ele
         
     | 
| 191 | 
         
            +
                                        or "image_url" in ele
         
     | 
| 192 | 
         
            +
                                        or "video" in ele
         
     | 
| 193 | 
         
            +
                                        or ele["type"] in ("image", "image_url", "video")
         
     | 
| 194 | 
         
            +
                                    ):
         
     | 
| 195 | 
         
            +
                                        vision_infos.append(ele)
         
     | 
| 196 | 
         
            +
                    return vision_infos
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                @staticmethod
         
     | 
| 199 | 
         
            +
                def process_vision_info(
         
     | 
| 200 | 
         
            +
                    conversations: list[dict] | list[list[dict]],
         
     | 
| 201 | 
         
            +
                    return_video_kwargs: bool = False,
         
     | 
| 202 | 
         
            +
                ) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
         
     | 
| 203 | 
         
            +
                    """
         
     | 
| 204 | 
         
            +
                    referernce from qwen_vl_utils
         
     | 
| 205 | 
         
            +
                    NVILA does not depend on the function, but the interface is the same.
         
     | 
| 206 | 
         
            +
                    """
         
     | 
| 207 | 
         
            +
                    vision_infos = extract_vision_info(conversations)
         
     | 
| 208 | 
         
            +
                    ## Read images or videos
         
     | 
| 209 | 
         
            +
                    image_inputs = []
         
     | 
| 210 | 
         
            +
                    video_inputs = []
         
     | 
| 211 | 
         
            +
                    video_sample_fps_list = []
         
     | 
| 212 | 
         
            +
                    for vision_info in vision_infos:
         
     | 
| 213 | 
         
            +
                        if "image" in vision_info or "image_url" in vision_info:
         
     | 
| 214 | 
         
            +
                            image_inputs.append(fetch_image(vision_info))
         
     | 
| 215 | 
         
            +
                        elif "video" in vision_info:
         
     | 
| 216 | 
         
            +
                            video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
         
     | 
| 217 | 
         
            +
                            video_sample_fps_list.append(video_sample_fps)
         
     | 
| 218 | 
         
            +
                            video_inputs.append(video_input)
         
     | 
| 219 | 
         
            +
                        else:
         
     | 
| 220 | 
         
            +
                            raise ValueError("image, image_url or video should in content.")
         
     | 
| 221 | 
         
            +
                    if len(image_inputs) == 0:
         
     | 
| 222 | 
         
            +
                        image_inputs = None
         
     | 
| 223 | 
         
            +
                    if len(video_inputs) == 0:
         
     | 
| 224 | 
         
            +
                        video_inputs = None
         
     | 
| 225 | 
         
            +
                    if return_video_kwargs:
         
     | 
| 226 | 
         
            +
                        return image_inputs, video_inputs, {"fps": video_sample_fps_list}
         
     | 
| 227 | 
         
            +
                    return image_inputs, video_inputs
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                @staticmethod
         
     | 
| 230 | 
         
            +
                def move_data_to_device(cls, prompt_inputs):
         
     | 
| 231 | 
         
            +
                    def _move_data_to_device(item):
         
     | 
| 232 | 
         
            +
                        # wrap function grpo trainer _prepare_input
         
     | 
| 233 | 
         
            +
                        kwargs = {"device": cls.args.device}
         
     | 
| 234 | 
         
            +
                        if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
         
     | 
| 235 | 
         
            +
                            kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
         
     | 
| 236 | 
         
            +
                        return item.to(**kwargs)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
         
     | 
| 239 | 
         
            +
                    prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
         
     | 
| 240 | 
         
            +
                    if "image" in prompt_inputs.media:
         
     | 
| 241 | 
         
            +
                        prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
         
     | 
| 242 | 
         
            +
                    return prompt_inputs
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                @classmethod
         
     | 
| 245 | 
         
            +
                def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
         
     | 
| 246 | 
         
            +
                    padding_side = kwargs.get("padding_side", "left")
         
     | 
| 247 | 
         
            +
                    if os.path.isdir(pretrained_model_name_or_path):
         
     | 
| 248 | 
         
            +
                        pretrained_model_name_or_path = pretrained_model_name_or_path
         
     | 
| 249 | 
         
            +
                    else:
         
     | 
| 250 | 
         
            +
                        print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
         
     | 
| 251 | 
         
            +
                        from huggingface_hub import snapshot_download
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                        pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    image_processor = AutoImageProcessor.from_pretrained(
         
     | 
| 256 | 
         
            +
                        osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
         
     | 
| 257 | 
         
            +
                    )
         
     | 
| 258 | 
         
            +
                    tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 259 | 
         
            +
                        osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
         
     | 
| 260 | 
         
            +
                    )
         
     | 
| 261 | 
         
            +
                    config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
         
     | 
| 262 | 
         
            +
                    return cls(image_processor=image_processor, tokenizer=tokenizer, config=config, padding_side=padding_side)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def __repr__(self):
         
     | 
| 265 | 
         
            +
                    return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def __call__(
         
     | 
| 268 | 
         
            +
                    self,
         
     | 
| 269 | 
         
            +
                    conversation=None,
         
     | 
| 270 | 
         
            +
                    **kwargs: Unpack[VILAProcessorKwargs],
         
     | 
| 271 | 
         
            +
                ) -> BatchFeature:
         
     | 
| 272 | 
         
            +
                    """
         
     | 
| 273 | 
         
            +
                    The `conv` will be look like
         
     | 
| 274 | 
         
            +
                    [
         
     | 
| 275 | 
         
            +
                        {
         
     | 
| 276 | 
         
            +
                            'from': 'human',
         
     | 
| 277 | 
         
            +
                            'value': [
         
     | 
| 278 | 
         
            +
                                <transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
         
     | 
| 279 | 
         
            +
                                'What are the common elements in these pictures?'
         
     | 
| 280 | 
         
            +
                            ]
         
     | 
| 281 | 
         
            +
                        }
         
     | 
| 282 | 
         
            +
                    ]
         
     | 
| 283 | 
         
            +
                    and `conversation` will be a list of such `conv`s
         
     | 
| 284 | 
         
            +
                    """
         
     | 
| 285 | 
         
            +
                    if kwargs.get("text", None) is not None:
         
     | 
| 286 | 
         
            +
                        conversation = kwargs.get("text")
         
     | 
| 287 | 
         
            +
                    assert conversation is not None, "`conversation` or `text` is required"
         
     | 
| 288 | 
         
            +
                    padding_side = kwargs.get("padding_side", self.padding_side)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    input_ids_list = []
         
     | 
| 291 | 
         
            +
                    attention_mask = []
         
     | 
| 292 | 
         
            +
                    media = defaultdict(list)
         
     | 
| 293 | 
         
            +
                    media_config = defaultdict(dict)
         
     | 
| 294 | 
         
            +
                    for conv in conversation:
         
     | 
| 295 | 
         
            +
                        feat = self.__single_call__(conv, **kwargs)
         
     | 
| 296 | 
         
            +
                        input_ids_list.append(feat.input_ids)
         
     | 
| 297 | 
         
            +
                        attention_mask.append(feat.attention_mask)
         
     | 
| 298 | 
         
            +
                        for name in feat.media:
         
     | 
| 299 | 
         
            +
                            media[name] += feat.media[name]
         
     | 
| 300 | 
         
            +
                        for name in feat.media_config:
         
     | 
| 301 | 
         
            +
                            media_config[name].update(feat.media_config[name])
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    # pad the input_ids to batchfy
         
     | 
| 304 | 
         
            +
                    input_ids = pad_fn(
         
     | 
| 305 | 
         
            +
                        input_ids_list,
         
     | 
| 306 | 
         
            +
                        padding_value=self.pad_token_id,
         
     | 
| 307 | 
         
            +
                        padding_side=padding_side,
         
     | 
| 308 | 
         
            +
                    )
         
     | 
| 309 | 
         
            +
                    # ignore the pad token in the attention mask
         
     | 
| 310 | 
         
            +
                    attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
         
     | 
| 311 | 
         
            +
                    attention_mask[input_ids == self.pad_token_id] = False
         
     | 
| 312 | 
         
            +
                    input_texts = self.tokenizer.batch_decode(input_ids)
         
     | 
| 313 | 
         
            +
                    bdata = BatchFeature(
         
     | 
| 314 | 
         
            +
                        data={
         
     | 
| 315 | 
         
            +
                            # "input_texts": input_texts,
         
     | 
| 316 | 
         
            +
                            "input_ids": input_ids,
         
     | 
| 317 | 
         
            +
                            "attention_mask": attention_mask,
         
     | 
| 318 | 
         
            +
                            "media": media,
         
     | 
| 319 | 
         
            +
                            "media_config": media_config,
         
     | 
| 320 | 
         
            +
                        }
         
     | 
| 321 | 
         
            +
                    )
         
     | 
| 322 | 
         
            +
                    # NOTE: hard coded to cuda
         
     | 
| 323 | 
         
            +
                    # bdata.input_ids = bdata.input_ids.cuda()
         
     | 
| 324 | 
         
            +
                    # bdata.attention_mask = bdata.attention_mask.cuda()
         
     | 
| 325 | 
         
            +
                    # bdata.media["image"] = [img.cuda() for img in bdata.media["image"]]
         
     | 
| 326 | 
         
            +
                    return bdata
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def __single_call__(
         
     | 
| 329 | 
         
            +
                    self,
         
     | 
| 330 | 
         
            +
                    conversation,
         
     | 
| 331 | 
         
            +
                    images: ImageInput = None,
         
     | 
| 332 | 
         
            +
                    text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
         
     | 
| 333 | 
         
            +
                    videos = None,
         
     | 
| 334 | 
         
            +
                    **kwargs: Unpack[VILAProcessorKwargs],
         
     | 
| 335 | 
         
            +
                ) -> BatchFeature:
         
     | 
| 336 | 
         
            +
                    conversation = copy.deepcopy(conversation)
         
     | 
| 337 | 
         
            +
                    media = extract_media(conversation, self.config)
         
     | 
| 338 | 
         
            +
                    # Process media
         
     | 
| 339 | 
         
            +
                    media_config = defaultdict(dict)
         
     | 
| 340 | 
         
            +
                    for name in media:
         
     | 
| 341 | 
         
            +
                        if name == "image":
         
     | 
| 342 | 
         
            +
                            if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
         
     | 
| 343 | 
         
            +
                                self.config.image_processor = self.image_processor
         
     | 
| 344 | 
         
            +
                                if self.config.image_aspect_ratio == "dynamic":
         
     | 
| 345 | 
         
            +
                                    images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
         
     | 
| 346 | 
         
            +
                                    # NOTE: this only works for images appears at the first conversation
         
     | 
| 347 | 
         
            +
                                    conversation[0]["value"] = conversation[0]["value"].replace(
         
     | 
| 348 | 
         
            +
                                        DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
         
     | 
| 349 | 
         
            +
                                    )
         
     | 
| 350 | 
         
            +
                                else:
         
     | 
| 351 | 
         
            +
                                    if type(self.config.s2_scales) is str:
         
     | 
| 352 | 
         
            +
                                        self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
         
     | 
| 353 | 
         
            +
                                    images, block_sizes = process_image(
         
     | 
| 354 | 
         
            +
                                        media["image"][0], self.config, None, enable_dynamic_s2=True
         
     | 
| 355 | 
         
            +
                                    )
         
     | 
| 356 | 
         
            +
                                    images = images.half()
         
     | 
| 357 | 
         
            +
                                    media_config[name]["block_sizes"] = [block_sizes]
         
     | 
| 358 | 
         
            +
                            else:
         
     | 
| 359 | 
         
            +
                                images = process_images(media["image"], self.image_processor, self.config).half()
         
     | 
| 360 | 
         
            +
                            media[name] = [image for image in images]
         
     | 
| 361 | 
         
            +
                        elif name == "video":
         
     | 
| 362 | 
         
            +
                            media[name] = [
         
     | 
| 363 | 
         
            +
                                process_images(images, self.image_processor, self.config).half() for images in media[name]
         
     | 
| 364 | 
         
            +
                            ]
         
     | 
| 365 | 
         
            +
                        else:
         
     | 
| 366 | 
         
            +
                            raise ValueError(f"Unsupported media type: {name}")
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                    inputs = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True, return_ids_only=False)
         
     | 
| 369 | 
         
            +
                    input_ids = inputs.input_ids[0].unsqueeze(0).cuda()
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
         
     | 
| 372 | 
         
            +
                    return BatchFeature(
         
     | 
| 373 | 
         
            +
                        data={
         
     | 
| 374 | 
         
            +
                            "input_ids": input_ids,
         
     | 
| 375 | 
         
            +
                            "attention_mask": attention_mask,
         
     | 
| 376 | 
         
            +
                            "media": media,
         
     | 
| 377 | 
         
            +
                            "media_config": media_config,
         
     | 
| 378 | 
         
            +
                        }
         
     | 
| 379 | 
         
            +
                    )
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                def batch_decode(self, *args, **kwargs):
         
     | 
| 382 | 
         
            +
                    """
         
     | 
| 383 | 
         
            +
                    This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
         
     | 
| 384 | 
         
            +
                    refer to the docstring of this method for more information.
         
     | 
| 385 | 
         
            +
                    """
         
     | 
| 386 | 
         
            +
                    return self.tokenizer.batch_decode(*args, **kwargs)
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                def decode(self, *args, **kwargs):
         
     | 
| 389 | 
         
            +
                    """
         
     | 
| 390 | 
         
            +
                    This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
         
     | 
| 391 | 
         
            +
                    the docstring of this method for more information.
         
     | 
| 392 | 
         
            +
                    """
         
     | 
| 393 | 
         
            +
                    return self.tokenizer.decode(*args, **kwargs)
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                def post_process_image_text_to_text(self, generated_outputs):
         
     | 
| 396 | 
         
            +
                    """
         
     | 
| 397 | 
         
            +
                    Post-process the output of the model to decode the text.
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    Args:
         
     | 
| 400 | 
         
            +
                        generated_outputs (`torch.Tensor` or `np.ndarray`):
         
     | 
| 401 | 
         
            +
                            The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
         
     | 
| 402 | 
         
            +
                            or `(sequence_length,)`.
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                    Returns:
         
     | 
| 405 | 
         
            +
                        `List[str]`: The decoded text.
         
     | 
| 406 | 
         
            +
                    """
         
     | 
| 407 | 
         
            +
                    return self.tokenizer.batch_decode(
         
     | 
| 408 | 
         
            +
                        generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
         
     | 
| 409 | 
         
            +
                    )
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                @property
         
     | 
| 412 | 
         
            +
                def model_input_names(self):
         
     | 
| 413 | 
         
            +
                    tokenizer_input_names = self.tokenizer.model_input_names
         
     | 
| 414 | 
         
            +
                    image_processor_input_names = self.image_processor.model_input_names
         
     | 
| 415 | 
         
            +
                    return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                def convert_gpt_conv_to_vila_conv(self, conversation):
         
     | 
| 418 | 
         
            +
                    vila_conv = []
         
     | 
| 419 | 
         
            +
                    for chat in conversation:
         
     | 
| 420 | 
         
            +
                        vila_chat = {"from": "", "value": []}
         
     | 
| 421 | 
         
            +
                        if chat["role"] in ("user", "system"):
         
     | 
| 422 | 
         
            +
                            # user allows to input image and text
         
     | 
| 423 | 
         
            +
                            vila_chat["from"] = "human" if chat["role"] == "user" else "system"
         
     | 
| 424 | 
         
            +
                            vila_chat["value"] = extract_value_from_conv(chat)
         
     | 
| 425 | 
         
            +
                        elif chat["role"] == "assistant":
         
     | 
| 426 | 
         
            +
                            vila_chat["from"] = "gpt"
         
     | 
| 427 | 
         
            +
                            vila_chat["value"] = extract_value_from_conv(chat)
         
     | 
| 428 | 
         
            +
                        else:
         
     | 
| 429 | 
         
            +
                            raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
         
     | 
| 430 | 
         
            +
                        vila_conv.append(vila_chat)
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                    return vila_conv
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
         
     | 
| 435 | 
         
            +
                    return self.convert_gpt_conv_to_vila_conv(conversation)
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 439 | 
         
            +
                # gpt style: user, assistant
         
     | 
| 440 | 
         
            +
                # vila style: human, gpt
         
     | 
| 441 | 
         
            +
                gpt_conv = [
         
     | 
| 442 | 
         
            +
                    {
         
     | 
| 443 | 
         
            +
                        "role": "user",
         
     | 
| 444 | 
         
            +
                        "content": [
         
     | 
| 445 | 
         
            +
                            {"type": "image", "path": "demo_images/demo_img_1.png"},
         
     | 
| 446 | 
         
            +
                            {"type": "text", "text": "Describe this image."},
         
     | 
| 447 | 
         
            +
                        ],
         
     | 
| 448 | 
         
            +
                    }
         
     | 
| 449 | 
         
            +
                ]
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                llavaconv = [
         
     | 
| 452 | 
         
            +
                    {
         
     | 
| 453 | 
         
            +
                        "from": "human",
         
     | 
| 454 | 
         
            +
                        "value": [
         
     | 
| 455 | 
         
            +
                            PIL.Image.open("demo_images/demo_img_1.png"),
         
     | 
| 456 | 
         
            +
                            "Describe this image.",
         
     | 
| 457 | 
         
            +
                        ],
         
     | 
| 458 | 
         
            +
                    }
         
     | 
| 459 | 
         
            +
                ]
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True)
         
     | 
| 462 | 
         
            +
                inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
         
     | 
| 463 | 
         
            +
                # model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda()
         
     | 
| 464 | 
         
            +
                # print(model)
         
     | 
| 465 | 
         
            +
                model_path = "NVILA-Lite-2B-hf-preview"
         
     | 
| 466 | 
         
            +
                model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
         
     | 
| 467 | 
         
            +
                # res = model.generate_content(["how are you today?"])
         
     | 
| 468 | 
         
            +
                # print(model.config)
         
     | 
| 469 | 
         
            +
                # print(model.tokenizer)
         
     | 
| 470 | 
         
            +
                # print(res)
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                processor = VILAProcessor(
         
     | 
| 473 | 
         
            +
                    config=model.config,
         
     | 
| 474 | 
         
            +
                    image_processor=model.vision_tower.image_processor,
         
     | 
| 475 | 
         
            +
                    tokenizer=model.tokenizer,
         
     | 
| 476 | 
         
            +
                )
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
                inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
         
     | 
| 479 | 
         
            +
                print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
         
     | 
| 480 | 
         
            +
                print("vila conv pass")
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
         
     | 
| 483 | 
         
            +
                print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
         
     | 
| 484 | 
         
            +
                print("gpt conv pass")
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                output_ids = model.generate(
         
     | 
| 487 | 
         
            +
                    input_ids=inputs.input_ids,
         
     | 
| 488 | 
         
            +
                    media={
         
     | 
| 489 | 
         
            +
                        "image": inputs.image,
         
     | 
| 490 | 
         
            +
                    },
         
     | 
| 491 | 
         
            +
                    media_config={"image": {}},
         
     | 
| 492 | 
         
            +
                    generation_config=model.generation_config,
         
     | 
| 493 | 
         
            +
                    max_new_tokens=100,
         
     | 
| 494 | 
         
            +
                )
         
     | 
| 495 | 
         
            +
                print(output_ids)
         
     | 
    	
        base_projector.py
    ADDED
    
    | 
         @@ -0,0 +1,228 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import re
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            import torch.nn as nn
         
     | 
| 21 | 
         
            +
            from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class IdentityMap(nn.Module):
         
     | 
| 25 | 
         
            +
                def __init__(self):
         
     | 
| 26 | 
         
            +
                    super().__init__()
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def forward(self, x, *args, **kwargs):
         
     | 
| 29 | 
         
            +
                    return x
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                @property
         
     | 
| 32 | 
         
            +
                def config(self):
         
     | 
| 33 | 
         
            +
                    return {"mm_projector_type": "identity"}
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class SimpleResBlock(nn.Module):
         
     | 
| 37 | 
         
            +
                def __init__(self, channels):
         
     | 
| 38 | 
         
            +
                    super().__init__()
         
     | 
| 39 | 
         
            +
                    self.pre_norm = nn.LayerNorm(channels)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def forward(self, x):
         
     | 
| 44 | 
         
            +
                    x = self.pre_norm(x)
         
     | 
| 45 | 
         
            +
                    return x + self.proj(x)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            class DownSampleBlock(nn.Module):
         
     | 
| 49 | 
         
            +
                def forward(self, x):
         
     | 
| 50 | 
         
            +
                    vit_embeds = x
         
     | 
| 51 | 
         
            +
                    h = w = int(vit_embeds.shape[1] ** 0.5)
         
     | 
| 52 | 
         
            +
                    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
         
     | 
| 53 | 
         
            +
                    vit_embeds = self.flat_square(vit_embeds)
         
     | 
| 54 | 
         
            +
                    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
         
     | 
| 55 | 
         
            +
                    return vit_embeds
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def flat_square(self, x):
         
     | 
| 58 | 
         
            +
                    n, w, h, c = x.size()
         
     | 
| 59 | 
         
            +
                    if w % 2 == 1:
         
     | 
| 60 | 
         
            +
                        x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
         
     | 
| 61 | 
         
            +
                        n, w, h, c = x.size()
         
     | 
| 62 | 
         
            +
                    if h % 2 == 1:
         
     | 
| 63 | 
         
            +
                        x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
         
     | 
| 64 | 
         
            +
                        n, w, h, c = x.size()
         
     | 
| 65 | 
         
            +
                    x = x.contiguous()
         
     | 
| 66 | 
         
            +
                    x = x.view(n, w, int(h / 2), int(c * 2))
         
     | 
| 67 | 
         
            +
                    x = x.permute(0, 2, 1, 3).contiguous()
         
     | 
| 68 | 
         
            +
                    x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
         
     | 
| 69 | 
         
            +
                    x = x.permute(0, 2, 1, 3).contiguous()
         
     | 
| 70 | 
         
            +
                    return x
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            class DownSample2x2BlockFix(nn.Module):
         
     | 
| 74 | 
         
            +
                def forward(self, x):
         
     | 
| 75 | 
         
            +
                    vit_embeds = x
         
     | 
| 76 | 
         
            +
                    h = w = int(vit_embeds.shape[1] ** 0.5)
         
     | 
| 77 | 
         
            +
                    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
         
     | 
| 78 | 
         
            +
                    vit_embeds = flat_square_2x2(vit_embeds)
         
     | 
| 79 | 
         
            +
                    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
         
     | 
| 80 | 
         
            +
                    return vit_embeds
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def flat_square_2x2(x):
         
     | 
| 84 | 
         
            +
                n, w, h, c = x.size()
         
     | 
| 85 | 
         
            +
                if w % 2 == 1:
         
     | 
| 86 | 
         
            +
                    x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
         
     | 
| 87 | 
         
            +
                    n, w, h, c = x.size()
         
     | 
| 88 | 
         
            +
                x = x.contiguous()
         
     | 
| 89 | 
         
            +
                if h % 2 == 1:
         
     | 
| 90 | 
         
            +
                    x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
         
     | 
| 91 | 
         
            +
                    n, w, h, c = x.size()
         
     | 
| 92 | 
         
            +
                x = x.view(n, w, int(h / 2), int(c * 2))
         
     | 
| 93 | 
         
            +
                x = x.permute(0, 2, 1, 3).contiguous()
         
     | 
| 94 | 
         
            +
                x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
         
     | 
| 95 | 
         
            +
                x = x.permute(0, 2, 1, 3).contiguous()
         
     | 
| 96 | 
         
            +
                return x
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            class DownSample3x3BlockFix(nn.Module):
         
     | 
| 100 | 
         
            +
                def forward(self, x):
         
     | 
| 101 | 
         
            +
                    vit_embeds = x
         
     | 
| 102 | 
         
            +
                    h = w = int(vit_embeds.shape[1] ** 0.5)
         
     | 
| 103 | 
         
            +
                    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
         
     | 
| 104 | 
         
            +
                    vit_embeds = flat_square_3x3(vit_embeds)
         
     | 
| 105 | 
         
            +
                    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
         
     | 
| 106 | 
         
            +
                    return vit_embeds
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            def flat_square_3x3(x):
         
     | 
| 110 | 
         
            +
                n, w, h, c = x.size()
         
     | 
| 111 | 
         
            +
                if w % 3 != 0:
         
     | 
| 112 | 
         
            +
                    x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
         
     | 
| 113 | 
         
            +
                    n, w, h, c = x.size()
         
     | 
| 114 | 
         
            +
                x = x.contiguous()
         
     | 
| 115 | 
         
            +
                if h % 3 != 0:
         
     | 
| 116 | 
         
            +
                    x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
         
     | 
| 117 | 
         
            +
                    n, w, h, c = x.size()
         
     | 
| 118 | 
         
            +
                x = x.view(n, w, int(h / 3), int(c * 3))
         
     | 
| 119 | 
         
            +
                x = x.permute(0, 2, 1, 3).contiguous()
         
     | 
| 120 | 
         
            +
                x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
         
     | 
| 121 | 
         
            +
                x = x.permute(0, 2, 1, 3).contiguous()
         
     | 
| 122 | 
         
            +
                return x
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            class MultimodalProjectorConfig(PretrainedConfig):
         
     | 
| 126 | 
         
            +
                model_type = "v2l_projector"
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                def __init__(self, mm_projector_type: str = None, **kwargs):
         
     | 
| 129 | 
         
            +
                    super().__init__()
         
     | 
| 130 | 
         
            +
                    self.mm_projector_type = mm_projector_type
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            class MultimodalProjector(PreTrainedModel):
         
     | 
| 134 | 
         
            +
                config_class = MultimodalProjectorConfig
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
         
     | 
| 137 | 
         
            +
                    super().__init__(mm_projector_cfg)
         
     | 
| 138 | 
         
            +
                    mm_projector_type = mm_projector_cfg.mm_projector_type
         
     | 
| 139 | 
         
            +
                    self.downsample_rate = 1
         
     | 
| 140 | 
         
            +
                    if mm_projector_type == "identity":
         
     | 
| 141 | 
         
            +
                        self.layers = IdentityMap()
         
     | 
| 142 | 
         
            +
                    elif mm_projector_type == "linear":
         
     | 
| 143 | 
         
            +
                        self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
         
     | 
| 144 | 
         
            +
                    elif mm_projector_type == "mlp_downsample":
         
     | 
| 145 | 
         
            +
                        self.layers = nn.Sequential(
         
     | 
| 146 | 
         
            +
                            DownSampleBlock(),
         
     | 
| 147 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 4),
         
     | 
| 148 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
         
     | 
| 149 | 
         
            +
                            nn.GELU(),
         
     | 
| 150 | 
         
            +
                            nn.Linear(config.hidden_size, config.hidden_size),
         
     | 
| 151 | 
         
            +
                        )
         
     | 
| 152 | 
         
            +
                        self.downsample_rate = 2
         
     | 
| 153 | 
         
            +
                    elif mm_projector_type == "mlp_downsample_2x2_fix":
         
     | 
| 154 | 
         
            +
                        self.layers = nn.Sequential(
         
     | 
| 155 | 
         
            +
                            DownSample2x2BlockFix(),
         
     | 
| 156 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 4),
         
     | 
| 157 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
         
     | 
| 158 | 
         
            +
                            nn.GELU(),
         
     | 
| 159 | 
         
            +
                            nn.Linear(config.hidden_size, config.hidden_size),
         
     | 
| 160 | 
         
            +
                        )
         
     | 
| 161 | 
         
            +
                        self.downsample_rate = 2
         
     | 
| 162 | 
         
            +
                    elif mm_projector_type == "mlp_downsample_3x3_fix":
         
     | 
| 163 | 
         
            +
                        self.layers = nn.Sequential(
         
     | 
| 164 | 
         
            +
                            DownSample3x3BlockFix(),
         
     | 
| 165 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 9),
         
     | 
| 166 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
         
     | 
| 167 | 
         
            +
                            nn.GELU(),
         
     | 
| 168 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 3),
         
     | 
| 169 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
         
     | 
| 170 | 
         
            +
                            nn.GELU(),
         
     | 
| 171 | 
         
            +
                            nn.Linear(config.hidden_size, config.hidden_size),
         
     | 
| 172 | 
         
            +
                        )
         
     | 
| 173 | 
         
            +
                        self.downsample_rate = 3
         
     | 
| 174 | 
         
            +
                    elif mm_projector_type == "mlp_downsample_3x3_s2":
         
     | 
| 175 | 
         
            +
                        self.layers = nn.Sequential(
         
     | 
| 176 | 
         
            +
                            DownSample3x3BlockFix(),
         
     | 
| 177 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 9),
         
     | 
| 178 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
         
     | 
| 179 | 
         
            +
                            nn.GELU(),
         
     | 
| 180 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 3),
         
     | 
| 181 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
         
     | 
| 182 | 
         
            +
                            nn.GELU(),
         
     | 
| 183 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size),
         
     | 
| 184 | 
         
            +
                            nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
         
     | 
| 185 | 
         
            +
                            nn.GELU(),
         
     | 
| 186 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size // 3),
         
     | 
| 187 | 
         
            +
                            nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
         
     | 
| 188 | 
         
            +
                            nn.GELU(),
         
     | 
| 189 | 
         
            +
                            nn.Linear(config.hidden_size, config.hidden_size),
         
     | 
| 190 | 
         
            +
                        )
         
     | 
| 191 | 
         
            +
                    elif mm_projector_type == "mlp_downsample_3x3_s2_new":
         
     | 
| 192 | 
         
            +
                        self.layers = nn.Sequential(
         
     | 
| 193 | 
         
            +
                            DownSample3x3BlockFix(),
         
     | 
| 194 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 9),
         
     | 
| 195 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
         
     | 
| 196 | 
         
            +
                            nn.GELU(),
         
     | 
| 197 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 4),
         
     | 
| 198 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
         
     | 
| 199 | 
         
            +
                            nn.GELU(),
         
     | 
| 200 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size * 2),
         
     | 
| 201 | 
         
            +
                            nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
         
     | 
| 202 | 
         
            +
                            nn.GELU(),
         
     | 
| 203 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size),
         
     | 
| 204 | 
         
            +
                            nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
         
     | 
| 205 | 
         
            +
                            nn.GELU(),
         
     | 
| 206 | 
         
            +
                            nn.LayerNorm(config.mm_hidden_size // 3),
         
     | 
| 207 | 
         
            +
                            nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
         
     | 
| 208 | 
         
            +
                            nn.GELU(),
         
     | 
| 209 | 
         
            +
                            nn.Linear(config.hidden_size, config.hidden_size),
         
     | 
| 210 | 
         
            +
                        )
         
     | 
| 211 | 
         
            +
                    else:
         
     | 
| 212 | 
         
            +
                        mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
         
     | 
| 213 | 
         
            +
                        if mlp_gelu_match:
         
     | 
| 214 | 
         
            +
                            mlp_depth = int(mlp_gelu_match.group(1))
         
     | 
| 215 | 
         
            +
                            modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
         
     | 
| 216 | 
         
            +
                            for _ in range(1, mlp_depth):
         
     | 
| 217 | 
         
            +
                                modules.append(nn.GELU())
         
     | 
| 218 | 
         
            +
                                modules.append(nn.Linear(config.hidden_size, config.hidden_size))
         
     | 
| 219 | 
         
            +
                            self.layers = nn.Sequential(*modules)
         
     | 
| 220 | 
         
            +
                        else:
         
     | 
| 221 | 
         
            +
                            raise ValueError(f"Unknown projector type: {mm_projector_type}")
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                def forward(self, x, *args, **kwargs):
         
     | 
| 224 | 
         
            +
                    return self.layers(x)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            # AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
         
     | 
| 228 | 
         
            +
            # AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
         
     | 
    	
        builder.py
    ADDED
    
    | 
         @@ -0,0 +1,247 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import math
         
     | 
| 18 | 
         
            +
            import os
         
     | 
| 19 | 
         
            +
            import os.path as osp
         
     | 
| 20 | 
         
            +
            import warnings
         
     | 
| 21 | 
         
            +
            from dataclasses import asdict
         
     | 
| 22 | 
         
            +
            from typing import Any, Dict, List, Optional, Sequence, Tuple
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import torch
         
     | 
| 25 | 
         
            +
            import transformers
         
     | 
| 26 | 
         
            +
            from huggingface_hub import file_exists, repo_exists
         
     | 
| 27 | 
         
            +
            from huggingface_hub.utils import HFValidationError
         
     | 
| 28 | 
         
            +
            from transformers import (
         
     | 
| 29 | 
         
            +
                AutoConfig,
         
     | 
| 30 | 
         
            +
                AutoModelForCausalLM,
         
     | 
| 31 | 
         
            +
                AutoTokenizer,
         
     | 
| 32 | 
         
            +
                PretrainedConfig,
         
     | 
| 33 | 
         
            +
                PreTrainedModel,
         
     | 
| 34 | 
         
            +
                PreTrainedTokenizer,
         
     | 
| 35 | 
         
            +
            )
         
     | 
| 36 | 
         
            +
            from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            # from .conversation import *
         
     | 
| 39 | 
         
            +
            from .conversation import SeparatorStyle, default_conversation
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            SENTINEL_TOKEN = "<vila/sentinel>"
         
     | 
| 42 | 
         
            +
            MEDIA_TOKENS = {
         
     | 
| 43 | 
         
            +
                "image": "<image>",
         
     | 
| 44 | 
         
            +
                "video": "<vila/video>",
         
     | 
| 45 | 
         
            +
            }
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            # from llava.model.utils import packing
         
     | 
| 48 | 
         
            +
            # from llava.utils.logging import logger
         
     | 
| 49 | 
         
            +
            # from llava.utils.tokenizer import infer_stop_tokens
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            DUMMY_CONVERSATION = [
         
     | 
| 52 | 
         
            +
                {"from": "human", "value": "question"},
         
     | 
| 53 | 
         
            +
                {"from": "gpt", "value": "answer"},
         
     | 
| 54 | 
         
            +
            ] * 10
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
         
     | 
| 58 | 
         
            +
                return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def has_tokenizer(repo_id_or_path: str) -> bool:
         
     | 
| 62 | 
         
            +
                # Check if the tokenizer is in a local directory
         
     | 
| 63 | 
         
            +
                if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
         
     | 
| 64 | 
         
            +
                    return True
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                # Check if the tokenizer is in a Hugging Face Hub repo
         
     | 
| 67 | 
         
            +
                try:
         
     | 
| 68 | 
         
            +
                    return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
         
     | 
| 69 | 
         
            +
                except HFValidationError:
         
     | 
| 70 | 
         
            +
                    return False
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
         
     | 
| 74 | 
         
            +
                if not hasattr(tokenizer, "sentinel_token"):
         
     | 
| 75 | 
         
            +
                    tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
         
     | 
| 76 | 
         
            +
                    tokenizer.sentinel_token = SENTINEL_TOKEN
         
     | 
| 77 | 
         
            +
                    tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            def tokenize_conversation_legacy(
         
     | 
| 81 | 
         
            +
                messages: Sequence[Dict[str, str]],
         
     | 
| 82 | 
         
            +
                tokenizer: transformers.PreTrainedTokenizer,
         
     | 
| 83 | 
         
            +
                add_generation_prompt: bool = False,
         
     | 
| 84 | 
         
            +
                overrides: Optional[Dict[str, str]] = None,
         
     | 
| 85 | 
         
            +
                no_system_prompt: bool = False,
         
     | 
| 86 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 87 | 
         
            +
                conv = default_conversation.copy()
         
     | 
| 88 | 
         
            +
                roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                if no_system_prompt:
         
     | 
| 91 | 
         
            +
                    conv.system = ""
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                # Skip the first message if it is not from human
         
     | 
| 94 | 
         
            +
                if messages[0]["from"] != "human":
         
     | 
| 95 | 
         
            +
                    messages = messages[1:]
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                # Add a generation prompt if needed
         
     | 
| 98 | 
         
            +
                if add_generation_prompt:
         
     | 
| 99 | 
         
            +
                    messages.append({"from": "gpt", "value": None})
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                conv.messages = []
         
     | 
| 102 | 
         
            +
                for turn, message in enumerate(messages):
         
     | 
| 103 | 
         
            +
                    role = roles[message["from"]]
         
     | 
| 104 | 
         
            +
                    assert role == conv.roles[turn % 2]
         
     | 
| 105 | 
         
            +
                    if overrides is not None and message["from"] in overrides:
         
     | 
| 106 | 
         
            +
                        conv.append_message(role, overrides[message["from"]])
         
     | 
| 107 | 
         
            +
                    else:
         
     | 
| 108 | 
         
            +
                        conv.append_message(role, message["value"])
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def tokenize_conversation(
         
     | 
| 114 | 
         
            +
                messages: Sequence[Dict[str, str]],
         
     | 
| 115 | 
         
            +
                tokenizer: transformers.PreTrainedTokenizer,
         
     | 
| 116 | 
         
            +
                add_generation_prompt: bool = False,
         
     | 
| 117 | 
         
            +
                overrides: Optional[Dict[str, str]] = None,
         
     | 
| 118 | 
         
            +
                no_system_prompt: bool = False,
         
     | 
| 119 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 120 | 
         
            +
                # Normalize the conversation before tokenization
         
     | 
| 121 | 
         
            +
                for message in messages:
         
     | 
| 122 | 
         
            +
                    message["value"] = message["value"].strip()
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                if default_conversation.sep_style != SeparatorStyle.AUTO:
         
     | 
| 125 | 
         
            +
                    return tokenize_conversation_legacy(
         
     | 
| 126 | 
         
            +
                        messages,
         
     | 
| 127 | 
         
            +
                        tokenizer,
         
     | 
| 128 | 
         
            +
                        add_generation_prompt=add_generation_prompt,
         
     | 
| 129 | 
         
            +
                        overrides=overrides,
         
     | 
| 130 | 
         
            +
                        no_system_prompt=no_system_prompt,
         
     | 
| 131 | 
         
            +
                    )
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                conversation = []
         
     | 
| 134 | 
         
            +
                for m in messages:
         
     | 
| 135 | 
         
            +
                    message = {}
         
     | 
| 136 | 
         
            +
                    if m["from"] == "human":
         
     | 
| 137 | 
         
            +
                        message["role"] = "user"
         
     | 
| 138 | 
         
            +
                    elif m["from"] == "gpt":
         
     | 
| 139 | 
         
            +
                        message["role"] = "assistant"
         
     | 
| 140 | 
         
            +
                    else:
         
     | 
| 141 | 
         
            +
                        raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    message["content"] = m["value"]
         
     | 
| 144 | 
         
            +
                    if overrides is not None and m["from"] in overrides:
         
     | 
| 145 | 
         
            +
                        message["content"] = overrides[m["from"]]
         
     | 
| 146 | 
         
            +
                    conversation.append(message)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                if no_system_prompt:
         
     | 
| 149 | 
         
            +
                    conversation = [{"role": "system", "content": ""}] + conversation
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                text = tokenizer.apply_chat_template(
         
     | 
| 152 | 
         
            +
                    conversation,
         
     | 
| 153 | 
         
            +
                    add_generation_prompt=add_generation_prompt,
         
     | 
| 154 | 
         
            +
                    tokenize=False,
         
     | 
| 155 | 
         
            +
                )
         
     | 
| 156 | 
         
            +
                return tokenizer_image_token(text, tokenizer, return_tensors="pt")
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
            def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
         
     | 
| 160 | 
         
            +
                _maybe_add_sentinel_token(tokenizer)
         
     | 
| 161 | 
         
            +
                template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                stop_tokens = {tokenizer.eos_token}
         
     | 
| 164 | 
         
            +
                for k in range(template.size(0) - 1):
         
     | 
| 165 | 
         
            +
                    if template[k] == tokenizer.sentinel_token_id:
         
     | 
| 166 | 
         
            +
                        stop_token = tokenizer.decode(template[k + 1])
         
     | 
| 167 | 
         
            +
                        stop_tokens.add(stop_token)
         
     | 
| 168 | 
         
            +
                return list(stop_tokens)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
            def context_length_extension(config):
         
     | 
| 172 | 
         
            +
                orig_ctx_len = getattr(config, "max_position_embeddings", None)
         
     | 
| 173 | 
         
            +
                model_max_length = getattr(config, "model_max_length", None)
         
     | 
| 174 | 
         
            +
                if orig_ctx_len and model_max_length > orig_ctx_len:
         
     | 
| 175 | 
         
            +
                    print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
         
     | 
| 176 | 
         
            +
                    scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
         
     | 
| 177 | 
         
            +
                    config.rope_scaling = {"type": "linear", "factor": scaling_factor}
         
     | 
| 178 | 
         
            +
                return config
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            def build_llm_and_tokenizer(
         
     | 
| 182 | 
         
            +
                model_name_or_path: str,
         
     | 
| 183 | 
         
            +
                config: PretrainedConfig,
         
     | 
| 184 | 
         
            +
                attn_implementation=None,
         
     | 
| 185 | 
         
            +
                model_max_length=None,
         
     | 
| 186 | 
         
            +
                *args,
         
     | 
| 187 | 
         
            +
                **kwargs,
         
     | 
| 188 | 
         
            +
            ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
         
     | 
| 189 | 
         
            +
                # print(model_name_or_path)
         
     | 
| 190 | 
         
            +
                llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
         
     | 
| 191 | 
         
            +
                llm_cfg._attn_implementation = attn_implementation
         
     | 
| 192 | 
         
            +
                llm_cfg.model_max_length = model_max_length
         
     | 
| 193 | 
         
            +
                if model_max_length is not None:
         
     | 
| 194 | 
         
            +
                    context_length_extension(llm_cfg)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                # Quantization related
         
     | 
| 197 | 
         
            +
                quantization_restore_from_checkpoint = False
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                if quantization_restore_from_checkpoint:
         
     | 
| 200 | 
         
            +
                    fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    llm = AutoModelForCausalLM.from_pretrained(
         
     | 
| 203 | 
         
            +
                        fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
         
     | 
| 204 | 
         
            +
                    )
         
     | 
| 205 | 
         
            +
                else:
         
     | 
| 206 | 
         
            +
                    if is_deepspeed_zero3_enabled():
         
     | 
| 207 | 
         
            +
                        # NOTE: found by wei, need to pop out device_map when using zero3
         
     | 
| 208 | 
         
            +
                        kwargs.pop("device_map")
         
     | 
| 209 | 
         
            +
                    llm = AutoModelForCausalLM.from_pretrained(
         
     | 
| 210 | 
         
            +
                        model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
         
     | 
| 211 | 
         
            +
                    )
         
     | 
| 212 | 
         
            +
                # packing.patch(llm)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                # Locate the tokenizer.
         
     | 
| 215 | 
         
            +
                llm_path = model_name_or_path
         
     | 
| 216 | 
         
            +
                if not has_tokenizer(llm_path):
         
     | 
| 217 | 
         
            +
                    llm_path = osp.join(llm_path, "llm")
         
     | 
| 218 | 
         
            +
                if not has_tokenizer(llm_path):
         
     | 
| 219 | 
         
            +
                    raise ValueError(f"Cannot find tokenizer in {llm_path}.")
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
         
     | 
| 222 | 
         
            +
                if model_max_length is not None:
         
     | 
| 223 | 
         
            +
                    tokenizer.model_max_length = model_max_length
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                # Load chat template if specified.
         
     | 
| 226 | 
         
            +
                if getattr(config, "chat_template", None) is not None:
         
     | 
| 227 | 
         
            +
                    print(f"Using chat template: {config.chat_template}")
         
     | 
| 228 | 
         
            +
                    fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
         
     | 
| 229 | 
         
            +
                    if not os.path.exists(fpath):
         
     | 
| 230 | 
         
            +
                        fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
         
     | 
| 231 | 
         
            +
                    with open(fpath) as fd:
         
     | 
| 232 | 
         
            +
                        chat_template = fd.read()
         
     | 
| 233 | 
         
            +
                    tokenizer.chat_template = chat_template.replace("    ", "").replace("\n", "")
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                # Set stop tokens for the tokenizer
         
     | 
| 236 | 
         
            +
                tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
         
     | 
| 237 | 
         
            +
                tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                # Add media tokens to the tokenizer
         
     | 
| 240 | 
         
            +
                tokenizer.media_tokens = MEDIA_TOKENS
         
     | 
| 241 | 
         
            +
                tokenizer.media_token_ids = {}
         
     | 
| 242 | 
         
            +
                for name, token in MEDIA_TOKENS.items():
         
     | 
| 243 | 
         
            +
                    tokenizer.add_tokens([token], special_tokens=True)
         
     | 
| 244 | 
         
            +
                    tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                config.hidden_size = llm.config.hidden_size
         
     | 
| 247 | 
         
            +
                return llm, tokenizer
         
     | 
    	
        config.json
    ADDED
    
    | 
         @@ -0,0 +1,295 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_attn_implementation_autoset": true,
         
     | 
| 3 | 
         
            +
              "_name_or_path": "./LongVILA-R1-7B",
         
     | 
| 4 | 
         
            +
              "architectures": [
         
     | 
| 5 | 
         
            +
                "VILAForCausalLM"
         
     | 
| 6 | 
         
            +
              ],
         
     | 
| 7 | 
         
            +
              "chat_template": null,
         
     | 
| 8 | 
         
            +
              "drop_path_rate": 0.0,
         
     | 
| 9 | 
         
            +
              "fps": 0.0,
         
     | 
| 10 | 
         
            +
              "hidden_size": 3584,
         
     | 
| 11 | 
         
            +
              "image_aspect_ratio": "resize",
         
     | 
| 12 | 
         
            +
              "image_encoder": {
         
     | 
| 13 | 
         
            +
                "_target_": "llava.model.encoders.BasicImageEncoder"
         
     | 
| 14 | 
         
            +
              },
         
     | 
| 15 | 
         
            +
              "interpolate_mode": "linear",
         
     | 
| 16 | 
         
            +
              "llm_cfg": {
         
     | 
| 17 | 
         
            +
                "_attn_implementation_autoset": false,
         
     | 
| 18 | 
         
            +
                "_name_or_path": "./LongVILA-R1-7B/llm",
         
     | 
| 19 | 
         
            +
                "add_cross_attention": false,
         
     | 
| 20 | 
         
            +
                "architectures": [
         
     | 
| 21 | 
         
            +
                  "Qwen2ForCausalLM"
         
     | 
| 22 | 
         
            +
                ],
         
     | 
| 23 | 
         
            +
                "attention_dropout": 0.0,
         
     | 
| 24 | 
         
            +
                "bad_words_ids": null,
         
     | 
| 25 | 
         
            +
                "begin_suppress_tokens": null,
         
     | 
| 26 | 
         
            +
                "bos_token_id": 151643,
         
     | 
| 27 | 
         
            +
                "chunk_size_feed_forward": 0,
         
     | 
| 28 | 
         
            +
                "cross_attention_hidden_size": null,
         
     | 
| 29 | 
         
            +
                "decoder_start_token_id": null,
         
     | 
| 30 | 
         
            +
                "diversity_penalty": 0.0,
         
     | 
| 31 | 
         
            +
                "do_sample": false,
         
     | 
| 32 | 
         
            +
                "early_stopping": false,
         
     | 
| 33 | 
         
            +
                "encoder_no_repeat_ngram_size": 0,
         
     | 
| 34 | 
         
            +
                "eos_token_id": 151645,
         
     | 
| 35 | 
         
            +
                "exponential_decay_length_penalty": null,
         
     | 
| 36 | 
         
            +
                "finetuning_task": null,
         
     | 
| 37 | 
         
            +
                "forced_bos_token_id": null,
         
     | 
| 38 | 
         
            +
                "forced_eos_token_id": null,
         
     | 
| 39 | 
         
            +
                "hidden_act": "silu",
         
     | 
| 40 | 
         
            +
                "hidden_size": 3584,
         
     | 
| 41 | 
         
            +
                "id2label": {
         
     | 
| 42 | 
         
            +
                  "0": "LABEL_0",
         
     | 
| 43 | 
         
            +
                  "1": "LABEL_1"
         
     | 
| 44 | 
         
            +
                },
         
     | 
| 45 | 
         
            +
                "initializer_range": 0.02,
         
     | 
| 46 | 
         
            +
                "intermediate_size": 18944,
         
     | 
| 47 | 
         
            +
                "is_decoder": false,
         
     | 
| 48 | 
         
            +
                "is_encoder_decoder": false,
         
     | 
| 49 | 
         
            +
                "label2id": {
         
     | 
| 50 | 
         
            +
                  "LABEL_0": 0,
         
     | 
| 51 | 
         
            +
                  "LABEL_1": 1
         
     | 
| 52 | 
         
            +
                },
         
     | 
| 53 | 
         
            +
                "length_penalty": 1.0,
         
     | 
| 54 | 
         
            +
                "max_length": 20,
         
     | 
| 55 | 
         
            +
                "max_position_embeddings": 32768,
         
     | 
| 56 | 
         
            +
                "max_window_layers": 28,
         
     | 
| 57 | 
         
            +
                "min_length": 0,
         
     | 
| 58 | 
         
            +
                "model_max_length": 32768,
         
     | 
| 59 | 
         
            +
                "model_type": "qwen2",
         
     | 
| 60 | 
         
            +
                "no_repeat_ngram_size": 0,
         
     | 
| 61 | 
         
            +
                "num_attention_heads": 28,
         
     | 
| 62 | 
         
            +
                "num_beam_groups": 1,
         
     | 
| 63 | 
         
            +
                "num_beams": 1,
         
     | 
| 64 | 
         
            +
                "num_hidden_layers": 28,
         
     | 
| 65 | 
         
            +
                "num_key_value_heads": 4,
         
     | 
| 66 | 
         
            +
                "num_return_sequences": 1,
         
     | 
| 67 | 
         
            +
                "output_attentions": false,
         
     | 
| 68 | 
         
            +
                "output_hidden_states": false,
         
     | 
| 69 | 
         
            +
                "output_scores": false,
         
     | 
| 70 | 
         
            +
                "pad_token_id": null,
         
     | 
| 71 | 
         
            +
                "prefix": null,
         
     | 
| 72 | 
         
            +
                "problem_type": null,
         
     | 
| 73 | 
         
            +
                "pruned_heads": {},
         
     | 
| 74 | 
         
            +
                "remove_invalid_values": false,
         
     | 
| 75 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 76 | 
         
            +
                "return_dict": true,
         
     | 
| 77 | 
         
            +
                "return_dict_in_generate": false,
         
     | 
| 78 | 
         
            +
                "rms_norm_eps": 1e-06,
         
     | 
| 79 | 
         
            +
                "rope_scaling": null,
         
     | 
| 80 | 
         
            +
                "rope_theta": 1000000.0,
         
     | 
| 81 | 
         
            +
                "sep_token_id": null,
         
     | 
| 82 | 
         
            +
                "sliding_window": null,
         
     | 
| 83 | 
         
            +
                "suppress_tokens": null,
         
     | 
| 84 | 
         
            +
                "task_specific_params": null,
         
     | 
| 85 | 
         
            +
                "temperature": 1.0,
         
     | 
| 86 | 
         
            +
                "tf_legacy_loss": false,
         
     | 
| 87 | 
         
            +
                "tie_encoder_decoder": false,
         
     | 
| 88 | 
         
            +
                "tie_word_embeddings": false,
         
     | 
| 89 | 
         
            +
                "tokenizer_class": null,
         
     | 
| 90 | 
         
            +
                "tokenizer_model_max_length": 4096,
         
     | 
| 91 | 
         
            +
                "tokenizer_padding_side": "right",
         
     | 
| 92 | 
         
            +
                "top_k": 50,
         
     | 
| 93 | 
         
            +
                "top_p": 1.0,
         
     | 
| 94 | 
         
            +
                "torch_dtype": "bfloat16",
         
     | 
| 95 | 
         
            +
                "torchscript": false,
         
     | 
| 96 | 
         
            +
                "typical_p": 1.0,
         
     | 
| 97 | 
         
            +
                "use_bfloat16": false,
         
     | 
| 98 | 
         
            +
                "use_cache": false,
         
     | 
| 99 | 
         
            +
                "use_sliding_window": false,
         
     | 
| 100 | 
         
            +
                "vocab_size": 151651
         
     | 
| 101 | 
         
            +
              },
         
     | 
| 102 | 
         
            +
              "mm_hidden_size": 1152,
         
     | 
| 103 | 
         
            +
              "mm_projector": "mlp_downsample_2x2_fix",
         
     | 
| 104 | 
         
            +
              "mm_projector_cfg": {
         
     | 
| 105 | 
         
            +
                "_attn_implementation_autoset": false,
         
     | 
| 106 | 
         
            +
                "_name_or_path": "./LongVILA-R1-7B/mm_projector",
         
     | 
| 107 | 
         
            +
                "add_cross_attention": false,
         
     | 
| 108 | 
         
            +
                "architectures": [
         
     | 
| 109 | 
         
            +
                  "MultimodalProjector"
         
     | 
| 110 | 
         
            +
                ],
         
     | 
| 111 | 
         
            +
                "bad_words_ids": null,
         
     | 
| 112 | 
         
            +
                "begin_suppress_tokens": null,
         
     | 
| 113 | 
         
            +
                "bos_token_id": null,
         
     | 
| 114 | 
         
            +
                "chunk_size_feed_forward": 0,
         
     | 
| 115 | 
         
            +
                "cross_attention_hidden_size": null,
         
     | 
| 116 | 
         
            +
                "decoder_start_token_id": null,
         
     | 
| 117 | 
         
            +
                "diversity_penalty": 0.0,
         
     | 
| 118 | 
         
            +
                "do_sample": false,
         
     | 
| 119 | 
         
            +
                "early_stopping": false,
         
     | 
| 120 | 
         
            +
                "encoder_no_repeat_ngram_size": 0,
         
     | 
| 121 | 
         
            +
                "eos_token_id": null,
         
     | 
| 122 | 
         
            +
                "exponential_decay_length_penalty": null,
         
     | 
| 123 | 
         
            +
                "finetuning_task": null,
         
     | 
| 124 | 
         
            +
                "forced_bos_token_id": null,
         
     | 
| 125 | 
         
            +
                "forced_eos_token_id": null,
         
     | 
| 126 | 
         
            +
                "id2label": {
         
     | 
| 127 | 
         
            +
                  "0": "LABEL_0",
         
     | 
| 128 | 
         
            +
                  "1": "LABEL_1"
         
     | 
| 129 | 
         
            +
                },
         
     | 
| 130 | 
         
            +
                "is_decoder": false,
         
     | 
| 131 | 
         
            +
                "is_encoder_decoder": false,
         
     | 
| 132 | 
         
            +
                "label2id": {
         
     | 
| 133 | 
         
            +
                  "LABEL_0": 0,
         
     | 
| 134 | 
         
            +
                  "LABEL_1": 1
         
     | 
| 135 | 
         
            +
                },
         
     | 
| 136 | 
         
            +
                "length_penalty": 1.0,
         
     | 
| 137 | 
         
            +
                "max_length": 20,
         
     | 
| 138 | 
         
            +
                "min_length": 0,
         
     | 
| 139 | 
         
            +
                "mm_projector_type": "mlp_downsample_2x2_fix",
         
     | 
| 140 | 
         
            +
                "model_type": "v2l_projector",
         
     | 
| 141 | 
         
            +
                "no_repeat_ngram_size": 0,
         
     | 
| 142 | 
         
            +
                "num_beam_groups": 1,
         
     | 
| 143 | 
         
            +
                "num_beams": 1,
         
     | 
| 144 | 
         
            +
                "num_return_sequences": 1,
         
     | 
| 145 | 
         
            +
                "output_attentions": false,
         
     | 
| 146 | 
         
            +
                "output_hidden_states": false,
         
     | 
| 147 | 
         
            +
                "output_scores": false,
         
     | 
| 148 | 
         
            +
                "pad_token_id": null,
         
     | 
| 149 | 
         
            +
                "prefix": null,
         
     | 
| 150 | 
         
            +
                "problem_type": null,
         
     | 
| 151 | 
         
            +
                "pruned_heads": {},
         
     | 
| 152 | 
         
            +
                "remove_invalid_values": false,
         
     | 
| 153 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 154 | 
         
            +
                "return_dict": true,
         
     | 
| 155 | 
         
            +
                "return_dict_in_generate": false,
         
     | 
| 156 | 
         
            +
                "sep_token_id": null,
         
     | 
| 157 | 
         
            +
                "suppress_tokens": null,
         
     | 
| 158 | 
         
            +
                "task_specific_params": null,
         
     | 
| 159 | 
         
            +
                "temperature": 1.0,
         
     | 
| 160 | 
         
            +
                "tf_legacy_loss": false,
         
     | 
| 161 | 
         
            +
                "tie_encoder_decoder": false,
         
     | 
| 162 | 
         
            +
                "tie_word_embeddings": true,
         
     | 
| 163 | 
         
            +
                "tokenizer_class": null,
         
     | 
| 164 | 
         
            +
                "top_k": 50,
         
     | 
| 165 | 
         
            +
                "top_p": 1.0,
         
     | 
| 166 | 
         
            +
                "torch_dtype": "bfloat16",
         
     | 
| 167 | 
         
            +
                "torchscript": false,
         
     | 
| 168 | 
         
            +
                "typical_p": 1.0,
         
     | 
| 169 | 
         
            +
                "use_bfloat16": false
         
     | 
| 170 | 
         
            +
              },
         
     | 
| 171 | 
         
            +
              "mm_projector_lr": null,
         
     | 
| 172 | 
         
            +
              "mm_use_im_patch_token": false,
         
     | 
| 173 | 
         
            +
              "mm_use_im_start_end": false,
         
     | 
| 174 | 
         
            +
              "mm_vision_select_feature": "cls_patch",
         
     | 
| 175 | 
         
            +
              "mm_vision_select_layer": -2,
         
     | 
| 176 | 
         
            +
              "model_dtype": "torch.bfloat16",
         
     | 
| 177 | 
         
            +
              "model_name_or_path": "./LongVILA-R1-7B",
         
     | 
| 178 | 
         
            +
              "model_type": "vila",
         
     | 
| 179 | 
         
            +
              "num_time_tokens": 0,
         
     | 
| 180 | 
         
            +
              "num_video_frames": 256,
         
     | 
| 181 | 
         
            +
              "resume_path": "./LongVILA-R1-7B",
         
     | 
| 182 | 
         
            +
              "s2": false,
         
     | 
| 183 | 
         
            +
              "s2_max_split_size": 336,
         
     | 
| 184 | 
         
            +
              "s2_scales": "336,672,1008",
         
     | 
| 185 | 
         
            +
              "soft_ce_std": 1.0,
         
     | 
| 186 | 
         
            +
              "time_token_format": "<t{t}>",
         
     | 
| 187 | 
         
            +
              "time_token_ids": [],
         
     | 
| 188 | 
         
            +
              "transformers_version": "4.46.2",
         
     | 
| 189 | 
         
            +
              "tune_language_model": true,
         
     | 
| 190 | 
         
            +
              "tune_mm_projector": true,
         
     | 
| 191 | 
         
            +
              "tune_vision_tower": true,
         
     | 
| 192 | 
         
            +
              "version": "2.0",
         
     | 
| 193 | 
         
            +
              "video_encoder": {
         
     | 
| 194 | 
         
            +
                "_target_": "llava.model.encoders.TSPVideoEncoder",
         
     | 
| 195 | 
         
            +
                "pool_sizes": [
         
     | 
| 196 | 
         
            +
                  [
         
     | 
| 197 | 
         
            +
                    8,
         
     | 
| 198 | 
         
            +
                    1,
         
     | 
| 199 | 
         
            +
                    1
         
     | 
| 200 | 
         
            +
                  ]
         
     | 
| 201 | 
         
            +
                ]
         
     | 
| 202 | 
         
            +
              },
         
     | 
| 203 | 
         
            +
              "video_max_tiles": 1,
         
     | 
| 204 | 
         
            +
              "vision_resolution": -1,
         
     | 
| 205 | 
         
            +
              "vision_tower": "Efficient-Large-Model/paligemma-siglip-so400m-patch14-448",
         
     | 
| 206 | 
         
            +
              "vision_tower_cfg": {
         
     | 
| 207 | 
         
            +
                "_attn_implementation_autoset": false,
         
     | 
| 208 | 
         
            +
                "_name_or_path": "./LongVILA-R1-7B/vision_tower",
         
     | 
| 209 | 
         
            +
                "add_cross_attention": false,
         
     | 
| 210 | 
         
            +
                "architectures": [
         
     | 
| 211 | 
         
            +
                  "SiglipVisionModel"
         
     | 
| 212 | 
         
            +
                ],
         
     | 
| 213 | 
         
            +
                "attention_dropout": 0.0,
         
     | 
| 214 | 
         
            +
                "bad_words_ids": null,
         
     | 
| 215 | 
         
            +
                "begin_suppress_tokens": null,
         
     | 
| 216 | 
         
            +
                "bos_token_id": null,
         
     | 
| 217 | 
         
            +
                "chunk_size_feed_forward": 0,
         
     | 
| 218 | 
         
            +
                "cross_attention_hidden_size": null,
         
     | 
| 219 | 
         
            +
                "decoder_start_token_id": null,
         
     | 
| 220 | 
         
            +
                "diversity_penalty": 0.0,
         
     | 
| 221 | 
         
            +
                "do_sample": false,
         
     | 
| 222 | 
         
            +
                "early_stopping": false,
         
     | 
| 223 | 
         
            +
                "encoder_no_repeat_ngram_size": 0,
         
     | 
| 224 | 
         
            +
                "eos_token_id": null,
         
     | 
| 225 | 
         
            +
                "exponential_decay_length_penalty": null,
         
     | 
| 226 | 
         
            +
                "finetuning_task": null,
         
     | 
| 227 | 
         
            +
                "forced_bos_token_id": null,
         
     | 
| 228 | 
         
            +
                "forced_eos_token_id": null,
         
     | 
| 229 | 
         
            +
                "hidden_act": "gelu_pytorch_tanh",
         
     | 
| 230 | 
         
            +
                "hidden_size": 1152,
         
     | 
| 231 | 
         
            +
                "id2label": {
         
     | 
| 232 | 
         
            +
                  "0": "LABEL_0",
         
     | 
| 233 | 
         
            +
                  "1": "LABEL_1"
         
     | 
| 234 | 
         
            +
                },
         
     | 
| 235 | 
         
            +
                "image_size": 448,
         
     | 
| 236 | 
         
            +
                "intermediate_size": 4304,
         
     | 
| 237 | 
         
            +
                "is_decoder": false,
         
     | 
| 238 | 
         
            +
                "is_encoder_decoder": false,
         
     | 
| 239 | 
         
            +
                "label2id": {
         
     | 
| 240 | 
         
            +
                  "LABEL_0": 0,
         
     | 
| 241 | 
         
            +
                  "LABEL_1": 1
         
     | 
| 242 | 
         
            +
                },
         
     | 
| 243 | 
         
            +
                "layer_norm_eps": 1e-06,
         
     | 
| 244 | 
         
            +
                "length_penalty": 1.0,
         
     | 
| 245 | 
         
            +
                "max_length": 20,
         
     | 
| 246 | 
         
            +
                "min_length": 0,
         
     | 
| 247 | 
         
            +
                "model_type": "siglip_vision_model",
         
     | 
| 248 | 
         
            +
                "no_repeat_ngram_size": 0,
         
     | 
| 249 | 
         
            +
                "num_attention_heads": 16,
         
     | 
| 250 | 
         
            +
                "num_beam_groups": 1,
         
     | 
| 251 | 
         
            +
                "num_beams": 1,
         
     | 
| 252 | 
         
            +
                "num_channels": 3,
         
     | 
| 253 | 
         
            +
                "num_hidden_layers": 27,
         
     | 
| 254 | 
         
            +
                "num_image_tokens": 256,
         
     | 
| 255 | 
         
            +
                "num_return_sequences": 1,
         
     | 
| 256 | 
         
            +
                "output_attentions": false,
         
     | 
| 257 | 
         
            +
                "output_hidden_states": false,
         
     | 
| 258 | 
         
            +
                "output_scores": false,
         
     | 
| 259 | 
         
            +
                "pad_token_id": null,
         
     | 
| 260 | 
         
            +
                "patch_size": 14,
         
     | 
| 261 | 
         
            +
                "prefix": null,
         
     | 
| 262 | 
         
            +
                "problem_type": null,
         
     | 
| 263 | 
         
            +
                "projection_dim": 2048,
         
     | 
| 264 | 
         
            +
                "projector_hidden_act": "gelu_fast",
         
     | 
| 265 | 
         
            +
                "pruned_heads": {},
         
     | 
| 266 | 
         
            +
                "remove_invalid_values": false,
         
     | 
| 267 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 268 | 
         
            +
                "return_dict": true,
         
     | 
| 269 | 
         
            +
                "return_dict_in_generate": false,
         
     | 
| 270 | 
         
            +
                "sep_token_id": null,
         
     | 
| 271 | 
         
            +
                "suppress_tokens": null,
         
     | 
| 272 | 
         
            +
                "task_specific_params": null,
         
     | 
| 273 | 
         
            +
                "temperature": 1.0,
         
     | 
| 274 | 
         
            +
                "tf_legacy_loss": false,
         
     | 
| 275 | 
         
            +
                "tie_encoder_decoder": false,
         
     | 
| 276 | 
         
            +
                "tie_word_embeddings": true,
         
     | 
| 277 | 
         
            +
                "tokenizer_class": null,
         
     | 
| 278 | 
         
            +
                "top_k": 50,
         
     | 
| 279 | 
         
            +
                "top_p": 1.0,
         
     | 
| 280 | 
         
            +
                "torch_dtype": "bfloat16",
         
     | 
| 281 | 
         
            +
                "torchscript": false,
         
     | 
| 282 | 
         
            +
                "typical_p": 1.0,
         
     | 
| 283 | 
         
            +
                "use_bfloat16": false,
         
     | 
| 284 | 
         
            +
                "vision_use_head": false
         
     | 
| 285 | 
         
            +
              },
         
     | 
| 286 | 
         
            +
              "vision_tower_lr": null,
         
     | 
| 287 | 
         
            +
              "weight_memory_efficient": true,
         
     | 
| 288 | 
         
            +
              "xvila_mode": false,
         
     | 
| 289 | 
         
            +
              "auto_map": {
         
     | 
| 290 | 
         
            +
                "AutoProcessor": "auto_processor.VILAProcessor",
         
     | 
| 291 | 
         
            +
                "AutoConfig": "modeling_vila.VILAConfig",
         
     | 
| 292 | 
         
            +
                "AutoModel": "modeling_vila.VILAForCausalLM",
         
     | 
| 293 | 
         
            +
                "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM"
         
     | 
| 294 | 
         
            +
              }
         
     | 
| 295 | 
         
            +
            }
         
     | 
    	
        configuration_vila.py
    ADDED
    
    | 
         @@ -0,0 +1,92 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import os.path as osp
         
     | 
| 5 | 
         
            +
            from copy import deepcopy
         
     | 
| 6 | 
         
            +
            from threading import Thread
         
     | 
| 7 | 
         
            +
            from typing import List, Optional
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import torchvision
         
     | 
| 11 | 
         
            +
            from PIL import Image
         
     | 
| 12 | 
         
            +
            from transformers import (
         
     | 
| 13 | 
         
            +
                AutoProcessor,
         
     | 
| 14 | 
         
            +
                PretrainedConfig,
         
     | 
| 15 | 
         
            +
                PreTrainedModel,
         
     | 
| 16 | 
         
            +
                Qwen2Config,
         
     | 
| 17 | 
         
            +
                Qwen2ForCausalLM,
         
     | 
| 18 | 
         
            +
                Qwen2PreTrainedModel,
         
     | 
| 19 | 
         
            +
                TextIteratorStreamer,
         
     | 
| 20 | 
         
            +
            )
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class VILAConfig(PretrainedConfig):
         
     | 
| 24 | 
         
            +
                model_type = "vila"
         
     | 
| 25 | 
         
            +
                keys_to_ignore_at_inference = ["past_key_values"]
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def __init__(
         
     | 
| 28 | 
         
            +
                    self,
         
     | 
| 29 | 
         
            +
                    llm_cfg=None,
         
     | 
| 30 | 
         
            +
                    vision_tower_cfg=None,
         
     | 
| 31 | 
         
            +
                    mm_projector_cfg=None,
         
     | 
| 32 | 
         
            +
                    architectures=None,
         
     | 
| 33 | 
         
            +
                    resume_path=None,
         
     | 
| 34 | 
         
            +
                    hidden_size=None,
         
     | 
| 35 | 
         
            +
                    mm_hidden_size=None,
         
     | 
| 36 | 
         
            +
                    image_aspect_ratio=None,
         
     | 
| 37 | 
         
            +
                    num_video_frames=None,
         
     | 
| 38 | 
         
            +
                    fps=None,
         
     | 
| 39 | 
         
            +
                    mm_vision_select_layer=None,
         
     | 
| 40 | 
         
            +
                    mm_vision_select_feature=None,
         
     | 
| 41 | 
         
            +
                    mm_use_im_start_end=False,
         
     | 
| 42 | 
         
            +
                    mm_use_im_patch_token=False,
         
     | 
| 43 | 
         
            +
                    mm_projector_lr=None,
         
     | 
| 44 | 
         
            +
                    vision_tower_lr=None,
         
     | 
| 45 | 
         
            +
                    vision_resolution=None,
         
     | 
| 46 | 
         
            +
                    interpolate_mode=None,
         
     | 
| 47 | 
         
            +
                    s2=None,
         
     | 
| 48 | 
         
            +
                    dynamic_s2=None,
         
     | 
| 49 | 
         
            +
                    s2_scales=None,
         
     | 
| 50 | 
         
            +
                    s2_max_split_size=None,
         
     | 
| 51 | 
         
            +
                    s2_resize_output_to_scale_idx=0,
         
     | 
| 52 | 
         
            +
                    min_tiles: Optional[int] = 1,
         
     | 
| 53 | 
         
            +
                    max_tiles: Optional[int] = 12,
         
     | 
| 54 | 
         
            +
                    num_time_tokens=None,
         
     | 
| 55 | 
         
            +
                    time_token_format=None,
         
     | 
| 56 | 
         
            +
                    image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
         
     | 
| 57 | 
         
            +
                    video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
         
     | 
| 58 | 
         
            +
                    **kwargs,
         
     | 
| 59 | 
         
            +
                ):
         
     | 
| 60 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    self.architectures = architectures
         
     | 
| 63 | 
         
            +
                    self.llm_cfg = llm_cfg
         
     | 
| 64 | 
         
            +
                    self.vision_tower_cfg = vision_tower_cfg
         
     | 
| 65 | 
         
            +
                    self.mm_projector_cfg = mm_projector_cfg
         
     | 
| 66 | 
         
            +
                    self.resume_path = resume_path
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 69 | 
         
            +
                    self.mm_hidden_size = mm_hidden_size
         
     | 
| 70 | 
         
            +
                    self.image_aspect_ratio = image_aspect_ratio
         
     | 
| 71 | 
         
            +
                    self.num_video_frames = num_video_frames
         
     | 
| 72 | 
         
            +
                    self.fps = fps
         
     | 
| 73 | 
         
            +
                    self.mm_vision_select_layer = mm_vision_select_layer
         
     | 
| 74 | 
         
            +
                    self.mm_vision_select_feature = mm_vision_select_feature
         
     | 
| 75 | 
         
            +
                    self.mm_use_im_start_end = mm_use_im_start_end
         
     | 
| 76 | 
         
            +
                    self.mm_use_im_patch_token = mm_use_im_patch_token
         
     | 
| 77 | 
         
            +
                    self.mm_projector_lr = mm_projector_lr
         
     | 
| 78 | 
         
            +
                    self.vision_tower_lr = vision_tower_lr
         
     | 
| 79 | 
         
            +
                    self.vision_resolution = vision_resolution
         
     | 
| 80 | 
         
            +
                    self.interpolate_mode = interpolate_mode
         
     | 
| 81 | 
         
            +
                    self.s2 = s2
         
     | 
| 82 | 
         
            +
                    self.dynamic_s2 = dynamic_s2
         
     | 
| 83 | 
         
            +
                    self.s2_scales = s2_scales
         
     | 
| 84 | 
         
            +
                    self.s2_max_split_size = s2_max_split_size
         
     | 
| 85 | 
         
            +
                    self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
         
     | 
| 86 | 
         
            +
                    self.min_tiles = min_tiles
         
     | 
| 87 | 
         
            +
                    self.max_tiles = max_tiles
         
     | 
| 88 | 
         
            +
                    self.num_time_tokens = num_time_tokens
         
     | 
| 89 | 
         
            +
                    self.time_token_format = time_token_format
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    self.image_encoder = image_encoder
         
     | 
| 92 | 
         
            +
                    self.video_encoder = video_encoder
         
     | 
    	
        constants.py
    ADDED
    
    | 
         @@ -0,0 +1,83 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #      http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # This file is modified from https://github.com/haotian-liu/LLaVA/
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            CONTROLLER_HEART_BEAT_EXPIRATION = 30
         
     | 
| 20 | 
         
            +
            WORKER_HEART_BEAT_INTERVAL = 15
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            LOGDIR = "."
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            # Model Constants
         
     | 
| 25 | 
         
            +
            IGNORE_INDEX = -100
         
     | 
| 26 | 
         
            +
            DEFAULT_IMAGE_TOKEN = "<image>"
         
     | 
| 27 | 
         
            +
            DEFAULT_SOUND_TOKEN = "<sound>"
         
     | 
| 28 | 
         
            +
            DEFAULT_SPEECH_TOKEN = "<speech>"
         
     | 
| 29 | 
         
            +
            SENTINEL_TOKEN = "<vila/sentinel>"
         
     | 
| 30 | 
         
            +
            DEFAULT_IM_START_TOKEN = "<im_start>"
         
     | 
| 31 | 
         
            +
            DEFAULT_IM_END_TOKEN = "<im_end>"
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            SENTINEL_TOKEN = "<vila/sentinel>"
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            MEDIA_TOKENS = {
         
     | 
| 37 | 
         
            +
                "image": "<image>",
         
     | 
| 38 | 
         
            +
                "video": "<vila/video>",
         
     | 
| 39 | 
         
            +
                "speech": "<speech>",
         
     | 
| 40 | 
         
            +
                "sound": "<sound>",
         
     | 
| 41 | 
         
            +
            }
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            # <image> <vila/video> <vila/sentinel>
         
     | 
| 44 | 
         
            +
            """
         
     | 
| 45 | 
         
            +
            vila:
         
     | 
| 46 | 
         
            +
                151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 47 | 
         
            +
                151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 48 | 
         
            +
                151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 49 | 
         
            +
                151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 50 | 
         
            +
                151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 51 | 
         
            +
                151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 52 | 
         
            +
                151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 53 | 
         
            +
                151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            xvila:
         
     | 
| 56 | 
         
            +
                151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 57 | 
         
            +
                151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 58 | 
         
            +
                151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 59 | 
         
            +
                151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 60 | 
         
            +
                151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 61 | 
         
            +
                151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 62 | 
         
            +
                151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 63 | 
         
            +
                151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 64 | 
         
            +
                151651: AddedToken("<speech>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 65 | 
         
            +
                151652: AddedToken("<sound>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 66 | 
         
            +
                151653: AddedToken("<|image_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 67 | 
         
            +
                151654: AddedToken("<|image_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 68 | 
         
            +
                151655: AddedToken("<|video_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 69 | 
         
            +
                151656: AddedToken("<|video_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 70 | 
         
            +
                151657: AddedToken("<|speech_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 71 | 
         
            +
                151658: AddedToken("<|speech_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 72 | 
         
            +
                151659: AddedToken("<|sound_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 73 | 
         
            +
                151660: AddedToken("<|sound_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
         
     | 
| 74 | 
         
            +
            """
         
     | 
| 75 | 
         
            +
            MM_BOS_EOS_TOKENS = {
         
     | 
| 76 | 
         
            +
                "image": ["<|image_bos|>", "<|image_eos|>"],
         
     | 
| 77 | 
         
            +
                "video": ["<|video_bos|>", "<|video_eos|>"],
         
     | 
| 78 | 
         
            +
                "speech": ["<|speech_bos|>", "<|speech_eos|>"],
         
     | 
| 79 | 
         
            +
                "sound": ["<|sound_bos|>", "<|sound_eos|>"],
         
     | 
| 80 | 
         
            +
            }
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            NUM_EXTRA_TOKENS_VILA = 8
         
     | 
| 83 | 
         
            +
            NUM_EXTRA_TOKENS_XVILA = 10
         
     | 
    	
        conversation.py
    ADDED
    
    | 
         @@ -0,0 +1,191 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #      http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
            # This file is modified from https://github.com/haotian-liu/LLaVA/
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import dataclasses
         
     | 
| 19 | 
         
            +
            from enum import Enum, auto
         
     | 
| 20 | 
         
            +
            from typing import List
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # from llava.utils.logging import logger
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class SeparatorStyle(Enum):
         
     | 
| 26 | 
         
            +
                """Different separator style."""
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                AUTO = auto()
         
     | 
| 29 | 
         
            +
                TWO = auto()
         
     | 
| 30 | 
         
            +
                MPT = auto()
         
     | 
| 31 | 
         
            +
                PLAIN = auto()
         
     | 
| 32 | 
         
            +
                LLAMA_3 = auto()
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @dataclasses.dataclass
         
     | 
| 36 | 
         
            +
            class Conversation:
         
     | 
| 37 | 
         
            +
                """A class that keeps all conversation history."""
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                system: str
         
     | 
| 40 | 
         
            +
                roles: List[str]
         
     | 
| 41 | 
         
            +
                messages: List[List[str]]
         
     | 
| 42 | 
         
            +
                sep_style: SeparatorStyle = SeparatorStyle.AUTO
         
     | 
| 43 | 
         
            +
                sep: str = "###"
         
     | 
| 44 | 
         
            +
                sep2: str = None
         
     | 
| 45 | 
         
            +
                version: str = "Unknown"
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def get_prompt(self):
         
     | 
| 48 | 
         
            +
                    messages = self.messages
         
     | 
| 49 | 
         
            +
                    if len(messages) > 0 and type(messages[0][1]) is tuple:
         
     | 
| 50 | 
         
            +
                        messages = self.messages.copy()
         
     | 
| 51 | 
         
            +
                        init_role, init_msg = messages[0].copy()
         
     | 
| 52 | 
         
            +
                        init_msg = init_msg[0].replace("<image>", "").strip()
         
     | 
| 53 | 
         
            +
                        messages[0] = (init_role, "<image>\n" + init_msg)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    if self.sep_style == SeparatorStyle.TWO:
         
     | 
| 56 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 57 | 
         
            +
                        ret = self.system + seps[0]
         
     | 
| 58 | 
         
            +
                        for i, (role, message) in enumerate(messages):
         
     | 
| 59 | 
         
            +
                            if message:
         
     | 
| 60 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 61 | 
         
            +
                                    message, _, _ = message
         
     | 
| 62 | 
         
            +
                                ret += role + ": " + message + seps[i % 2]
         
     | 
| 63 | 
         
            +
                            else:
         
     | 
| 64 | 
         
            +
                                ret += role + ":"
         
     | 
| 65 | 
         
            +
                    elif self.sep_style == SeparatorStyle.LLAMA_3:
         
     | 
| 66 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 67 | 
         
            +
                        for rid, (role, message) in enumerate(messages):
         
     | 
| 68 | 
         
            +
                            if message:
         
     | 
| 69 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 70 | 
         
            +
                                    message = message[0]
         
     | 
| 71 | 
         
            +
                                sep = self.sep if rid < len(messages) - 1 else self.sep2
         
     | 
| 72 | 
         
            +
                                ret += role + message + sep
         
     | 
| 73 | 
         
            +
                            else:
         
     | 
| 74 | 
         
            +
                                ret += role
         
     | 
| 75 | 
         
            +
                    elif self.sep_style == SeparatorStyle.MPT:
         
     | 
| 76 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 77 | 
         
            +
                        for role, message in messages:
         
     | 
| 78 | 
         
            +
                            if message:
         
     | 
| 79 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 80 | 
         
            +
                                    message, _, _ = message
         
     | 
| 81 | 
         
            +
                                ret += role + message + self.sep
         
     | 
| 82 | 
         
            +
                            else:
         
     | 
| 83 | 
         
            +
                                ret += role
         
     | 
| 84 | 
         
            +
                    elif self.sep_style == SeparatorStyle.PLAIN:
         
     | 
| 85 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 86 | 
         
            +
                        ret = self.system
         
     | 
| 87 | 
         
            +
                        for i, (role, message) in enumerate(messages):
         
     | 
| 88 | 
         
            +
                            if message:
         
     | 
| 89 | 
         
            +
                                if type(message) is tuple:
         
     | 
| 90 | 
         
            +
                                    message, _, _ = message
         
     | 
| 91 | 
         
            +
                                ret += message + seps[i % 2]
         
     | 
| 92 | 
         
            +
                            else:
         
     | 
| 93 | 
         
            +
                                ret += ""
         
     | 
| 94 | 
         
            +
                    else:
         
     | 
| 95 | 
         
            +
                        raise ValueError(f"Invalid style: {self.sep_style}")
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    return ret
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def append_message(self, role, message):
         
     | 
| 100 | 
         
            +
                    self.messages.append([role, message])
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def copy(self):
         
     | 
| 103 | 
         
            +
                    return Conversation(
         
     | 
| 104 | 
         
            +
                        system=self.system,
         
     | 
| 105 | 
         
            +
                        roles=self.roles,
         
     | 
| 106 | 
         
            +
                        messages=[[x, y] for x, y in self.messages],
         
     | 
| 107 | 
         
            +
                        sep_style=self.sep_style,
         
     | 
| 108 | 
         
            +
                        sep=self.sep,
         
     | 
| 109 | 
         
            +
                        sep2=self.sep2,
         
     | 
| 110 | 
         
            +
                        version=self.version,
         
     | 
| 111 | 
         
            +
                    )
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            conv_auto = Conversation(
         
     | 
| 115 | 
         
            +
                system="",
         
     | 
| 116 | 
         
            +
                roles=("", ""),
         
     | 
| 117 | 
         
            +
                messages=(),
         
     | 
| 118 | 
         
            +
                sep_style=SeparatorStyle.AUTO,
         
     | 
| 119 | 
         
            +
                sep="\n",
         
     | 
| 120 | 
         
            +
            )
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            conv_vicuna_v1 = Conversation(
         
     | 
| 123 | 
         
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         
     | 
| 124 | 
         
            +
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
         
     | 
| 125 | 
         
            +
                roles=("USER", "ASSISTANT"),
         
     | 
| 126 | 
         
            +
                version="v1",
         
     | 
| 127 | 
         
            +
                messages=(),
         
     | 
| 128 | 
         
            +
                sep_style=SeparatorStyle.TWO,
         
     | 
| 129 | 
         
            +
                sep=" ",
         
     | 
| 130 | 
         
            +
                sep2="</s>",
         
     | 
| 131 | 
         
            +
            )
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            conv_llava_plain = Conversation(
         
     | 
| 134 | 
         
            +
                system="",
         
     | 
| 135 | 
         
            +
                roles=("", ""),
         
     | 
| 136 | 
         
            +
                messages=(),
         
     | 
| 137 | 
         
            +
                sep_style=SeparatorStyle.PLAIN,
         
     | 
| 138 | 
         
            +
                sep="\n",
         
     | 
| 139 | 
         
            +
            )
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            hermes_2 = Conversation(
         
     | 
| 142 | 
         
            +
                system="<|im_start|>system\nAnswer the questions.",
         
     | 
| 143 | 
         
            +
                roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
         
     | 
| 144 | 
         
            +
                sep_style=SeparatorStyle.MPT,
         
     | 
| 145 | 
         
            +
                sep="<|im_end|>",
         
     | 
| 146 | 
         
            +
                messages=(),
         
     | 
| 147 | 
         
            +
                version="hermes-2",
         
     | 
| 148 | 
         
            +
            )
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
         
     | 
| 151 | 
         
            +
            llama_3_chat = Conversation(
         
     | 
| 152 | 
         
            +
                system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
         
     | 
| 153 | 
         
            +
                "You are able to understand the visual content that the user provides, "
         
     | 
| 154 | 
         
            +
                "and assist the user with a variety of tasks using natural language.",
         
     | 
| 155 | 
         
            +
                roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
         
     | 
| 156 | 
         
            +
                version="llama_v3",
         
     | 
| 157 | 
         
            +
                messages=(),
         
     | 
| 158 | 
         
            +
                sep_style=SeparatorStyle.LLAMA_3,
         
     | 
| 159 | 
         
            +
                sep="<|eot_id|>",
         
     | 
| 160 | 
         
            +
                sep2="<|end_of_text|>",
         
     | 
| 161 | 
         
            +
            )
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            default_conversation = conv_auto
         
     | 
| 165 | 
         
            +
            conv_templates = {
         
     | 
| 166 | 
         
            +
                "auto": conv_auto,
         
     | 
| 167 | 
         
            +
                "hermes-2": hermes_2,
         
     | 
| 168 | 
         
            +
                "llama_3": llama_3_chat,
         
     | 
| 169 | 
         
            +
                "v1": conv_vicuna_v1,
         
     | 
| 170 | 
         
            +
                "vicuna_v1": conv_vicuna_v1,
         
     | 
| 171 | 
         
            +
                "plain": conv_llava_plain,
         
     | 
| 172 | 
         
            +
            }
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            CONVERSATION_MODE_MAPPING = {
         
     | 
| 176 | 
         
            +
                "vila1.5-3b": "vicuna_v1",
         
     | 
| 177 | 
         
            +
                "vila1.5-8b": "llama_3",
         
     | 
| 178 | 
         
            +
                "vila1.5-13b": "vicuna_v1",
         
     | 
| 179 | 
         
            +
                "vila1.5-40b": "hermes-2",
         
     | 
| 180 | 
         
            +
                "llama-3": "llama_3",
         
     | 
| 181 | 
         
            +
                "llama3": "llama_3",
         
     | 
| 182 | 
         
            +
            }
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
            def auto_set_conversation_mode(model_name_or_path: str) -> str:
         
     | 
| 186 | 
         
            +
                global default_conversation
         
     | 
| 187 | 
         
            +
                for k, v in CONVERSATION_MODE_MAPPING.items():
         
     | 
| 188 | 
         
            +
                    if k in model_name_or_path.lower():
         
     | 
| 189 | 
         
            +
                        print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
         
     | 
| 190 | 
         
            +
                        default_conversation = conv_templates[v]
         
     | 
| 191 | 
         
            +
                        return
         
     | 
    	
        distributed.py
    ADDED
    
    | 
         @@ -0,0 +1,73 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
            from typing import Any, List, Optional
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from torch import distributed as dist
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            __all__ = [
         
     | 
| 8 | 
         
            +
                "init",
         
     | 
| 9 | 
         
            +
                "is_initialized",
         
     | 
| 10 | 
         
            +
                "size",
         
     | 
| 11 | 
         
            +
                "rank",
         
     | 
| 12 | 
         
            +
                "local_size",
         
     | 
| 13 | 
         
            +
                "local_rank",
         
     | 
| 14 | 
         
            +
                "is_main",
         
     | 
| 15 | 
         
            +
                "barrier",
         
     | 
| 16 | 
         
            +
                "gather",
         
     | 
| 17 | 
         
            +
                "all_gather",
         
     | 
| 18 | 
         
            +
            ]
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def init() -> None:
         
     | 
| 22 | 
         
            +
                if "RANK" not in os.environ:
         
     | 
| 23 | 
         
            +
                    warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
         
     | 
| 24 | 
         
            +
                    return
         
     | 
| 25 | 
         
            +
                dist.init_process_group(backend="nccl", init_method="env://")
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def is_initialized() -> bool:
         
     | 
| 29 | 
         
            +
                return dist.is_initialized()
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def size() -> int:
         
     | 
| 33 | 
         
            +
                return int(os.environ.get("WORLD_SIZE", 1))
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def rank() -> int:
         
     | 
| 37 | 
         
            +
                return int(os.environ.get("RANK", 0))
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def local_size() -> int:
         
     | 
| 41 | 
         
            +
                return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def local_rank() -> int:
         
     | 
| 45 | 
         
            +
                return int(os.environ.get("LOCAL_RANK", 0))
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def is_main() -> bool:
         
     | 
| 49 | 
         
            +
                return rank() == 0
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def barrier() -> None:
         
     | 
| 53 | 
         
            +
                dist.barrier()
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
         
     | 
| 57 | 
         
            +
                if not is_initialized():
         
     | 
| 58 | 
         
            +
                    return [obj]
         
     | 
| 59 | 
         
            +
                if is_main():
         
     | 
| 60 | 
         
            +
                    objs = [None for _ in range(size())]
         
     | 
| 61 | 
         
            +
                    dist.gather_object(obj, objs, dst=dst)
         
     | 
| 62 | 
         
            +
                    return objs
         
     | 
| 63 | 
         
            +
                else:
         
     | 
| 64 | 
         
            +
                    dist.gather_object(obj, dst=dst)
         
     | 
| 65 | 
         
            +
                    return None
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def all_gather(obj: Any) -> List[Any]:
         
     | 
| 69 | 
         
            +
                if not is_initialized():
         
     | 
| 70 | 
         
            +
                    return [obj]
         
     | 
| 71 | 
         
            +
                objs = [None for _ in range(size())]
         
     | 
| 72 | 
         
            +
                dist.all_gather_object(objs, obj)
         
     | 
| 73 | 
         
            +
                return objs
         
     | 
    	
        loss.py
    ADDED
    
    | 
         @@ -0,0 +1,48 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List, Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torch.nn.functional import cross_entropy
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from .constants import IGNORE_INDEX
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            __all__ = ["soft_cross_entropy"]
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def soft_cross_entropy(
         
     | 
| 12 | 
         
            +
                outputs: torch.Tensor,
         
     | 
| 13 | 
         
            +
                targets: torch.Tensor,
         
     | 
| 14 | 
         
            +
                soft_tokens: Union[torch.Tensor, List[int]],
         
     | 
| 15 | 
         
            +
                std: float = 1,
         
     | 
| 16 | 
         
            +
                ignore_index: int = IGNORE_INDEX,
         
     | 
| 17 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 18 | 
         
            +
                # Remove last token from outputs and first token from targets
         
     | 
| 19 | 
         
            +
                outputs = outputs[..., :-1, :].contiguous()
         
     | 
| 20 | 
         
            +
                targets = targets[..., 1:].contiguous()
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                # Flatten outputs and targets
         
     | 
| 23 | 
         
            +
                targets = targets.view(-1)
         
     | 
| 24 | 
         
            +
                outputs = outputs.view(targets.size(0), -1)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                # Remove outputs and targets with ignore_index
         
     | 
| 27 | 
         
            +
                indices = targets != ignore_index
         
     | 
| 28 | 
         
            +
                outputs = outputs[indices]
         
     | 
| 29 | 
         
            +
                targets = targets[indices]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                # Convert soft token IDs to tensor
         
     | 
| 32 | 
         
            +
                if isinstance(soft_tokens, list):
         
     | 
| 33 | 
         
            +
                    soft_tokens = torch.tensor(soft_tokens).to(targets)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                # Calculate loss for non-soft tokens
         
     | 
| 36 | 
         
            +
                indices = torch.isin(targets, soft_tokens, invert=True)
         
     | 
| 37 | 
         
            +
                loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                # Calculate loss for soft tokens
         
     | 
| 40 | 
         
            +
                indices = torch.isin(targets, soft_tokens)
         
     | 
| 41 | 
         
            +
                targets_indices = torch.zeros_like(outputs[indices])
         
     | 
| 42 | 
         
            +
                for k, target in enumerate(targets[indices]):
         
     | 
| 43 | 
         
            +
                    dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
         
     | 
| 44 | 
         
            +
                    targets_indices[k][soft_tokens] = dist / dist.sum()
         
     | 
| 45 | 
         
            +
                loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                # Return average loss
         
     | 
| 48 | 
         
            +
                return loss / targets.size(0)
         
     | 
    	
        media.py
    ADDED
    
    | 
         @@ -0,0 +1,130 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import glob
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from collections import defaultdict
         
     | 
| 4 | 
         
            +
            from typing import Any, Dict, List, Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import cv2
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import PIL
         
     | 
| 9 | 
         
            +
            import PIL.Image
         
     | 
| 10 | 
         
            +
            import requests
         
     | 
| 11 | 
         
            +
            from transformers import PretrainedConfig
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # from llava.constants import MEDIA_TOKENS
         
     | 
| 14 | 
         
            +
            # from llava.media import Image, Video
         
     | 
| 15 | 
         
            +
            # from llava.utils import make_list
         
     | 
| 16 | 
         
            +
            # from llava.utils.logging import logger
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            MEDIA_TOKENS = {
         
     | 
| 19 | 
         
            +
                "image": "<image>",
         
     | 
| 20 | 
         
            +
                "video": "<vila/video>",
         
     | 
| 21 | 
         
            +
            }
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class Media:
         
     | 
| 25 | 
         
            +
                pass
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class File(Media):
         
     | 
| 29 | 
         
            +
                def __init__(self, path: str) -> None:
         
     | 
| 30 | 
         
            +
                    self.path = path
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class Image(File):
         
     | 
| 34 | 
         
            +
                pass
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            class Video(File):
         
     | 
| 38 | 
         
            +
                pass
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def make_list(obj: Any) -> List:
         
     | 
| 42 | 
         
            +
                return obj if isinstance(obj, list) else [obj]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
         
     | 
| 46 | 
         
            +
                if isinstance(image, Image):
         
     | 
| 47 | 
         
            +
                    if image.path.startswith("http://") or image.path.startswith("https://"):
         
     | 
| 48 | 
         
            +
                        image = PIL.Image.open(requests.get(image.path, stream=True).raw)
         
     | 
| 49 | 
         
            +
                    else:
         
     | 
| 50 | 
         
            +
                        image = PIL.Image.open(image.path)
         
     | 
| 51 | 
         
            +
                return image
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
         
     | 
| 55 | 
         
            +
                # Load video frames from a directory
         
     | 
| 56 | 
         
            +
                if os.path.isdir(video_path):
         
     | 
| 57 | 
         
            +
                    frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
         
     | 
| 58 | 
         
            +
                    indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
         
     | 
| 59 | 
         
            +
                    return [PIL.Image.open(frame_paths[index]) for index in indices]
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                # Load video frames from a video file
         
     | 
| 62 | 
         
            +
                vidcap = cv2.VideoCapture(video_path)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                # Find the last frame as frame count might not be accurate
         
     | 
| 65 | 
         
            +
                frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
         
     | 
| 66 | 
         
            +
                while frame_count > 0:
         
     | 
| 67 | 
         
            +
                    vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
         
     | 
| 68 | 
         
            +
                    if vidcap.grab():
         
     | 
| 69 | 
         
            +
                        break
         
     | 
| 70 | 
         
            +
                    frame_count -= 1
         
     | 
| 71 | 
         
            +
                else:
         
     | 
| 72 | 
         
            +
                    raise ValueError(f"Video '{video_path}' has no frames.")
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                # Extract frames uniformly
         
     | 
| 75 | 
         
            +
                indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
         
     | 
| 76 | 
         
            +
                frames = {}
         
     | 
| 77 | 
         
            +
                for index in indices:
         
     | 
| 78 | 
         
            +
                    if index in frames:
         
     | 
| 79 | 
         
            +
                        continue
         
     | 
| 80 | 
         
            +
                    vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
         
     | 
| 81 | 
         
            +
                    success, frame = vidcap.read()
         
     | 
| 82 | 
         
            +
                    if not success:
         
     | 
| 83 | 
         
            +
                        print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
         
     | 
| 84 | 
         
            +
                        continue
         
     | 
| 85 | 
         
            +
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 86 | 
         
            +
                    frames[index] = PIL.Image.fromarray(frame)
         
     | 
| 87 | 
         
            +
                return [frames[index] for index in indices if index in frames]
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
         
     | 
| 91 | 
         
            +
                num_frames = config.num_video_frames
         
     | 
| 92 | 
         
            +
                video_path = video.path if isinstance(video, Video) else video["path"]
         
     | 
| 93 | 
         
            +
                frames = _load_video(video_path, num_frames=num_frames)
         
     | 
| 94 | 
         
            +
                return frames
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def extract_media(
         
     | 
| 98 | 
         
            +
                messages: List[Dict[str, Any]],
         
     | 
| 99 | 
         
            +
                config: Optional[PretrainedConfig] = None,
         
     | 
| 100 | 
         
            +
                draft: bool = False,
         
     | 
| 101 | 
         
            +
            ) -> Dict[str, List[Any]]:
         
     | 
| 102 | 
         
            +
                media = defaultdict(list)
         
     | 
| 103 | 
         
            +
                for message in messages:
         
     | 
| 104 | 
         
            +
                    text = ""
         
     | 
| 105 | 
         
            +
                    for part in make_list(message["value"]):
         
     | 
| 106 | 
         
            +
                        if isinstance(part, str):
         
     | 
| 107 | 
         
            +
                            for token in MEDIA_TOKENS.values():
         
     | 
| 108 | 
         
            +
                                if token in part:
         
     | 
| 109 | 
         
            +
                                    print(f"Media token '{token}' found in text: '{part}'. Removed.")
         
     | 
| 110 | 
         
            +
                                    part = part.replace(token, "").strip()
         
     | 
| 111 | 
         
            +
                            text += part
         
     | 
| 112 | 
         
            +
                        elif isinstance(part, (Image, PIL.Image.Image)):
         
     | 
| 113 | 
         
            +
                            if draft:
         
     | 
| 114 | 
         
            +
                                media["image"].append(part)
         
     | 
| 115 | 
         
            +
                            else:
         
     | 
| 116 | 
         
            +
                                media["image"].append(_extract_image(part))
         
     | 
| 117 | 
         
            +
                            text += MEDIA_TOKENS["image"]
         
     | 
| 118 | 
         
            +
                        elif isinstance(part, dict) or isinstance(part, Video):
         
     | 
| 119 | 
         
            +
                            if draft:
         
     | 
| 120 | 
         
            +
                                media["video"].append(part)
         
     | 
| 121 | 
         
            +
                            else:
         
     | 
| 122 | 
         
            +
                                media["video"].append(_extract_video(part, config))
         
     | 
| 123 | 
         
            +
                            text += MEDIA_TOKENS["video"]
         
     | 
| 124 | 
         
            +
                        else:
         
     | 
| 125 | 
         
            +
                            raise ValueError(f"Unsupported prompt part type: {type(part)}")
         
     | 
| 126 | 
         
            +
                    message["value"] = text
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                if MEDIA_TOKENS["video"] in messages[0]["value"]:
         
     | 
| 129 | 
         
            +
                    messages[0]["value"] = "<vila/video>" + messages[0]["value"].replace("<vila/video>", "")
         
     | 
| 130 | 
         
            +
                return media
         
     | 
    	
        media_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,158 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from functools import partial
         
     | 
| 2 | 
         
            +
            from typing import Any, Dict, List, Optional
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from torch import nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class BaseEncoder(nn.Module):
         
     | 
| 9 | 
         
            +
                def __init__(self, parent: nn.Module) -> None:
         
     | 
| 10 | 
         
            +
                    super().__init__()
         
     | 
| 11 | 
         
            +
                    self._parent = [parent]
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                @property
         
     | 
| 14 | 
         
            +
                def parent(self) -> nn.Module:
         
     | 
| 15 | 
         
            +
                    return self._parent[0]
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class BasicImageEncoder(BaseEncoder):
         
     | 
| 19 | 
         
            +
                def __init__(
         
     | 
| 20 | 
         
            +
                    self,
         
     | 
| 21 | 
         
            +
                    parent: torch.nn.Module,
         
     | 
| 22 | 
         
            +
                    start_tokens: Optional[str] = None,
         
     | 
| 23 | 
         
            +
                    end_tokens: Optional[str] = "\n",
         
     | 
| 24 | 
         
            +
                ) -> None:
         
     | 
| 25 | 
         
            +
                    super().__init__(parent)
         
     | 
| 26 | 
         
            +
                    self.start_tokens = start_tokens
         
     | 
| 27 | 
         
            +
                    self.end_tokens = end_tokens
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
         
     | 
| 30 | 
         
            +
                    if tokens is None:
         
     | 
| 31 | 
         
            +
                        return None
         
     | 
| 32 | 
         
            +
                    token_ids = self.parent.tokenizer(tokens).input_ids
         
     | 
| 33 | 
         
            +
                    token_ids = torch.tensor(token_ids, device=self.parent.device)
         
     | 
| 34 | 
         
            +
                    return self.parent.llm_model_embed_tokens(token_ids)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def _process_features(
         
     | 
| 37 | 
         
            +
                    self,
         
     | 
| 38 | 
         
            +
                    features: torch.Tensor,
         
     | 
| 39 | 
         
            +
                    start_token_embeds: Optional[torch.Tensor],
         
     | 
| 40 | 
         
            +
                    end_token_embeds: Optional[torch.Tensor],
         
     | 
| 41 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 42 | 
         
            +
                    if start_token_embeds is not None:
         
     | 
| 43 | 
         
            +
                        features = torch.cat([start_token_embeds, features], dim=0)
         
     | 
| 44 | 
         
            +
                    if end_token_embeds is not None:
         
     | 
| 45 | 
         
            +
                        features = torch.cat([features, end_token_embeds], dim=0)
         
     | 
| 46 | 
         
            +
                    return features
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]:
         
     | 
| 49 | 
         
            +
                    images = torch.stack(images, dim=0)
         
     | 
| 50 | 
         
            +
                    features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
         
     | 
| 51 | 
         
            +
                    process_features = partial(
         
     | 
| 52 | 
         
            +
                        self._process_features,
         
     | 
| 53 | 
         
            +
                        start_token_embeds=self.embed_tokens(self.start_tokens),
         
     | 
| 54 | 
         
            +
                        end_token_embeds=self.embed_tokens(self.end_tokens),
         
     | 
| 55 | 
         
            +
                    )
         
     | 
| 56 | 
         
            +
                    return [process_features(f).to(device) for f in features]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            class BasicVideoEncoder(BaseEncoder):
         
     | 
| 60 | 
         
            +
                def __init__(
         
     | 
| 61 | 
         
            +
                    self,
         
     | 
| 62 | 
         
            +
                    parent: torch.nn.Module,
         
     | 
| 63 | 
         
            +
                    start_tokens: Optional[str] = None,
         
     | 
| 64 | 
         
            +
                    end_tokens: Optional[str] = "\n",
         
     | 
| 65 | 
         
            +
                ) -> None:
         
     | 
| 66 | 
         
            +
                    super().__init__(parent)
         
     | 
| 67 | 
         
            +
                    self.start_tokens = start_tokens
         
     | 
| 68 | 
         
            +
                    self.end_tokens = end_tokens
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
         
     | 
| 71 | 
         
            +
                    if tokens is None:
         
     | 
| 72 | 
         
            +
                        return None
         
     | 
| 73 | 
         
            +
                    token_ids = self.parent.tokenizer(tokens).input_ids
         
     | 
| 74 | 
         
            +
                    token_ids = torch.tensor(token_ids, device=self.parent.device)
         
     | 
| 75 | 
         
            +
                    return self.parent.llm_model_embed_tokens(token_ids)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def _process_features(
         
     | 
| 78 | 
         
            +
                    self,
         
     | 
| 79 | 
         
            +
                    features: torch.Tensor,
         
     | 
| 80 | 
         
            +
                    start_token_embeds: Optional[torch.Tensor],
         
     | 
| 81 | 
         
            +
                    end_token_embeds: Optional[torch.Tensor],
         
     | 
| 82 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 83 | 
         
            +
                    if start_token_embeds is not None:
         
     | 
| 84 | 
         
            +
                        start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
         
     | 
| 85 | 
         
            +
                        features = torch.cat([start_embeds, features], dim=1)
         
     | 
| 86 | 
         
            +
                    if end_token_embeds is not None:
         
     | 
| 87 | 
         
            +
                        end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
         
     | 
| 88 | 
         
            +
                        features = torch.cat([features, end_embeds], dim=1)
         
     | 
| 89 | 
         
            +
                    return features.flatten(0, 1)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
         
     | 
| 92 | 
         
            +
                    num_frames = [video.shape[0] for video in videos]
         
     | 
| 93 | 
         
            +
                    images = torch.cat(videos, dim=0)
         
     | 
| 94 | 
         
            +
                    features = self.parent.encode_images(images)
         
     | 
| 95 | 
         
            +
                    features = torch.split(features, num_frames)
         
     | 
| 96 | 
         
            +
                    process_features = partial(
         
     | 
| 97 | 
         
            +
                        self._process_features,
         
     | 
| 98 | 
         
            +
                        start_token_embeds=self.embed_tokens(self.start_tokens),
         
     | 
| 99 | 
         
            +
                        end_token_embeds=self.embed_tokens(self.end_tokens),
         
     | 
| 100 | 
         
            +
                    )
         
     | 
| 101 | 
         
            +
                    return [process_features(f) for f in features]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
         
     | 
| 104 | 
         
            +
                if x.shape[dim] % size == 0:
         
     | 
| 105 | 
         
            +
                    return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
         
     | 
| 106 | 
         
            +
                else:
         
     | 
| 107 | 
         
            +
                    return x.narrow(dim, start=0, length=1)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            class TSPVideoEncoder(BasicVideoEncoder):
         
     | 
| 110 | 
         
            +
                def __init__(
         
     | 
| 111 | 
         
            +
                    self,
         
     | 
| 112 | 
         
            +
                    parent: torch.nn.Module,
         
     | 
| 113 | 
         
            +
                    #pool_sizes: List[Tuple[int, int, int]],
         
     | 
| 114 | 
         
            +
                    start_tokens: Optional[str] = None,
         
     | 
| 115 | 
         
            +
                    end_tokens: Optional[str] = "\n",
         
     | 
| 116 | 
         
            +
                    sep_tokens: Optional[str] = None,
         
     | 
| 117 | 
         
            +
                ) -> None:
         
     | 
| 118 | 
         
            +
                    super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
         
     | 
| 119 | 
         
            +
                    self.pool_sizes = [[8, 1, 1]] #pool_sizes
         
     | 
| 120 | 
         
            +
                    self.sep_tokens = sep_tokens
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def _process_features(
         
     | 
| 123 | 
         
            +
                    self,
         
     | 
| 124 | 
         
            +
                    inputs: torch.Tensor,
         
     | 
| 125 | 
         
            +
                    start_token_embeds: Optional[torch.Tensor],
         
     | 
| 126 | 
         
            +
                    end_token_embeds: Optional[torch.Tensor],
         
     | 
| 127 | 
         
            +
                    sep_token_embeds: Optional[torch.Tensor],
         
     | 
| 128 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 129 | 
         
            +
                    nt, ns = inputs.shape[:2]
         
     | 
| 130 | 
         
            +
                    nl = int(ns**0.5)
         
     | 
| 131 | 
         
            +
                    outputs = []
         
     | 
| 132 | 
         
            +
                    for pool_size in self.pool_sizes:
         
     | 
| 133 | 
         
            +
                        features = inputs.view(nt, nl, nl, -1)
         
     | 
| 134 | 
         
            +
                        for dim, p in enumerate(pool_size):
         
     | 
| 135 | 
         
            +
                            features = pool(features, p, dim=dim)
         
     | 
| 136 | 
         
            +
                        features = features.flatten(1, 2)
         
     | 
| 137 | 
         
            +
                        features = super()._process_features(
         
     | 
| 138 | 
         
            +
                            features,
         
     | 
| 139 | 
         
            +
                            start_token_embeds=start_token_embeds,
         
     | 
| 140 | 
         
            +
                            end_token_embeds=end_token_embeds,
         
     | 
| 141 | 
         
            +
                        )
         
     | 
| 142 | 
         
            +
                        if sep_token_embeds is not None:
         
     | 
| 143 | 
         
            +
                            features = torch.cat([features, sep_token_embeds], dim=0)
         
     | 
| 144 | 
         
            +
                        outputs.append(features)
         
     | 
| 145 | 
         
            +
                    return torch.cat(outputs, dim=0)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
         
     | 
| 148 | 
         
            +
                    num_frames = [video.shape[0] for video in videos]
         
     | 
| 149 | 
         
            +
                    images = torch.cat(videos, dim=0)
         
     | 
| 150 | 
         
            +
                    features = self.parent.encode_images(images)
         
     | 
| 151 | 
         
            +
                    features = torch.split(features, num_frames)
         
     | 
| 152 | 
         
            +
                    process_features = partial(
         
     | 
| 153 | 
         
            +
                        self._process_features,
         
     | 
| 154 | 
         
            +
                        start_token_embeds=self.embed_tokens(self.start_tokens),
         
     | 
| 155 | 
         
            +
                        end_token_embeds=self.embed_tokens(self.end_tokens),
         
     | 
| 156 | 
         
            +
                        sep_token_embeds=self.embed_tokens(self.sep_tokens),
         
     | 
| 157 | 
         
            +
                    )
         
     | 
| 158 | 
         
            +
                    return [process_features(f) for f in features]
         
     | 
    	
        mm_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,575 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #      http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import base64
         
     | 
| 20 | 
         
            +
            import os
         
     | 
| 21 | 
         
            +
            import tempfile
         
     | 
| 22 | 
         
            +
            from io import BytesIO
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import numpy as np
         
     | 
| 25 | 
         
            +
            import torch
         
     | 
| 26 | 
         
            +
            from PIL import Image
         
     | 
| 27 | 
         
            +
            from transformers import StoppingCriteria
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            from .constants import DEFAULT_IMAGE_TOKEN
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
         
     | 
| 33 | 
         
            +
                import cv2
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                if fps == None or frame_count == None:
         
     | 
| 36 | 
         
            +
                    # if one of fps or frame_count is None, still recompute
         
     | 
| 37 | 
         
            +
                    fps = vidcap.get(cv2.CAP_PROP_FPS)
         
     | 
| 38 | 
         
            +
                    frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
         
     | 
| 39 | 
         
            +
                if fps == 0 or frame_count == 0:
         
     | 
| 40 | 
         
            +
                    print(f"Video file not found. return empty images. {video_file_name}")
         
     | 
| 41 | 
         
            +
                    return [
         
     | 
| 42 | 
         
            +
                        Image.new("RGB", (720, 720)),
         
     | 
| 43 | 
         
            +
                    ] * num_frames, 0
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                duration = frame_count / fps
         
     | 
| 46 | 
         
            +
                frame_interval = frame_count // num_frames
         
     | 
| 47 | 
         
            +
                if frame_interval == 0 and frame_count <= 1:
         
     | 
| 48 | 
         
            +
                    print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
         
     | 
| 49 | 
         
            +
                    return [
         
     | 
| 50 | 
         
            +
                        Image.new("RGB", (720, 720)),
         
     | 
| 51 | 
         
            +
                    ] * num_frames, 0
         
     | 
| 52 | 
         
            +
                # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                images = []
         
     | 
| 55 | 
         
            +
                count = 0
         
     | 
| 56 | 
         
            +
                success = True
         
     | 
| 57 | 
         
            +
                frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
         
     | 
| 58 | 
         
            +
                while success:
         
     | 
| 59 | 
         
            +
                    # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
         
     | 
| 60 | 
         
            +
                    if frame_count >= num_frames:
         
     | 
| 61 | 
         
            +
                        success, frame = vidcap.read()
         
     | 
| 62 | 
         
            +
                        if count in frame_indices:
         
     | 
| 63 | 
         
            +
                            try:
         
     | 
| 64 | 
         
            +
                                img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 65 | 
         
            +
                                im_pil = Image.fromarray(img)
         
     | 
| 66 | 
         
            +
                                images.append(im_pil)
         
     | 
| 67 | 
         
            +
                            except BaseException:
         
     | 
| 68 | 
         
            +
                                continue
         
     | 
| 69 | 
         
            +
                            if len(images) >= num_frames:
         
     | 
| 70 | 
         
            +
                                return images, num_frames
         
     | 
| 71 | 
         
            +
                        count += 1
         
     | 
| 72 | 
         
            +
                    else:
         
     | 
| 73 | 
         
            +
                        # Left padding frames if the video is not long enough
         
     | 
| 74 | 
         
            +
                        success, frame = vidcap.read()
         
     | 
| 75 | 
         
            +
                        if success:
         
     | 
| 76 | 
         
            +
                            try:
         
     | 
| 77 | 
         
            +
                                img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 78 | 
         
            +
                                im_pil = Image.fromarray(img)
         
     | 
| 79 | 
         
            +
                                images.append(im_pil)
         
     | 
| 80 | 
         
            +
                            except BaseException:
         
     | 
| 81 | 
         
            +
                                continue
         
     | 
| 82 | 
         
            +
                            count += 1
         
     | 
| 83 | 
         
            +
                        else:
         
     | 
| 84 | 
         
            +
                            break
         
     | 
| 85 | 
         
            +
                if len(images) == 0:
         
     | 
| 86 | 
         
            +
                    raise ValueError("Did not find enough frames in the video. return empty image.")
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                return images, len(images)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
         
     | 
| 92 | 
         
            +
                """
         
     | 
| 93 | 
         
            +
                num_frames is the max number of frames the model can support.
         
     | 
| 94 | 
         
            +
                frame_count is the number of frames in the input video.
         
     | 
| 95 | 
         
            +
                max_fps is the max FPS of the model can support.
         
     | 
| 96 | 
         
            +
                fps is the fps of the input video.
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                import random
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                import cv2
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                if fps == None or frame_count == None:
         
     | 
| 104 | 
         
            +
                    # if one of fps or frame_count is None, still recompute
         
     | 
| 105 | 
         
            +
                    fps = vidcap.get(cv2.CAP_PROP_FPS)
         
     | 
| 106 | 
         
            +
                    frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                if fps == 0 or frame_count == 0:
         
     | 
| 109 | 
         
            +
                    print(f"Video file not found. return empty images. {video_file_name}")
         
     | 
| 110 | 
         
            +
                    empty_video_frames = int(random.uniform(2, 8 * max_fps))
         
     | 
| 111 | 
         
            +
                    return [
         
     | 
| 112 | 
         
            +
                        Image.new("RGB", (720, 720)),
         
     | 
| 113 | 
         
            +
                    ] * empty_video_frames, 0
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                duration = frame_count / fps
         
     | 
| 116 | 
         
            +
                # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
         
     | 
| 117 | 
         
            +
                # If the video is too long (longer than max_fps and num_frames can support),
         
     | 
| 118 | 
         
            +
                # we will use lower fps to sample frames.
         
     | 
| 119 | 
         
            +
                if duration >= num_frames / max_fps:
         
     | 
| 120 | 
         
            +
                    frame_interval = frame_count // num_frames
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    # If the video is too short, we will skip the video if there is only one frame.
         
     | 
| 123 | 
         
            +
                    if frame_interval == 0 and frame_count <= 1:
         
     | 
| 124 | 
         
            +
                        print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
         
     | 
| 125 | 
         
            +
                        empty_video_frames = int(random.uniform(2, 8 * max_fps))
         
     | 
| 126 | 
         
            +
                        return [
         
     | 
| 127 | 
         
            +
                            Image.new("RGB", (720, 720)),
         
     | 
| 128 | 
         
            +
                        ] * empty_video_frames, 0
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    images = []
         
     | 
| 131 | 
         
            +
                    count = 0
         
     | 
| 132 | 
         
            +
                    success = True
         
     | 
| 133 | 
         
            +
                    frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    while success:
         
     | 
| 136 | 
         
            +
                        if frame_count >= num_frames:
         
     | 
| 137 | 
         
            +
                            # success, frame = vidcap.read()
         
     | 
| 138 | 
         
            +
                            if count in frame_indices:
         
     | 
| 139 | 
         
            +
                                success, frame = vidcap.read()
         
     | 
| 140 | 
         
            +
                                try:
         
     | 
| 141 | 
         
            +
                                    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 142 | 
         
            +
                                    im_pil = Image.fromarray(img)
         
     | 
| 143 | 
         
            +
                                    images.append(im_pil)
         
     | 
| 144 | 
         
            +
                                except:
         
     | 
| 145 | 
         
            +
                                    # print("Failed to read frame:", count)
         
     | 
| 146 | 
         
            +
                                    continue
         
     | 
| 147 | 
         
            +
                                if len(images) >= num_frames:
         
     | 
| 148 | 
         
            +
                                    return images, num_frames
         
     | 
| 149 | 
         
            +
                            else:
         
     | 
| 150 | 
         
            +
                                success = vidcap.grab()
         
     | 
| 151 | 
         
            +
                            count += 1
         
     | 
| 152 | 
         
            +
                        else:
         
     | 
| 153 | 
         
            +
                            # Left padding frames if the video is not long enough
         
     | 
| 154 | 
         
            +
                            success, frame = vidcap.read()
         
     | 
| 155 | 
         
            +
                            if success:
         
     | 
| 156 | 
         
            +
                                try:
         
     | 
| 157 | 
         
            +
                                    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 158 | 
         
            +
                                    im_pil = Image.fromarray(img)
         
     | 
| 159 | 
         
            +
                                    images.append(im_pil)
         
     | 
| 160 | 
         
            +
                                except:
         
     | 
| 161 | 
         
            +
                                    # print("Failed to read frame:", count)
         
     | 
| 162 | 
         
            +
                                    continue
         
     | 
| 163 | 
         
            +
                                count += 1
         
     | 
| 164 | 
         
            +
                            else:
         
     | 
| 165 | 
         
            +
                                break
         
     | 
| 166 | 
         
            +
                else:
         
     | 
| 167 | 
         
            +
                    frames_required = int(duration * max_fps)
         
     | 
| 168 | 
         
            +
                    frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
         
     | 
| 169 | 
         
            +
                    if frames_required == 0:
         
     | 
| 170 | 
         
            +
                        print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
         
     | 
| 171 | 
         
            +
                        empty_video_frames = int(random.uniform(2, 8 * max_fps))
         
     | 
| 172 | 
         
            +
                        return [
         
     | 
| 173 | 
         
            +
                            Image.new("RGB", (720, 720)),
         
     | 
| 174 | 
         
            +
                        ] * empty_video_frames, 0
         
     | 
| 175 | 
         
            +
                    elif frames_required == 1:
         
     | 
| 176 | 
         
            +
                        frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
         
     | 
| 177 | 
         
            +
                    images = []
         
     | 
| 178 | 
         
            +
                    count = 0
         
     | 
| 179 | 
         
            +
                    looked = 0
         
     | 
| 180 | 
         
            +
                    success = True
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    while success:
         
     | 
| 183 | 
         
            +
                        success, frame = vidcap.read()
         
     | 
| 184 | 
         
            +
                        if success and (looked in frame_indices):
         
     | 
| 185 | 
         
            +
                            try:
         
     | 
| 186 | 
         
            +
                                img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 187 | 
         
            +
                                im_pil = Image.fromarray(img)
         
     | 
| 188 | 
         
            +
                                images.append(im_pil)
         
     | 
| 189 | 
         
            +
                            except:
         
     | 
| 190 | 
         
            +
                                continue
         
     | 
| 191 | 
         
            +
                            count += 1
         
     | 
| 192 | 
         
            +
                        looked += 1
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                if len(images) == 0:
         
     | 
| 195 | 
         
            +
                    empty_video_frames = int(random.uniform(2, 8 * max_fps))
         
     | 
| 196 | 
         
            +
                    return [
         
     | 
| 197 | 
         
            +
                        Image.new("RGB", (720, 720)),
         
     | 
| 198 | 
         
            +
                    ] * empty_video_frames, 0
         
     | 
| 199 | 
         
            +
                else:
         
     | 
| 200 | 
         
            +
                    return images, len(images)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
            def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
         
     | 
| 204 | 
         
            +
                """
         
     | 
| 205 | 
         
            +
                Extract frames from a video using OpenCV.
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                Args:
         
     | 
| 208 | 
         
            +
                    vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
         
     | 
| 209 | 
         
            +
                    frames (int): Number of frames to extract from the video.
         
     | 
| 210 | 
         
            +
                    fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                Returns:
         
     | 
| 213 | 
         
            +
                    list: List of PIL Images extracted from the video.
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                Raises:
         
     | 
| 216 | 
         
            +
                    NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
         
     | 
| 217 | 
         
            +
                """
         
     | 
| 218 | 
         
            +
                import cv2
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                if isinstance(vpath_or_bytesio, str):
         
     | 
| 221 | 
         
            +
                    vidcap = cv2.VideoCapture(vpath_or_bytesio)
         
     | 
| 222 | 
         
            +
                    if max_fps > 0.0:
         
     | 
| 223 | 
         
            +
                        return get_frame_from_vcap_with_fps(
         
     | 
| 224 | 
         
            +
                            vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
         
     | 
| 225 | 
         
            +
                        )
         
     | 
| 226 | 
         
            +
                    return get_frame_from_vcap(
         
     | 
| 227 | 
         
            +
                        vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
         
     | 
| 228 | 
         
            +
                    )
         
     | 
| 229 | 
         
            +
                elif isinstance(vpath_or_bytesio, (BytesIO,)):
         
     | 
| 230 | 
         
            +
                    # assuming mp4
         
     | 
| 231 | 
         
            +
                    with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
         
     | 
| 232 | 
         
            +
                        temp_video.write(vpath_or_bytesio.read())
         
     | 
| 233 | 
         
            +
                        temp_video_name = temp_video.name
         
     | 
| 234 | 
         
            +
                        vidcap = cv2.VideoCapture(temp_video_name)
         
     | 
| 235 | 
         
            +
                        if max_fps > 0.0:
         
     | 
| 236 | 
         
            +
                            return get_frame_from_vcap_with_fps(
         
     | 
| 237 | 
         
            +
                                vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
         
     | 
| 238 | 
         
            +
                            )
         
     | 
| 239 | 
         
            +
                        return get_frame_from_vcap(
         
     | 
| 240 | 
         
            +
                            vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
         
     | 
| 241 | 
         
            +
                        )
         
     | 
| 242 | 
         
            +
                else:
         
     | 
| 243 | 
         
            +
                    raise NotImplementedError(type(vpath_or_bytesio))
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
            def load_image_from_base64(image):
         
     | 
| 247 | 
         
            +
                return Image.open(BytesIO(base64.b64decode(image)))
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            def expand2square(pil_img, background_color):
         
     | 
| 251 | 
         
            +
                """
         
     | 
| 252 | 
         
            +
                Expand the given PIL image to a square shape by adding padding.
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                Parameters:
         
     | 
| 255 | 
         
            +
                - pil_img: The PIL image to be expanded.
         
     | 
| 256 | 
         
            +
                - background_color: The color of the padding to be added.
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                Returns:
         
     | 
| 259 | 
         
            +
                - The expanded PIL image.
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                If the image is already square, it is returned as is.
         
     | 
| 262 | 
         
            +
                If the image is wider than it is tall, padding is added to the top and bottom.
         
     | 
| 263 | 
         
            +
                If the image is taller than it is wide, padding is added to the left and right.
         
     | 
| 264 | 
         
            +
                """
         
     | 
| 265 | 
         
            +
                width, height = pil_img.size
         
     | 
| 266 | 
         
            +
                if pil_img.mode == "L":
         
     | 
| 267 | 
         
            +
                    background_color = background_color[0]
         
     | 
| 268 | 
         
            +
                if width == height:
         
     | 
| 269 | 
         
            +
                    return pil_img
         
     | 
| 270 | 
         
            +
                elif width > height:
         
     | 
| 271 | 
         
            +
                    result = Image.new(pil_img.mode, (width, width), background_color)
         
     | 
| 272 | 
         
            +
                    result.paste(pil_img, (0, (width - height) // 2))
         
     | 
| 273 | 
         
            +
                    return result
         
     | 
| 274 | 
         
            +
                else:
         
     | 
| 275 | 
         
            +
                    result = Image.new(pil_img.mode, (height, height), background_color)
         
     | 
| 276 | 
         
            +
                    result.paste(pil_img, ((height - width) // 2, 0))
         
     | 
| 277 | 
         
            +
                    return result
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
            def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
         
     | 
| 281 | 
         
            +
                best_ratio_diff = float("inf")
         
     | 
| 282 | 
         
            +
                best_ratio = (1, 1)
         
     | 
| 283 | 
         
            +
                area = width * height
         
     | 
| 284 | 
         
            +
                for ratio in target_ratios:
         
     | 
| 285 | 
         
            +
                    target_aspect_ratio = ratio[0] / ratio[1]
         
     | 
| 286 | 
         
            +
                    ratio_diff = abs(aspect_ratio - target_aspect_ratio)
         
     | 
| 287 | 
         
            +
                    if ratio_diff < best_ratio_diff:
         
     | 
| 288 | 
         
            +
                        best_ratio_diff = ratio_diff
         
     | 
| 289 | 
         
            +
                        best_ratio = ratio
         
     | 
| 290 | 
         
            +
                    elif ratio_diff == best_ratio_diff:
         
     | 
| 291 | 
         
            +
                        if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
         
     | 
| 292 | 
         
            +
                            best_ratio = ratio
         
     | 
| 293 | 
         
            +
                return best_ratio
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
            def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
         
     | 
| 297 | 
         
            +
                orig_width, orig_height = image.size
         
     | 
| 298 | 
         
            +
                aspect_ratio = orig_width / orig_height
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                # calculate the existing image aspect ratio
         
     | 
| 301 | 
         
            +
                target_ratios = {
         
     | 
| 302 | 
         
            +
                    (i, j)
         
     | 
| 303 | 
         
            +
                    for n in range(min_num, max_num + 1)
         
     | 
| 304 | 
         
            +
                    for i in range(1, n + 1)
         
     | 
| 305 | 
         
            +
                    for j in range(1, n + 1)
         
     | 
| 306 | 
         
            +
                    if i * j <= max_num and i * j >= min_num
         
     | 
| 307 | 
         
            +
                }
         
     | 
| 308 | 
         
            +
                target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                # find the closest aspect ratio to the target
         
     | 
| 311 | 
         
            +
                target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                # calculate the target width and height
         
     | 
| 314 | 
         
            +
                target_width = image_size * target_aspect_ratio[0]
         
     | 
| 315 | 
         
            +
                target_height = image_size * target_aspect_ratio[1]
         
     | 
| 316 | 
         
            +
                blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                # resize the image
         
     | 
| 319 | 
         
            +
                resized_img = image.resize((target_width, target_height))
         
     | 
| 320 | 
         
            +
                processed_images = []
         
     | 
| 321 | 
         
            +
                for i in range(blocks):
         
     | 
| 322 | 
         
            +
                    box = (
         
     | 
| 323 | 
         
            +
                        (i % (target_width // image_size)) * image_size,
         
     | 
| 324 | 
         
            +
                        (i // (target_width // image_size)) * image_size,
         
     | 
| 325 | 
         
            +
                        ((i % (target_width // image_size)) + 1) * image_size,
         
     | 
| 326 | 
         
            +
                        ((i // (target_width // image_size)) + 1) * image_size,
         
     | 
| 327 | 
         
            +
                    )
         
     | 
| 328 | 
         
            +
                    # split the image
         
     | 
| 329 | 
         
            +
                    split_img = resized_img.crop(box)
         
     | 
| 330 | 
         
            +
                    processed_images.append(split_img)
         
     | 
| 331 | 
         
            +
                assert len(processed_images) == blocks
         
     | 
| 332 | 
         
            +
                if use_thumbnail and len(processed_images) != 1:
         
     | 
| 333 | 
         
            +
                    thumbnail_img = image.resize((image_size, image_size))
         
     | 
| 334 | 
         
            +
                    processed_images.append(thumbnail_img)
         
     | 
| 335 | 
         
            +
                return processed_images
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
            def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
         
     | 
| 339 | 
         
            +
                orig_width, orig_height = image.size
         
     | 
| 340 | 
         
            +
                aspect_ratio = orig_width / orig_height
         
     | 
| 341 | 
         
            +
                min_num = (s2_scales[-1] // s2_scales[0]) ** 2  # at least use number of tiles as the largest scale
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                processed_images = []
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                ##########################################################################################
         
     | 
| 346 | 
         
            +
                ############# Add tiles for all but the last scale using fixed squre ratio ###############
         
     | 
| 347 | 
         
            +
                ##########################################################################################
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                for scale in s2_scales[:-1]:
         
     | 
| 350 | 
         
            +
                    target_width = image_size * (scale // s2_scales[0])
         
     | 
| 351 | 
         
            +
                    target_height = image_size * (scale // s2_scales[0])
         
     | 
| 352 | 
         
            +
                    blocks = (scale // s2_scales[0]) ** 2
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    # resize the image
         
     | 
| 355 | 
         
            +
                    resized_img = image.resize((target_width, target_height))
         
     | 
| 356 | 
         
            +
                    for i in range(blocks):
         
     | 
| 357 | 
         
            +
                        box = (
         
     | 
| 358 | 
         
            +
                            (i % (target_width // image_size)) * image_size,
         
     | 
| 359 | 
         
            +
                            (i // (target_width // image_size)) * image_size,
         
     | 
| 360 | 
         
            +
                            ((i % (target_width // image_size)) + 1) * image_size,
         
     | 
| 361 | 
         
            +
                            ((i // (target_width // image_size)) + 1) * image_size,
         
     | 
| 362 | 
         
            +
                        )
         
     | 
| 363 | 
         
            +
                        # split the image
         
     | 
| 364 | 
         
            +
                        split_img = resized_img.crop(box)
         
     | 
| 365 | 
         
            +
                        processed_images.append(split_img)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                ##########################################################################################
         
     | 
| 368 | 
         
            +
                ################ Add tiles for the last scale using dynamic aspect ratio #################
         
     | 
| 369 | 
         
            +
                ##########################################################################################
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                # calculate the existing image aspect ratio
         
     | 
| 372 | 
         
            +
                target_ratios = {
         
     | 
| 373 | 
         
            +
                    (i, j)
         
     | 
| 374 | 
         
            +
                    for n in range(min_num, max_num + 1)
         
     | 
| 375 | 
         
            +
                    for i in range(1, n + 1)
         
     | 
| 376 | 
         
            +
                    for j in range(1, n + 1)
         
     | 
| 377 | 
         
            +
                    if i * j <= max_num and i * j >= min_num
         
     | 
| 378 | 
         
            +
                }
         
     | 
| 379 | 
         
            +
                target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                # find the closest aspect ratio to the target
         
     | 
| 382 | 
         
            +
                target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                # calculate the target width and height
         
     | 
| 385 | 
         
            +
                target_width = image_size * target_aspect_ratio[0]
         
     | 
| 386 | 
         
            +
                target_height = image_size * target_aspect_ratio[1]
         
     | 
| 387 | 
         
            +
                blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                # resize the image
         
     | 
| 390 | 
         
            +
                resized_img = image.resize((target_width, target_height))
         
     | 
| 391 | 
         
            +
                for i in range(blocks):
         
     | 
| 392 | 
         
            +
                    box = (
         
     | 
| 393 | 
         
            +
                        (i % (target_width // image_size)) * image_size,
         
     | 
| 394 | 
         
            +
                        (i // (target_width // image_size)) * image_size,
         
     | 
| 395 | 
         
            +
                        ((i % (target_width // image_size)) + 1) * image_size,
         
     | 
| 396 | 
         
            +
                        ((i // (target_width // image_size)) + 1) * image_size,
         
     | 
| 397 | 
         
            +
                    )
         
     | 
| 398 | 
         
            +
                    # split the image
         
     | 
| 399 | 
         
            +
                    split_img = resized_img.crop(box)
         
     | 
| 400 | 
         
            +
                    processed_images.append(split_img)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
            def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
         
     | 
| 406 | 
         
            +
                prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
         
     | 
| 407 | 
         
            +
                idx = 0
         
     | 
| 408 | 
         
            +
                all_images = []
         
     | 
| 409 | 
         
            +
                for img in images:
         
     | 
| 410 | 
         
            +
                    processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
         
     | 
| 411 | 
         
            +
                    all_images.append(processed_images)
         
     | 
| 412 | 
         
            +
                    prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
         
     | 
| 413 | 
         
            +
                    idx += 2
         
     | 
| 414 | 
         
            +
                prompt = "".join(prompt)
         
     | 
| 415 | 
         
            +
                if all_images:
         
     | 
| 416 | 
         
            +
                    all_images = torch.cat(all_images)
         
     | 
| 417 | 
         
            +
                else:
         
     | 
| 418 | 
         
            +
                    all_images = None
         
     | 
| 419 | 
         
            +
                    prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
         
     | 
| 420 | 
         
            +
                return all_images, prompt
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
            def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
         
     | 
| 424 | 
         
            +
                idx = 0
         
     | 
| 425 | 
         
            +
                all_images = []
         
     | 
| 426 | 
         
            +
                all_block_size = []
         
     | 
| 427 | 
         
            +
                for img in images:
         
     | 
| 428 | 
         
            +
                    processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
         
     | 
| 429 | 
         
            +
                    all_images.append(processed_images)
         
     | 
| 430 | 
         
            +
                    all_block_size.append(block_size)
         
     | 
| 431 | 
         
            +
                    idx += 2
         
     | 
| 432 | 
         
            +
                if all_images:
         
     | 
| 433 | 
         
            +
                    all_images = torch.cat(all_images)
         
     | 
| 434 | 
         
            +
                else:
         
     | 
| 435 | 
         
            +
                    all_images = None
         
     | 
| 436 | 
         
            +
                return all_images, all_block_size
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
            def process_image(
         
     | 
| 440 | 
         
            +
                image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
         
     | 
| 441 | 
         
            +
            ):
         
     | 
| 442 | 
         
            +
                processor = data_args.image_processor
         
     | 
| 443 | 
         
            +
                if isinstance(image_file, str):
         
     | 
| 444 | 
         
            +
                    if image_folder is not None:
         
     | 
| 445 | 
         
            +
                        image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
         
     | 
| 446 | 
         
            +
                    else:
         
     | 
| 447 | 
         
            +
                        image = Image.open(image_file).convert("RGB")
         
     | 
| 448 | 
         
            +
                else:
         
     | 
| 449 | 
         
            +
                    # image is stored in bytearray
         
     | 
| 450 | 
         
            +
                    image = image_file
         
     | 
| 451 | 
         
            +
                image = image.convert("RGB")
         
     | 
| 452 | 
         
            +
                if hasattr(data_args.image_processor, "crop_size"):
         
     | 
| 453 | 
         
            +
                    # CLIP vision tower
         
     | 
| 454 | 
         
            +
                    crop_size = data_args.image_processor.crop_size
         
     | 
| 455 | 
         
            +
                else:
         
     | 
| 456 | 
         
            +
                    # SIGLIP vision tower
         
     | 
| 457 | 
         
            +
                    assert hasattr(data_args.image_processor, "size")
         
     | 
| 458 | 
         
            +
                    crop_size = data_args.image_processor.size
         
     | 
| 459 | 
         
            +
                if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
         
     | 
| 460 | 
         
            +
                    assert crop_size["height"] == crop_size["width"]
         
     | 
| 461 | 
         
            +
                    images, block_size = dynamic_s2_preprocess(
         
     | 
| 462 | 
         
            +
                        image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
         
     | 
| 463 | 
         
            +
                    )
         
     | 
| 464 | 
         
            +
                    images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
         
     | 
| 465 | 
         
            +
                    return torch.stack(images), block_size
         
     | 
| 466 | 
         
            +
                if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
         
     | 
| 467 | 
         
            +
                    assert crop_size["height"] == crop_size["width"]
         
     | 
| 468 | 
         
            +
                    if max_tiles is not None:
         
     | 
| 469 | 
         
            +
                        max_num = max_tiles
         
     | 
| 470 | 
         
            +
                    else:
         
     | 
| 471 | 
         
            +
                        max_num = data_args.max_tiles
         
     | 
| 472 | 
         
            +
                    images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
         
     | 
| 473 | 
         
            +
                    images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
         
     | 
| 474 | 
         
            +
                    return torch.stack(images)
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                if data_args.image_aspect_ratio == "resize":
         
     | 
| 477 | 
         
            +
                    image = image.resize((crop_size["width"], crop_size["height"]))
         
     | 
| 478 | 
         
            +
                if data_args.image_aspect_ratio == "pad":
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                    def expand2square(pil_img, background_color):
         
     | 
| 481 | 
         
            +
                        width, height = pil_img.size
         
     | 
| 482 | 
         
            +
                        if width == height:
         
     | 
| 483 | 
         
            +
                            return pil_img
         
     | 
| 484 | 
         
            +
                        elif width > height:
         
     | 
| 485 | 
         
            +
                            result = Image.new(pil_img.mode, (width, width), background_color)
         
     | 
| 486 | 
         
            +
                            result.paste(pil_img, (0, (width - height) // 2))
         
     | 
| 487 | 
         
            +
                            return result
         
     | 
| 488 | 
         
            +
                        else:
         
     | 
| 489 | 
         
            +
                            result = Image.new(pil_img.mode, (height, height), background_color)
         
     | 
| 490 | 
         
            +
                            result.paste(pil_img, ((height - width) // 2, 0))
         
     | 
| 491 | 
         
            +
                            return result
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
         
     | 
| 494 | 
         
            +
                    image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
         
     | 
| 495 | 
         
            +
                else:
         
     | 
| 496 | 
         
            +
                    # Using default behavior of the vision encoder
         
     | 
| 497 | 
         
            +
                    # For CLIP, default is central crop
         
     | 
| 498 | 
         
            +
                    # For Radio, default is central crop
         
     | 
| 499 | 
         
            +
                    # For Siglip, default is resize
         
     | 
| 500 | 
         
            +
                    # For InternVIT, default is resize
         
     | 
| 501 | 
         
            +
                    image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
         
     | 
| 502 | 
         
            +
                return image
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
            def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
         
     | 
| 506 | 
         
            +
                model_cfg.image_processor = image_processor
         
     | 
| 507 | 
         
            +
                new_images = [
         
     | 
| 508 | 
         
            +
                    process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
         
     | 
| 509 | 
         
            +
                    for image in images
         
     | 
| 510 | 
         
            +
                ]
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                if all(x.shape == new_images[0].shape for x in new_images):
         
     | 
| 513 | 
         
            +
                    if len(new_images[0].shape) == 4:
         
     | 
| 514 | 
         
            +
                        new_images = torch.cat(new_images, dim=0)
         
     | 
| 515 | 
         
            +
                    elif len(new_images[0].shape) == 3:
         
     | 
| 516 | 
         
            +
                        new_images = torch.stack(new_images, dim=0)
         
     | 
| 517 | 
         
            +
                    else:
         
     | 
| 518 | 
         
            +
                        raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
         
     | 
| 519 | 
         
            +
                else:
         
     | 
| 520 | 
         
            +
                    raise ValueError("The shape of images in new_images is different!")
         
     | 
| 521 | 
         
            +
                return new_images
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
             
     | 
| 524 | 
         
            +
            def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
         
     | 
| 525 | 
         
            +
                if return_ids:
         
     | 
| 526 | 
         
            +
                    return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
         
     | 
| 527 | 
         
            +
                else:
         
     | 
| 528 | 
         
            +
                    return tokenizer(prompt, return_tensors=return_tensors)
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
            def is_gemma_tokenizer(tokenizer):
         
     | 
| 532 | 
         
            +
                return "gemma" in tokenizer.__class__.__name__.lower()
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
            def get_model_name_from_path(model_path):
         
     | 
| 536 | 
         
            +
                model_path = model_path.strip("/")
         
     | 
| 537 | 
         
            +
                model_paths = model_path.split("/")
         
     | 
| 538 | 
         
            +
                if model_paths[-1].startswith("checkpoint-"):
         
     | 
| 539 | 
         
            +
                    return model_paths[-2] + "_" + model_paths[-1]
         
     | 
| 540 | 
         
            +
                else:
         
     | 
| 541 | 
         
            +
                    return model_paths[-1]
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
            class KeywordsStoppingCriteria(StoppingCriteria):
         
     | 
| 545 | 
         
            +
                def __init__(self, keywords, tokenizer, input_ids):
         
     | 
| 546 | 
         
            +
                    self.keywords = keywords
         
     | 
| 547 | 
         
            +
                    self.keyword_ids = []
         
     | 
| 548 | 
         
            +
                    self.max_keyword_len = 0
         
     | 
| 549 | 
         
            +
                    for keyword in keywords:
         
     | 
| 550 | 
         
            +
                        cur_keyword_ids = tokenizer(keyword).input_ids
         
     | 
| 551 | 
         
            +
                        if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
         
     | 
| 552 | 
         
            +
                            cur_keyword_ids = cur_keyword_ids[1:]
         
     | 
| 553 | 
         
            +
                        if len(cur_keyword_ids) > self.max_keyword_len:
         
     | 
| 554 | 
         
            +
                            self.max_keyword_len = len(cur_keyword_ids)
         
     | 
| 555 | 
         
            +
                        self.keyword_ids.append(torch.tensor(cur_keyword_ids))
         
     | 
| 556 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 557 | 
         
            +
                    self.start_len = input_ids.shape[1]
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
                def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         
     | 
| 560 | 
         
            +
                    offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
         
     | 
| 561 | 
         
            +
                    self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
         
     | 
| 562 | 
         
            +
                    for keyword_id in self.keyword_ids:
         
     | 
| 563 | 
         
            +
                        if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
         
     | 
| 564 | 
         
            +
                            return True
         
     | 
| 565 | 
         
            +
                    outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
         
     | 
| 566 | 
         
            +
                    for keyword in self.keywords:
         
     | 
| 567 | 
         
            +
                        if keyword in outputs:
         
     | 
| 568 | 
         
            +
                            return True
         
     | 
| 569 | 
         
            +
                    return False
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         
     | 
| 572 | 
         
            +
                    outputs = []
         
     | 
| 573 | 
         
            +
                    for i in range(output_ids.shape[0]):
         
     | 
| 574 | 
         
            +
                        outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
         
     | 
| 575 | 
         
            +
                    return all(outputs)
         
     | 
    	
        model_utils_packing.py
    ADDED
    
    | 
         @@ -0,0 +1,35 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from importlib import import_module
         
     | 
| 2 | 
         
            +
            from typing import Tuple
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import transformers
         
     | 
| 6 | 
         
            +
            from torch import nn
         
     | 
| 7 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            __all__ = ["patch"]
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
         
     | 
| 13 | 
         
            +
                if hasattr(_get_unpad_data, "seqlens_in_batch"):
         
     | 
| 14 | 
         
            +
                    seqlens_in_batch = _get_unpad_data.seqlens_in_batch
         
     | 
| 15 | 
         
            +
                else:
         
     | 
| 16 | 
         
            +
                    seqlens_in_batch = torch.sum(attention_mask, dim=1)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
         
     | 
| 19 | 
         
            +
                max_seqlen_in_batch = seqlens_in_batch.max().item()
         
     | 
| 20 | 
         
            +
                cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
         
     | 
| 21 | 
         
            +
                return indices, cu_seqlens, max_seqlen_in_batch
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
         
     | 
| 25 | 
         
            +
                _get_unpad_data.seqlens_in_batch = seqlens_in_batch
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def patch(model: nn.Module) -> None:
         
     | 
| 29 | 
         
            +
                if transformers.__version__ < "4.43.0":
         
     | 
| 30 | 
         
            +
                    m = import_module(model.__module__)
         
     | 
| 31 | 
         
            +
                    if not hasattr(m, "_get_unpad_data"):
         
     | 
| 32 | 
         
            +
                        raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
         
     | 
| 33 | 
         
            +
                    m._get_unpad_data = _get_unpad_data
         
     | 
| 34 | 
         
            +
                else:
         
     | 
| 35 | 
         
            +
                    transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
         
     | 
    	
        modeling_vila.py
    ADDED
    
    | 
         @@ -0,0 +1,1256 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import copy
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import logging
         
     | 
| 4 | 
         
            +
            import math
         
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
            import os.path
         
     | 
| 7 | 
         
            +
            import os.path as osp
         
     | 
| 8 | 
         
            +
            import shutil
         
     | 
| 9 | 
         
            +
            import warnings
         
     | 
| 10 | 
         
            +
            from abc import ABC
         
     | 
| 11 | 
         
            +
            from collections import OrderedDict, defaultdict, deque
         
     | 
| 12 | 
         
            +
            from copy import deepcopy
         
     | 
| 13 | 
         
            +
            from itertools import chain
         
     | 
| 14 | 
         
            +
            from threading import Thread
         
     | 
| 15 | 
         
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.distributed as dist
         
     | 
| 19 | 
         
            +
            import torch.nn as nn
         
     | 
| 20 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 21 | 
         
            +
            import torchvision
         
     | 
| 22 | 
         
            +
            from einops import rearrange
         
     | 
| 23 | 
         
            +
            from PIL import Image
         
     | 
| 24 | 
         
            +
            from transformers import (
         
     | 
| 25 | 
         
            +
                AutoConfig,
         
     | 
| 26 | 
         
            +
                AutoModel,
         
     | 
| 27 | 
         
            +
                AutoProcessor,
         
     | 
| 28 | 
         
            +
                AutoTokenizer,
         
     | 
| 29 | 
         
            +
                GenerationConfig,
         
     | 
| 30 | 
         
            +
                LogitsProcessor,
         
     | 
| 31 | 
         
            +
                PretrainedConfig,
         
     | 
| 32 | 
         
            +
                PreTrainedModel,
         
     | 
| 33 | 
         
            +
                Qwen2Config,
         
     | 
| 34 | 
         
            +
                Qwen2ForCausalLM,
         
     | 
| 35 | 
         
            +
                Qwen2PreTrainedModel,
         
     | 
| 36 | 
         
            +
                TextIteratorStreamer,
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         
     | 
| 39 | 
         
            +
            from transformers.modeling_utils import ContextManagers, no_init_weights
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            from .auto_processor import VILAProcessor
         
     | 
| 42 | 
         
            +
            from .base_projector import MultimodalProjector, MultimodalProjectorConfig
         
     | 
| 43 | 
         
            +
            from .builder import build_llm_and_tokenizer
         
     | 
| 44 | 
         
            +
            from .configuration_vila import VILAConfig
         
     | 
| 45 | 
         
            +
            from .constants import *
         
     | 
| 46 | 
         
            +
            from .conversation import SeparatorStyle, default_conversation
         
     | 
| 47 | 
         
            +
            from .distributed import all_gather as vila_all_gather
         
     | 
| 48 | 
         
            +
            from .loss import soft_cross_entropy
         
     | 
| 49 | 
         
            +
            from .media import extract_media
         
     | 
| 50 | 
         
            +
            from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder
         
     | 
| 51 | 
         
            +
            from .mm_utils import process_image, process_images
         
     | 
| 52 | 
         
            +
            from .model_utils_packing import set_seqlens_in_batch
         
     | 
| 53 | 
         
            +
            from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
         
     | 
| 54 | 
         
            +
            from .tokenizer_utils import tokenize_conversation
         
     | 
| 55 | 
         
            +
            from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            # ease debugging
         
     | 
| 60 | 
         
            +
            python_input = input
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            # quick hack for remote code
         
     | 
| 64 | 
         
            +
            def get_pg_manager():
         
     | 
| 65 | 
         
            +
                return None
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def get_model_weights_dtype(model: nn.Module):
         
     | 
| 69 | 
         
            +
                pass
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
         
     | 
| 73 | 
         
            +
                if model_type_or_path is None:
         
     | 
| 74 | 
         
            +
                    return None
         
     | 
| 75 | 
         
            +
                ## load from pretrained model
         
     | 
| 76 | 
         
            +
                if config.resume_path:
         
     | 
| 77 | 
         
            +
                    assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
         
     | 
| 78 | 
         
            +
                    return MultimodalProjector.from_pretrained(model_type_or_path, config)
         
     | 
| 79 | 
         
            +
                ## build from scratch
         
     | 
| 80 | 
         
            +
                else:
         
     | 
| 81 | 
         
            +
                    mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
         
     | 
| 82 | 
         
            +
                    mm_projector = MultimodalProjector(mm_projector_cfg, config)
         
     | 
| 83 | 
         
            +
                    return mm_projector
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            def check_dot_in_model_path(model_path: str):
         
     | 
| 87 | 
         
            +
                """Check if the model path contains dot, which will affect the remote code loading."""
         
     | 
| 88 | 
         
            +
                if osp.isdir(model_path):  # local model
         
     | 
| 89 | 
         
            +
                    if "." in osp.abspath(model_path):
         
     | 
| 90 | 
         
            +
                        return True
         
     | 
| 91 | 
         
            +
                else:  # remote model
         
     | 
| 92 | 
         
            +
                    if "." in model_path:
         
     | 
| 93 | 
         
            +
                        return True
         
     | 
| 94 | 
         
            +
                return False
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def get_vila_version(model_path: str) -> str:
         
     | 
| 98 | 
         
            +
                VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
         
     | 
| 99 | 
         
            +
                for version in VERSIONS:
         
     | 
| 100 | 
         
            +
                    if version in model_path.lower():
         
     | 
| 101 | 
         
            +
                        return version
         
     | 
| 102 | 
         
            +
                return None
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            def generate_jinja_template(conv_mode: str) -> str:
         
     | 
| 106 | 
         
            +
                if conv_mode == "vicuna_v1":
         
     | 
| 107 | 
         
            +
                    return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
         
     | 
| 108 | 
         
            +
            {% set roles = ["user", "assistant"] %}
         
     | 
| 109 | 
         
            +
            {% set sep = " " %}
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            {{ system_prompt }}
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            {% for message in messages %}
         
     | 
| 114 | 
         
            +
                {% if message['role'] == roles[0] %}
         
     | 
| 115 | 
         
            +
                    {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
         
     | 
| 116 | 
         
            +
                {% else %}
         
     | 
| 117 | 
         
            +
                    {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
         
     | 
| 118 | 
         
            +
                {% endif %}
         
     | 
| 119 | 
         
            +
            {% endfor %}
         
     | 
| 120 | 
         
            +
            {% if messages[-1]['role'] == 'user' %}
         
     | 
| 121 | 
         
            +
                {{ "ASSISTANT:" }}
         
     | 
| 122 | 
         
            +
            {% endif %}
         
     | 
| 123 | 
         
            +
            """
         
     | 
| 124 | 
         
            +
                elif conv_mode == "llama_3":
         
     | 
| 125 | 
         
            +
                    return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
         
     | 
| 126 | 
         
            +
            {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
         
     | 
| 127 | 
         
            +
            {% set sep = "<|eot_id|>" %}
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            {{ system_prompt }}
         
     | 
| 130 | 
         
            +
            {% for message in messages %}
         
     | 
| 131 | 
         
            +
                {% if message['role'] == 'user' %}
         
     | 
| 132 | 
         
            +
                    {{ roles[0] }}{{ message['content'] }}{{ sep }}
         
     | 
| 133 | 
         
            +
                {% else %}
         
     | 
| 134 | 
         
            +
                    {{ roles[1] }}{{ message['content'] }}{{ sep }}
         
     | 
| 135 | 
         
            +
                {% endif %}
         
     | 
| 136 | 
         
            +
            {% endfor %}
         
     | 
| 137 | 
         
            +
            {% if messages[-1]['role'] == 'user' %}
         
     | 
| 138 | 
         
            +
                {{ roles[1] }}
         
     | 
| 139 | 
         
            +
            {% endif %}
         
     | 
| 140 | 
         
            +
            """
         
     | 
| 141 | 
         
            +
                elif conv_mode == "hermes_2":
         
     | 
| 142 | 
         
            +
                    return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
         
     | 
| 143 | 
         
            +
            {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
         
     | 
| 144 | 
         
            +
            {% set sep = "<|im_end|>" %}
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            {{ system_prompt }}{{ sep }}
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
            {% for message in messages %}
         
     | 
| 149 | 
         
            +
                {% if message['role'] == 'user' %}
         
     | 
| 150 | 
         
            +
                    {{ roles[0] }}{{ message['content'] }}{{ sep }}
         
     | 
| 151 | 
         
            +
                {% else %}
         
     | 
| 152 | 
         
            +
                    {{ roles[1] }}{{ message['content'] }}{{ sep }}
         
     | 
| 153 | 
         
            +
                {% endif %}
         
     | 
| 154 | 
         
            +
            {% endfor %}"""
         
     | 
| 155 | 
         
            +
                else:
         
     | 
| 156 | 
         
            +
                    raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
            def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
         
     | 
| 160 | 
         
            +
                ## skip vision tower instantiation
         
     | 
| 161 | 
         
            +
                if model_name_or_path is None:
         
     | 
| 162 | 
         
            +
                    return None
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                vision_tower_arch = None
         
     | 
| 165 | 
         
            +
                if config.resume_path and "radio" not in model_name_or_path:
         
     | 
| 166 | 
         
            +
                    assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
         
     | 
| 167 | 
         
            +
                    vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
         
     | 
| 168 | 
         
            +
                    vision_tower_arch = vision_tower_cfg.architectures[0].lower()
         
     | 
| 169 | 
         
            +
                vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                use_s2 = getattr(config, "s2", False)
         
     | 
| 172 | 
         
            +
                use_dynamic_s2 = getattr(config, "dynamic_s2", False)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                if "siglip" in vision_tower_name:
         
     | 
| 175 | 
         
            +
                    if use_dynamic_s2:
         
     | 
| 176 | 
         
            +
                        vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
         
     | 
| 177 | 
         
            +
                    elif use_s2:
         
     | 
| 178 | 
         
            +
                        vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
         
     | 
| 179 | 
         
            +
                    else:
         
     | 
| 180 | 
         
            +
                        vision_tower = SiglipVisionTower(model_name_or_path, config)
         
     | 
| 181 | 
         
            +
                else:
         
     | 
| 182 | 
         
            +
                    raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                config.mm_hidden_size = (
         
     | 
| 185 | 
         
            +
                    vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
         
     | 
| 186 | 
         
            +
                )
         
     | 
| 187 | 
         
            +
                return vision_tower
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
            class VILAPretrainedModel(PreTrainedModel):
         
     | 
| 191 | 
         
            +
                config_class = VILAConfig
         
     | 
| 192 | 
         
            +
                main_input_name = "input_embeds"
         
     | 
| 193 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 194 | 
         
            +
                _supports_flash_attn_2 = True
         
     | 
| 195 | 
         
            +
                _no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                def __init__(self, config: VILAConfig, *args, **kwargs):
         
     | 
| 198 | 
         
            +
                    super().__init__(config)
         
     | 
| 199 | 
         
            +
                    self.config = config
         
     | 
| 200 | 
         
            +
                    cfgs = get_model_config(config)
         
     | 
| 201 | 
         
            +
                    if len(cfgs) == 3:
         
     | 
| 202 | 
         
            +
                        llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
         
     | 
| 203 | 
         
            +
                    else:
         
     | 
| 204 | 
         
            +
                        raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    # loading on auto by default
         
     | 
| 207 | 
         
            +
                    device_map = kwargs.get("device_map", "auto")
         
     | 
| 208 | 
         
            +
                    self.mm_projector = build_mm_projector(mm_projector_cfg, config)
         
     | 
| 209 | 
         
            +
                    self.vision_tower = build_vision_tower(vision_tower_cfg, config)
         
     | 
| 210 | 
         
            +
                    if device_map in ["auto", "cuda"]:
         
     | 
| 211 | 
         
            +
                        self.mm_projector = self.mm_projector.cuda()
         
     | 
| 212 | 
         
            +
                        self.vision_tower = self.vision_tower.cuda()
         
     | 
| 213 | 
         
            +
                    # set device_map auto can autoamtically shard llm to different devices
         
     | 
| 214 | 
         
            +
                    self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
         
     | 
| 215 | 
         
            +
                    self.llm_model_embed_tokens = self.llm.model.embed_tokens
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    try:
         
     | 
| 218 | 
         
            +
                        use_tsp_encoder = "TSPVideoEncoder" in getattr(config, "video_encoder", None)["_target_"]
         
     | 
| 219 | 
         
            +
                    except:
         
     | 
| 220 | 
         
            +
                        use_tsp_encoder = False
         
     | 
| 221 | 
         
            +
                    print("use_tsp_encoder", use_tsp_encoder)
         
     | 
| 222 | 
         
            +
                    self.tokenizer.padding_side = "left"
         
     | 
| 223 | 
         
            +
                    self.encoders = {"image": BasicImageEncoder(self), "video": TSPVideoEncoder(self) if use_tsp_encoder else BasicVideoEncoder(self)}
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    self.post_config()
         
     | 
| 226 | 
         
            +
                    self.is_loaded = True
         
     | 
| 227 | 
         
            +
                    self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
         
     | 
| 228 | 
         
            +
                    if self.llm_only_need_embed:
         
     | 
| 229 | 
         
            +
                        print("We only need the embed_tokens in llm.")
         
     | 
| 230 | 
         
            +
                        del self.llm
         
     | 
| 231 | 
         
            +
                        self.llm = None
         
     | 
| 232 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    assert (
         
     | 
| 235 | 
         
            +
                        self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
         
     | 
| 236 | 
         
            +
                    ), "At least one of the components must be instantiated."
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                @classmethod
         
     | 
| 239 | 
         
            +
                def convert_vila_dev_ckpt_to_remote(
         
     | 
| 240 | 
         
            +
                    self,
         
     | 
| 241 | 
         
            +
                    model_path: str,
         
     | 
| 242 | 
         
            +
                    output_dir: str = None,
         
     | 
| 243 | 
         
            +
                    vila_version: str | None = None,
         
     | 
| 244 | 
         
            +
                    conv_mode: str | None = None,
         
     | 
| 245 | 
         
            +
                    copy: bool = False,
         
     | 
| 246 | 
         
            +
                    copy_weights: bool = True,
         
     | 
| 247 | 
         
            +
                    copy_code: bool = True,
         
     | 
| 248 | 
         
            +
                    *model_args,
         
     | 
| 249 | 
         
            +
                    **kwargs,
         
     | 
| 250 | 
         
            +
                ):
         
     | 
| 251 | 
         
            +
                    # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
         
     | 
| 252 | 
         
            +
                    assert model_path != output_dir, "model_path and output_dir cannot be the same"
         
     | 
| 253 | 
         
            +
                    if os.path.isdir(model_path):
         
     | 
| 254 | 
         
            +
                        model_path = model_path
         
     | 
| 255 | 
         
            +
                    else:
         
     | 
| 256 | 
         
            +
                        from huggingface_hub import HfApi, snapshot_download
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                        model_path = snapshot_download(model_path)
         
     | 
| 259 | 
         
            +
                        print("downloading HF model to", model_path)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    if check_dot_in_model_path(model_path) and output_dir is None:
         
     | 
| 262 | 
         
            +
                        raise ValueError(
         
     | 
| 263 | 
         
            +
                            f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
         
     | 
| 264 | 
         
            +
                        )
         
     | 
| 265 | 
         
            +
                    if output_dir is not None and "." in output_dir:
         
     | 
| 266 | 
         
            +
                        raise ValueError(
         
     | 
| 267 | 
         
            +
                            f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
         
     | 
| 268 | 
         
            +
                        )
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    if copy:
         
     | 
| 271 | 
         
            +
                        print("copy is set to True, copying weights and code to output_dir")
         
     | 
| 272 | 
         
            +
                        copy_weights = copy_code = True
         
     | 
| 273 | 
         
            +
                    # copy weights and code to output_dir
         
     | 
| 274 | 
         
            +
                    self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights)
         
     | 
| 275 | 
         
            +
                    self.copy_remote_py_files(output_dir, copy=copy_code)
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    if vila_version is None:
         
     | 
| 278 | 
         
            +
                        vila_version = get_vila_version(output_dir)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    cfg_path = os.path.join(output_dir, "config.json")
         
     | 
| 281 | 
         
            +
                    config = json.load(open(cfg_path))
         
     | 
| 282 | 
         
            +
                    config["version"] = "2.0"  # nvila tag
         
     | 
| 283 | 
         
            +
                    config["architectures"] = ["VILAForCausalLM"]
         
     | 
| 284 | 
         
            +
                    config["auto_map"] = {
         
     | 
| 285 | 
         
            +
                        "AutoProcessor": "auto_processor.VILAProcessor",
         
     | 
| 286 | 
         
            +
                        "AutoConfig": "modeling_vila.VILAConfig",
         
     | 
| 287 | 
         
            +
                        "AutoModel": "modeling_vila.VILAForCausalLM",
         
     | 
| 288 | 
         
            +
                        "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
         
     | 
| 289 | 
         
            +
                    }
         
     | 
| 290 | 
         
            +
                    # vila1.5 legacy support
         
     | 
| 291 | 
         
            +
                    config["model_type"] = "vila"
         
     | 
| 292 | 
         
            +
                    if vila_version in ["vila1.5", "vila-m3"]:
         
     | 
| 293 | 
         
            +
                        if conv_mode is None:
         
     | 
| 294 | 
         
            +
                            raise ValueError(f"Please specify the conversation mode for {output_dir}.")
         
     | 
| 295 | 
         
            +
                        config["chat_template"] = conv_mode
         
     | 
| 296 | 
         
            +
                        jinja_template = generate_jinja_template(conv_mode)
         
     | 
| 297 | 
         
            +
                        jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja")
         
     | 
| 298 | 
         
            +
                        with open(jinja_path, "w") as f:
         
     | 
| 299 | 
         
            +
                            f.write(jinja_template)
         
     | 
| 300 | 
         
            +
                    json.dump(config, open(cfg_path, "w"), indent=2)
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    ##########################################################################################
         
     | 
| 303 | 
         
            +
                    config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True)
         
     | 
| 304 | 
         
            +
                    tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config)
         
     | 
| 305 | 
         
            +
                    tokenizer.save_pretrained(osp.join(output_dir, "llm"))
         
     | 
| 306 | 
         
            +
                    ##########################################################################################
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                @classmethod
         
     | 
| 309 | 
         
            +
                def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
         
     | 
| 310 | 
         
            +
                    # Create output directory if it doesn't exist
         
     | 
| 311 | 
         
            +
                    os.makedirs(output_dir, exist_ok=True)
         
     | 
| 312 | 
         
            +
                    # Create symlinks for all files in model_path to output_dir
         
     | 
| 313 | 
         
            +
                    for item in os.listdir(model_path):
         
     | 
| 314 | 
         
            +
                        src_path = os.path.join(model_path, item)
         
     | 
| 315 | 
         
            +
                        dst_path = os.path.join(output_dir, item)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                        # Remove existing file/directory at destination if it exists
         
     | 
| 318 | 
         
            +
                        if os.path.exists(dst_path):
         
     | 
| 319 | 
         
            +
                            if os.path.islink(dst_path):
         
     | 
| 320 | 
         
            +
                                os.unlink(dst_path)
         
     | 
| 321 | 
         
            +
                            elif os.path.isdir(dst_path):
         
     | 
| 322 | 
         
            +
                                shutil.rmtree(dst_path)
         
     | 
| 323 | 
         
            +
                            else:
         
     | 
| 324 | 
         
            +
                                os.remove(dst_path)
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                        # Create symlink
         
     | 
| 327 | 
         
            +
                        if copy:
         
     | 
| 328 | 
         
            +
                            if os.path.isdir(src_path):
         
     | 
| 329 | 
         
            +
                                shutil.copytree(src_path, dst_path)
         
     | 
| 330 | 
         
            +
                            else:
         
     | 
| 331 | 
         
            +
                                shutil.copy2(src_path, dst_path)
         
     | 
| 332 | 
         
            +
                            print(f"Copied {src_path} to {dst_path}")
         
     | 
| 333 | 
         
            +
                        else:
         
     | 
| 334 | 
         
            +
                            os.symlink(src_path, dst_path)
         
     | 
| 335 | 
         
            +
                            print(f"Created symlink from {src_path} to {dst_path}")
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                @classmethod
         
     | 
| 338 | 
         
            +
                def copy_remote_py_files(cls, output_dir, copy=True):
         
     | 
| 339 | 
         
            +
                    ## copy .py and REAMDE for next loading remote code
         
     | 
| 340 | 
         
            +
                    current_file_path = os.path.abspath(__file__)
         
     | 
| 341 | 
         
            +
                    current_folder = os.path.dirname(current_file_path)
         
     | 
| 342 | 
         
            +
                    for file_name in os.listdir(current_folder):
         
     | 
| 343 | 
         
            +
                        if file_name == "INSTRUCTIONS.md":
         
     | 
| 344 | 
         
            +
                            src_fname = os.path.join(current_folder, file_name)
         
     | 
| 345 | 
         
            +
                            dst_fname = os.path.join(output_dir, "README.md")
         
     | 
| 346 | 
         
            +
                            if os.path.exists(dst_fname):
         
     | 
| 347 | 
         
            +
                                old_reamde = open(dst_fname).read()
         
     | 
| 348 | 
         
            +
                            else:
         
     | 
| 349 | 
         
            +
                                old_reamde = ""
         
     | 
| 350 | 
         
            +
                            with open(src_fname) as src, open(dst_fname, "w") as dst:
         
     | 
| 351 | 
         
            +
                                dst.write(src.read())
         
     | 
| 352 | 
         
            +
                                dst.write(old_reamde)
         
     | 
| 353 | 
         
            +
                            print("[HF remote code] REAMDE ", src_fname, "to", dst_fname)
         
     | 
| 354 | 
         
            +
                        if file_name.endswith(".py") or file_name.endswith(".jinja"):
         
     | 
| 355 | 
         
            +
                            full_file_name = os.path.join(current_folder, file_name)
         
     | 
| 356 | 
         
            +
                            if os.path.isfile(full_file_name):
         
     | 
| 357 | 
         
            +
                                if copy:
         
     | 
| 358 | 
         
            +
                                    shutil.copy(full_file_name, output_dir)
         
     | 
| 359 | 
         
            +
                                    print("[HF remote code] copying", full_file_name, "to", output_dir)
         
     | 
| 360 | 
         
            +
                                else:
         
     | 
| 361 | 
         
            +
                                    # symlink to ease development
         
     | 
| 362 | 
         
            +
                                    if os.path.exists(os.path.join(output_dir, file_name)):
         
     | 
| 363 | 
         
            +
                                        os.remove(os.path.join(output_dir, file_name))
         
     | 
| 364 | 
         
            +
                                    os.symlink(full_file_name, os.path.join(output_dir, file_name))
         
     | 
| 365 | 
         
            +
                                    print("[HF remote code] linking", full_file_name, "to", output_dir)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                def save_pretrained(self, output_dir, state_dict=None, **kwargs):
         
     | 
| 368 | 
         
            +
                    if state_dict is None:
         
     | 
| 369 | 
         
            +
                        # other wise fetch from deepspeed
         
     | 
| 370 | 
         
            +
                        # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
         
     | 
| 371 | 
         
            +
                        state_dict = self.state_dict()
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    if getattr(self, "tokenizer", None):
         
     | 
| 374 | 
         
            +
                        self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    if self.get_llm():
         
     | 
| 377 | 
         
            +
                        print(f"saving llm to {osp.join(output_dir, 'llm')}")
         
     | 
| 378 | 
         
            +
                        self.llm.config._name_or_path = osp.join(output_dir, "llm")
         
     | 
| 379 | 
         
            +
                        llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
         
     | 
| 380 | 
         
            +
                        self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
         
     | 
| 381 | 
         
            +
                        self.config.llm_cfg = self.llm.config
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    if self.get_vision_tower():
         
     | 
| 384 | 
         
            +
                        print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
         
     | 
| 385 | 
         
            +
                        self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
         
     | 
| 386 | 
         
            +
                        vision_tower_state_dict = OrderedDict(
         
     | 
| 387 | 
         
            +
                            {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
         
     | 
| 388 | 
         
            +
                        )
         
     | 
| 389 | 
         
            +
                        self.vision_tower.vision_tower.save_pretrained(
         
     | 
| 390 | 
         
            +
                            os.path.join(output_dir, "vision_tower"),
         
     | 
| 391 | 
         
            +
                            state_dict=vision_tower_state_dict,
         
     | 
| 392 | 
         
            +
                        )
         
     | 
| 393 | 
         
            +
                        self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
         
     | 
| 394 | 
         
            +
                        self.config.vision_tower_cfg = self.vision_tower.config
         
     | 
| 395 | 
         
            +
                        if hasattr(self.config.vision_tower_cfg, "auto_map"):
         
     | 
| 396 | 
         
            +
                            if "radio" not in self.get_vision_tower().__class__.__name__.lower():
         
     | 
| 397 | 
         
            +
                                delattr(self.config.vision_tower_cfg, "auto_map")
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    if self.get_mm_projector():
         
     | 
| 400 | 
         
            +
                        print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
         
     | 
| 401 | 
         
            +
                        self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
         
     | 
| 402 | 
         
            +
                        mm_projector_state_dict = OrderedDict(
         
     | 
| 403 | 
         
            +
                            {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
         
     | 
| 404 | 
         
            +
                        )
         
     | 
| 405 | 
         
            +
                        self.mm_projector.save_pretrained(
         
     | 
| 406 | 
         
            +
                            os.path.join(output_dir, "mm_projector"),
         
     | 
| 407 | 
         
            +
                            state_dict=mm_projector_state_dict,
         
     | 
| 408 | 
         
            +
                        )
         
     | 
| 409 | 
         
            +
                        self.config.mm_projector_cfg = self.mm_projector.config
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                    ## update and save top-level config
         
     | 
| 412 | 
         
            +
                    self.config._name_or_path = output_dir
         
     | 
| 413 | 
         
            +
                    self.config.architectures = [self.__class__.__name__]
         
     | 
| 414 | 
         
            +
                    self.config.save_pretrained(output_dir)
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                    ## copy .py and REAMDE for next loading remote code
         
     | 
| 417 | 
         
            +
                    self.copy_remote_py_files(output_dir)
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                @classmethod
         
     | 
| 420 | 
         
            +
                def from_pretrained(
         
     | 
| 421 | 
         
            +
                    cls,
         
     | 
| 422 | 
         
            +
                    pretrained_model_name_or_path: Optional[str] = None,
         
     | 
| 423 | 
         
            +
                    *model_args,
         
     | 
| 424 | 
         
            +
                    config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
         
     | 
| 425 | 
         
            +
                    cache_dir: Optional[Union[str, os.PathLike]] = None,
         
     | 
| 426 | 
         
            +
                    ignore_mismatched_sizes: bool = False,
         
     | 
| 427 | 
         
            +
                    force_download: bool = False,
         
     | 
| 428 | 
         
            +
                    local_files_only: bool = False,
         
     | 
| 429 | 
         
            +
                    token: Optional[Union[str, bool]] = None,
         
     | 
| 430 | 
         
            +
                    revision: str = "main",
         
     | 
| 431 | 
         
            +
                    use_safetensors: Optional[bool] = None,
         
     | 
| 432 | 
         
            +
                    weights_only: bool = True,
         
     | 
| 433 | 
         
            +
                    **kwargs,
         
     | 
| 434 | 
         
            +
                ):
         
     | 
| 435 | 
         
            +
                    config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
         
     | 
| 436 | 
         
            +
                    return cls._from_config(config, **kwargs)
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                def init_llm(self, llm_config, config, *args, **kwargs):
         
     | 
| 439 | 
         
            +
                    self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
         
     | 
| 440 | 
         
            +
                    # hard coded for NVILA
         
     | 
| 441 | 
         
            +
                    # variables for XGrammar
         
     | 
| 442 | 
         
            +
                    NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                    self.pad_token_list = (
         
     | 
| 445 | 
         
            +
                        self.tokenizer.pad_token_id,
         
     | 
| 446 | 
         
            +
                        self.tokenizer.eos_token_id,
         
     | 
| 447 | 
         
            +
                        self.tokenizer.tokenize("<|endoftext|>")[0],  # for qwen
         
     | 
| 448 | 
         
            +
                    )
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                    self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
         
     | 
| 451 | 
         
            +
                    # XGrammar tokenizer and grammar compiler
         
     | 
| 452 | 
         
            +
                    # lazy init only when specified json output during inference
         
     | 
| 453 | 
         
            +
                    self.grammar_compiler = None
         
     | 
| 454 | 
         
            +
                    self.llm.resize_token_embeddings(len(self.tokenizer))
         
     | 
| 455 | 
         
            +
                    return self.llm, self.tokenizer
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                def post_config(self):
         
     | 
| 458 | 
         
            +
                    ######################################################################
         
     | 
| 459 | 
         
            +
                    self.llm = self.llm.to(torch.float16)
         
     | 
| 460 | 
         
            +
                    self.mm_projector = self.mm_projector.to(torch.float16)
         
     | 
| 461 | 
         
            +
                    self.vision_tower = self.vision_tower.to(torch.float16)
         
     | 
| 462 | 
         
            +
                    ######################################################################
         
     | 
| 463 | 
         
            +
                    self.training = self.llm.training
         
     | 
| 464 | 
         
            +
                    if self.training:
         
     | 
| 465 | 
         
            +
                        self.train()
         
     | 
| 466 | 
         
            +
                    else:
         
     | 
| 467 | 
         
            +
                        self.eval()
         
     | 
| 468 | 
         
            +
                    ## configuration
         
     | 
| 469 | 
         
            +
                    if getattr(self.config, "llm_cfg", None) is None:
         
     | 
| 470 | 
         
            +
                        self.config.llm_cfg = self.llm.config
         
     | 
| 471 | 
         
            +
                    if getattr(self.config, "vision_tower_cfg", None) is None:
         
     | 
| 472 | 
         
            +
                        self.config.vision_tower_cfg = self.vision_tower.config
         
     | 
| 473 | 
         
            +
                    if getattr(self.config, "mm_projector_cfg", None) is None:
         
     | 
| 474 | 
         
            +
                        self.config.mm_projector_cfg = self.mm_projector.config
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                def get_llm(self):
         
     | 
| 477 | 
         
            +
                    llm = getattr(self, "llm", None)
         
     | 
| 478 | 
         
            +
                    if type(llm) is list:
         
     | 
| 479 | 
         
            +
                        llm = llm[0]
         
     | 
| 480 | 
         
            +
                    return llm
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                def get_lm_head(self):
         
     | 
| 483 | 
         
            +
                    lm_head = getattr(self.get_llm(), "lm_head", None)
         
     | 
| 484 | 
         
            +
                    return lm_head
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                def get_vision_tower(self):
         
     | 
| 487 | 
         
            +
                    vision_tower = getattr(self, "vision_tower", None)
         
     | 
| 488 | 
         
            +
                    if type(vision_tower) is list:
         
     | 
| 489 | 
         
            +
                        vision_tower = vision_tower[0]
         
     | 
| 490 | 
         
            +
                    return vision_tower
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                def get_mm_projector(self):
         
     | 
| 493 | 
         
            +
                    mm_projector = getattr(self, "mm_projector", None)
         
     | 
| 494 | 
         
            +
                    if type(mm_projector) is list:
         
     | 
| 495 | 
         
            +
                        mm_projector = mm_projector[0]
         
     | 
| 496 | 
         
            +
                    return mm_projector
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                def freezed_module_patch(self):
         
     | 
| 499 | 
         
            +
                    """
         
     | 
| 500 | 
         
            +
                    Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
         
     | 
| 501 | 
         
            +
                    """
         
     | 
| 502 | 
         
            +
                    if self.training:
         
     | 
| 503 | 
         
            +
                        if self.get_llm() and not getattr(self.config, "tune_language_model", False):
         
     | 
| 504 | 
         
            +
                            pass
         
     | 
| 505 | 
         
            +
                            # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
         
     | 
| 506 | 
         
            +
                        if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
         
     | 
| 507 | 
         
            +
                            self.get_vision_tower().eval()
         
     | 
| 508 | 
         
            +
                        if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
         
     | 
| 509 | 
         
            +
                            self.get_mm_projector().eval()
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
            class VILAForCausalLM(VILAPretrainedModel):
         
     | 
| 513 | 
         
            +
                def __init__(self, config: VILAConfig, *args, **kwargs):
         
     | 
| 514 | 
         
            +
                    super().__init__(config, *args, **kwargs)
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
                def merge_features_for_dynamic_s2(self, image_features, block_sizes):
         
     | 
| 517 | 
         
            +
                    scales = self.get_vision_tower().scales
         
     | 
| 518 | 
         
            +
                    resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                    image_features_each_image = []
         
     | 
| 521 | 
         
            +
                    new_block_sizes = []
         
     | 
| 522 | 
         
            +
                    block_cnt = 0
         
     | 
| 523 | 
         
            +
                    for block_size_each_image in block_sizes:
         
     | 
| 524 | 
         
            +
                        if block_size_each_image is None:
         
     | 
| 525 | 
         
            +
                            cur_features = image_features[block_cnt : block_cnt + 1]
         
     | 
| 526 | 
         
            +
                            cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
         
     | 
| 527 | 
         
            +
                            cur_features = cur_features.repeat(1, len(scales), 1, 1)
         
     | 
| 528 | 
         
            +
                            image_features_each_image.append(cur_features)
         
     | 
| 529 | 
         
            +
                            new_block_sizes.append((1, 1))
         
     | 
| 530 | 
         
            +
                            block_cnt += 1
         
     | 
| 531 | 
         
            +
                        else:
         
     | 
| 532 | 
         
            +
                            cur_features_each_scale = []
         
     | 
| 533 | 
         
            +
                            for scale in scales[:-1]:
         
     | 
| 534 | 
         
            +
                                num_blocks_this_scale = (scale // scales[0]) ** 2
         
     | 
| 535 | 
         
            +
                                cur_features_each_scale.append(
         
     | 
| 536 | 
         
            +
                                    self.merge_chessboard(
         
     | 
| 537 | 
         
            +
                                        image_features[block_cnt : block_cnt + num_blocks_this_scale],
         
     | 
| 538 | 
         
            +
                                        num_split_h=scale // scales[0],
         
     | 
| 539 | 
         
            +
                                        num_split_w=scale // scales[0],
         
     | 
| 540 | 
         
            +
                                    )
         
     | 
| 541 | 
         
            +
                                )  # 1 * C * H * W
         
     | 
| 542 | 
         
            +
                                block_cnt += num_blocks_this_scale
         
     | 
| 543 | 
         
            +
                            num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
         
     | 
| 544 | 
         
            +
                            cur_features_each_scale.append(
         
     | 
| 545 | 
         
            +
                                self.merge_chessboard(
         
     | 
| 546 | 
         
            +
                                    image_features[block_cnt : block_cnt + num_blocks_last_scale],
         
     | 
| 547 | 
         
            +
                                    num_split_h=block_size_each_image[0],
         
     | 
| 548 | 
         
            +
                                    num_split_w=block_size_each_image[1],
         
     | 
| 549 | 
         
            +
                                )
         
     | 
| 550 | 
         
            +
                            )  # 1 * C * H * W
         
     | 
| 551 | 
         
            +
                            block_cnt += num_blocks_last_scale
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
                            # resize and concat features from different scales
         
     | 
| 554 | 
         
            +
                            output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
         
     | 
| 555 | 
         
            +
                            cur_features = torch.cat(
         
     | 
| 556 | 
         
            +
                                [
         
     | 
| 557 | 
         
            +
                                    F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
         
     | 
| 558 | 
         
            +
                                        cur_features_each_scale[i].dtype
         
     | 
| 559 | 
         
            +
                                    )
         
     | 
| 560 | 
         
            +
                                    for i in range(len(cur_features_each_scale))
         
     | 
| 561 | 
         
            +
                                ],
         
     | 
| 562 | 
         
            +
                                dim=1,
         
     | 
| 563 | 
         
            +
                            )
         
     | 
| 564 | 
         
            +
                            # cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                            image_features_each_image.append(cur_features)
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                            if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
         
     | 
| 569 | 
         
            +
                                new_block_sizes.append(block_size_each_image)
         
     | 
| 570 | 
         
            +
                            else:
         
     | 
| 571 | 
         
            +
                                new_block_sizes.append(
         
     | 
| 572 | 
         
            +
                                    (
         
     | 
| 573 | 
         
            +
                                        scales[resize_output_to_scale_idx] // scales[0],
         
     | 
| 574 | 
         
            +
                                        scales[resize_output_to_scale_idx] // scales[0],
         
     | 
| 575 | 
         
            +
                                    )
         
     | 
| 576 | 
         
            +
                                )
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                    assert block_cnt == len(image_features)
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
                    return image_features_each_image, new_block_sizes
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
         
     | 
| 583 | 
         
            +
                    if block_sizes is None:
         
     | 
| 584 | 
         
            +
                        block_sizes = [None] * len(images)
         
     | 
| 585 | 
         
            +
                    if getattr(self.config, "dynamic_s2", False):
         
     | 
| 586 | 
         
            +
                        image_features = self.get_vision_tower()(images)
         
     | 
| 587 | 
         
            +
                        image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                        image_features = [
         
     | 
| 590 | 
         
            +
                            self.split_chessboard(x, block_size[0], block_size[1])
         
     | 
| 591 | 
         
            +
                            for x, block_size in zip(image_features, new_block_sizes)
         
     | 
| 592 | 
         
            +
                        ]  # list of B * C * H * W tensors
         
     | 
| 593 | 
         
            +
                        image_features = torch.cat(
         
     | 
| 594 | 
         
            +
                            [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
         
     | 
| 595 | 
         
            +
                        )  # B * N * C
         
     | 
| 596 | 
         
            +
                        image_features = self.get_mm_projector()(image_features)
         
     | 
| 597 | 
         
            +
                        image_features = list(
         
     | 
| 598 | 
         
            +
                            image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
         
     | 
| 599 | 
         
            +
                        )
         
     | 
| 600 | 
         
            +
                        image_features = [
         
     | 
| 601 | 
         
            +
                            self.merge_chessboard(x, block_size[0], block_size[1])
         
     | 
| 602 | 
         
            +
                            for x, block_size in zip(image_features, new_block_sizes)
         
     | 
| 603 | 
         
            +
                        ]  # list of 1 * C * H * W tensors
         
     | 
| 604 | 
         
            +
                        image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features]  # list of N * C tensors
         
     | 
| 605 | 
         
            +
                        if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
         
     | 
| 606 | 
         
            +
                            image_features = torch.stack(image_features, dim=0)
         
     | 
| 607 | 
         
            +
                    else:
         
     | 
| 608 | 
         
            +
                        image_features = self.get_vision_tower()(images)
         
     | 
| 609 | 
         
            +
                        image_features = self.get_mm_projector()(image_features)
         
     | 
| 610 | 
         
            +
                    return image_features
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                def train(self, mode: bool = True):
         
     | 
| 613 | 
         
            +
                    super().train(mode)
         
     | 
| 614 | 
         
            +
                    return self
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
                @torch.inference_mode()
         
     | 
| 617 | 
         
            +
                def _embed(
         
     | 
| 618 | 
         
            +
                    self,
         
     | 
| 619 | 
         
            +
                    input_ids: torch.Tensor,
         
     | 
| 620 | 
         
            +
                    media: Dict[str, List[torch.Tensor]],
         
     | 
| 621 | 
         
            +
                    media_config: Dict[str, Dict[str, Any]],
         
     | 
| 622 | 
         
            +
                    labels: Optional[torch.Tensor],
         
     | 
| 623 | 
         
            +
                    attention_mask: Optional[torch.Tensor],
         
     | 
| 624 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         
     | 
| 625 | 
         
            +
                    media = copy.deepcopy(media)
         
     | 
| 626 | 
         
            +
                    media_config = copy.deepcopy(media_config)
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
                    labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
         
     | 
| 629 | 
         
            +
                    attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
         
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
                    PROCESS_GROUP_MANAGER = get_pg_manager()
         
     | 
| 632 | 
         
            +
                    if PROCESS_GROUP_MANAGER is not None:
         
     | 
| 633 | 
         
            +
                        for name in media:
         
     | 
| 634 | 
         
            +
                            self.encoders[name].end_tokens = None
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                    # Extract text and media embeddings
         
     | 
| 637 | 
         
            +
                    text_embeds = self.llm_model_embed_tokens(input_ids)
         
     | 
| 638 | 
         
            +
                    
         
     | 
| 639 | 
         
            +
                    use_cache = False
         
     | 
| 640 | 
         
            +
                    if "use_cache" in media_config:
         
     | 
| 641 | 
         
            +
                        use_cache = media_config.pop("use_cache")
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
                    if use_cache:
         
     | 
| 644 | 
         
            +
                        print("Use cached embedding")
         
     | 
| 645 | 
         
            +
                    if media is not None:
         
     | 
| 646 | 
         
            +
                        media_embeds = media if use_cache else self.__embed_media_tokens(media, media_config)
         
     | 
| 647 | 
         
            +
                    else:
         
     | 
| 648 | 
         
            +
                        # no media was provided, so we just return an empty dict
         
     | 
| 649 | 
         
            +
                        media_embeds = {}
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
                    # This is a workaround to make sure the dummy embeddings are consumed
         
     | 
| 652 | 
         
            +
                    while media_embeds.get("dummy"):
         
     | 
| 653 | 
         
            +
                        dummy_embed = media_embeds["dummy"].popleft()
         
     | 
| 654 | 
         
            +
                        text_embeds += torch.sum(dummy_embed) * 0
         
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                    # Remove padding
         
     | 
| 657 | 
         
            +
                    batch_size = labels.shape[0]
         
     | 
| 658 | 
         
            +
                    text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
         
     | 
| 659 | 
         
            +
                    labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
         
     | 
| 660 | 
         
            +
             
     | 
| 661 | 
         
            +
                    # Build inverse mapping from token ID to media name
         
     | 
| 662 | 
         
            +
                    media_tokens = {}
         
     | 
| 663 | 
         
            +
                    for name, token_id in self.tokenizer.media_token_ids.items():
         
     | 
| 664 | 
         
            +
                        media_tokens[token_id] = name
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                    # Fuse text and media embeddings
         
     | 
| 667 | 
         
            +
                    inputs_m, labels_m = [], []
         
     | 
| 668 | 
         
            +
                    for k in range(batch_size):
         
     | 
| 669 | 
         
            +
                        inputs_mk, labels_mk = [], []
         
     | 
| 670 | 
         
            +
                        pos = 0
         
     | 
| 671 | 
         
            +
                        while pos < len(labels[k]):
         
     | 
| 672 | 
         
            +
                            if input_ids[k][pos].item() in media_tokens:
         
     | 
| 673 | 
         
            +
                                end = pos + 1
         
     | 
| 674 | 
         
            +
                                name = media_tokens[input_ids[k][pos].item()]
         
     | 
| 675 | 
         
            +
                                input = media_embeds[name].popleft()
         
     | 
| 676 | 
         
            +
                                label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
         
     | 
| 677 | 
         
            +
                            elif input_ids[k][pos].item() in self.pad_token_list:
         
     | 
| 678 | 
         
            +
                                # skip pad tokens
         
     | 
| 679 | 
         
            +
                                end = pos + 1
         
     | 
| 680 | 
         
            +
                                pos = end
         
     | 
| 681 | 
         
            +
                                continue
         
     | 
| 682 | 
         
            +
                            else:
         
     | 
| 683 | 
         
            +
                                end = pos
         
     | 
| 684 | 
         
            +
                                while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
         
     | 
| 685 | 
         
            +
                                    end += 1
         
     | 
| 686 | 
         
            +
                                input = text_embeds[k][pos:end]
         
     | 
| 687 | 
         
            +
                                label = labels[k][pos:end]
         
     | 
| 688 | 
         
            +
             
     | 
| 689 | 
         
            +
                            inputs_mk.append(input)
         
     | 
| 690 | 
         
            +
                            labels_mk.append(label)
         
     | 
| 691 | 
         
            +
                            pos = end
         
     | 
| 692 | 
         
            +
                        inputs_m.append(torch.cat(inputs_mk, dim=0))
         
     | 
| 693 | 
         
            +
                        labels_m.append(torch.cat(labels_mk, dim=0))
         
     | 
| 694 | 
         
            +
                    inputs, labels = inputs_m, labels_m
         
     | 
| 695 | 
         
            +
             
     | 
| 696 | 
         
            +
                    # Check if all media embeddings are consumed
         
     | 
| 697 | 
         
            +
                    for name in media_embeds:
         
     | 
| 698 | 
         
            +
                        if media_embeds[name]:
         
     | 
| 699 | 
         
            +
                            raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
         
     | 
| 700 | 
         
            +
             
     | 
| 701 | 
         
            +
                    # Truncate sequences to `model_max_length` as media embeddings are inserted
         
     | 
| 702 | 
         
            +
                    inputs, labels = self.__truncate_sequence(inputs, labels)
         
     | 
| 703 | 
         
            +
             
     | 
| 704 | 
         
            +
                    # Pad sequences to the longest one in the batch
         
     | 
| 705 | 
         
            +
                    return self.__batchify_sequence(inputs, labels)
         
     | 
| 706 | 
         
            +
             
     | 
| 707 | 
         
            +
                def __embed_media_tokens(
         
     | 
| 708 | 
         
            +
                    self,
         
     | 
| 709 | 
         
            +
                    media: Dict[str, List[torch.Tensor]],
         
     | 
| 710 | 
         
            +
                    media_config: Dict[str, Dict[str, Any]],
         
     | 
| 711 | 
         
            +
                ) -> Dict[str, List[torch.Tensor]]:
         
     | 
| 712 | 
         
            +
                    embeds = defaultdict(deque)
         
     | 
| 713 | 
         
            +
                    for name in media:
         
     | 
| 714 | 
         
            +
                        if self.training:
         
     | 
| 715 | 
         
            +
                            # Gather metainfo of media objects from all ranks
         
     | 
| 716 | 
         
            +
                            info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
         
     | 
| 717 | 
         
            +
                            infos = list(chain(vila_all_gather(info)))
         
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
                            # The entire batch does not contain any media objects of this type.
         
     | 
| 720 | 
         
            +
                            if not infos:
         
     | 
| 721 | 
         
            +
                                continue
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
                            # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
         
     | 
| 724 | 
         
            +
                            if media.get(name) is None or len(media[name]) == 0:
         
     | 
| 725 | 
         
            +
                                dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
         
     | 
| 726 | 
         
            +
                                embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
         
     | 
| 727 | 
         
            +
                                continue
         
     | 
| 728 | 
         
            +
                        embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
         
     | 
| 729 | 
         
            +
                    return embeds
         
     | 
| 730 | 
         
            +
             
     | 
| 731 | 
         
            +
                def __truncate_sequence(
         
     | 
| 732 | 
         
            +
                    self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
         
     | 
| 733 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 734 | 
         
            +
                    if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
         
     | 
| 735 | 
         
            +
                        warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
         
     | 
| 736 | 
         
            +
                        inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
         
     | 
| 737 | 
         
            +
                        labels = [label[: self.tokenizer.model_max_length] for label in labels]
         
     | 
| 738 | 
         
            +
                    return inputs, labels
         
     | 
| 739 | 
         
            +
             
     | 
| 740 | 
         
            +
                def __batchify_sequence(
         
     | 
| 741 | 
         
            +
                    self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
         
     | 
| 742 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         
     | 
| 743 | 
         
            +
                    batch_size = len(inputs)
         
     | 
| 744 | 
         
            +
                    device = inputs[0].device
         
     | 
| 745 | 
         
            +
                    hidden_size = inputs[0].shape[1]
         
     | 
| 746 | 
         
            +
                    max_length = max(inputs[k].shape[0] for k in range(batch_size))
         
     | 
| 747 | 
         
            +
                    attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
         
     | 
| 748 | 
         
            +
             
     | 
| 749 | 
         
            +
                    inputs_p, labels_p = [], []
         
     | 
| 750 | 
         
            +
                    for k in range(batch_size):
         
     | 
| 751 | 
         
            +
                        size_pk = max_length - inputs[k].shape[0]
         
     | 
| 752 | 
         
            +
                        inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
         
     | 
| 753 | 
         
            +
                        labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
         
     | 
| 754 | 
         
            +
                        if self.tokenizer.padding_side == "right":
         
     | 
| 755 | 
         
            +
                            attention_mask[k, inputs[k].shape[0] :] = False
         
     | 
| 756 | 
         
            +
                            inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
         
     | 
| 757 | 
         
            +
                            labels_pk = torch.cat([labels[k], labels_pk], dim=0)
         
     | 
| 758 | 
         
            +
                        else:
         
     | 
| 759 | 
         
            +
                            attention_mask[k, : -inputs[k].shape[0]] = False
         
     | 
| 760 | 
         
            +
                            inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
         
     | 
| 761 | 
         
            +
                            labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
         
     | 
| 762 | 
         
            +
                        inputs_p.append(inputs_pk)
         
     | 
| 763 | 
         
            +
                        labels_p.append(labels_pk)
         
     | 
| 764 | 
         
            +
             
     | 
| 765 | 
         
            +
                    inputs = torch.stack(inputs_p, dim=0)
         
     | 
| 766 | 
         
            +
                    labels = torch.stack(labels_p, dim=0)
         
     | 
| 767 | 
         
            +
                    return inputs, labels, attention_mask
         
     | 
| 768 | 
         
            +
             
     | 
| 769 | 
         
            +
                def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
         
     | 
| 770 | 
         
            +
                    # Handle sequence parallelism
         
     | 
| 771 | 
         
            +
                    PROCESS_GROUP_MANAGER = get_pg_manager()
         
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
                    # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
         
     | 
| 774 | 
         
            +
                    if PROCESS_GROUP_MANAGER is not None:
         
     | 
| 775 | 
         
            +
                        sp_degree = PROCESS_GROUP_MANAGER.sp_degree
         
     | 
| 776 | 
         
            +
                        sp_rank = PROCESS_GROUP_MANAGER.sp_rank
         
     | 
| 777 | 
         
            +
                        sp_group = PROCESS_GROUP_MANAGER.sp_pg
         
     | 
| 778 | 
         
            +
                        ring_degree = PROCESS_GROUP_MANAGER.ring_degree
         
     | 
| 779 | 
         
            +
                        ring_rank = PROCESS_GROUP_MANAGER.ring_rank
         
     | 
| 780 | 
         
            +
                        ring_type = PROCESS_GROUP_MANAGER.ring_type
         
     | 
| 781 | 
         
            +
                        ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
         
     | 
| 782 | 
         
            +
                        ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
         
     | 
| 783 | 
         
            +
             
     | 
| 784 | 
         
            +
                        bs, shard_seqlen = position_ids.shape
         
     | 
| 785 | 
         
            +
                        sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
         
     | 
| 786 | 
         
            +
                        dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
         
     | 
| 787 | 
         
            +
                        sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
         
     | 
| 788 | 
         
            +
             
     | 
| 789 | 
         
            +
                        if sp_rank == 0:
         
     | 
| 790 | 
         
            +
                            original_start_id = 0
         
     | 
| 791 | 
         
            +
                        else:
         
     | 
| 792 | 
         
            +
                            original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
         
     | 
| 793 | 
         
            +
                        original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
         
     | 
| 794 | 
         
            +
             
     | 
| 795 | 
         
            +
                        # Gather attention_mask, position_ids, labels and input_embeds
         
     | 
| 796 | 
         
            +
                        all_inputs_embeds = torch.zeros(
         
     | 
| 797 | 
         
            +
                            bs,
         
     | 
| 798 | 
         
            +
                            torch.sum(sp_seq_len_cat),
         
     | 
| 799 | 
         
            +
                            inputs_embeds.shape[-1],
         
     | 
| 800 | 
         
            +
                            dtype=inputs_embeds.dtype,
         
     | 
| 801 | 
         
            +
                            device=inputs_embeds.device,
         
     | 
| 802 | 
         
            +
                        ).contiguous()
         
     | 
| 803 | 
         
            +
                        all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
         
     | 
| 804 | 
         
            +
                        dist.barrier(group=sp_group)
         
     | 
| 805 | 
         
            +
                        dist.all_reduce(all_inputs_embeds, group=sp_group)
         
     | 
| 806 | 
         
            +
                        dist.barrier(group=sp_group)
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
                        attention_mask_list = [
         
     | 
| 809 | 
         
            +
                            torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
         
     | 
| 810 | 
         
            +
                            for i in range(sp_degree)
         
     | 
| 811 | 
         
            +
                        ]
         
     | 
| 812 | 
         
            +
                        position_ids_list = [
         
     | 
| 813 | 
         
            +
                            torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
         
     | 
| 814 | 
         
            +
                            for i in range(sp_degree)
         
     | 
| 815 | 
         
            +
                        ]
         
     | 
| 816 | 
         
            +
                        labels_list = [
         
     | 
| 817 | 
         
            +
                            torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
         
     | 
| 818 | 
         
            +
                        ]
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
                        dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
         
     | 
| 821 | 
         
            +
                        dist.all_gather(position_ids_list, position_ids, group=sp_group)
         
     | 
| 822 | 
         
            +
                        dist.all_gather(labels_list, labels, group=sp_group)
         
     | 
| 823 | 
         
            +
             
     | 
| 824 | 
         
            +
                        effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
         
     | 
| 825 | 
         
            +
                        effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
         
     | 
| 826 | 
         
            +
                        effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                        global_attention_mask_list = []
         
     | 
| 829 | 
         
            +
                        global_position_ids_list = []
         
     | 
| 830 | 
         
            +
                        global_labels_list = []
         
     | 
| 831 | 
         
            +
                        global_inputs_embeds_list = []
         
     | 
| 832 | 
         
            +
                        for i in range(bs):
         
     | 
| 833 | 
         
            +
                            global_attention_mask_batch_list = []
         
     | 
| 834 | 
         
            +
                            global_position_ids_batch_list = []
         
     | 
| 835 | 
         
            +
                            global_labels_batch_list = []
         
     | 
| 836 | 
         
            +
                            global_inputs_embeds_batch_list = []
         
     | 
| 837 | 
         
            +
                            for j in range(sp_degree):
         
     | 
| 838 | 
         
            +
                                eff_len = effective_seqlen_batch_list[i][j]
         
     | 
| 839 | 
         
            +
                                prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
         
     | 
| 840 | 
         
            +
             
     | 
| 841 | 
         
            +
                                global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
         
     | 
| 842 | 
         
            +
                                global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
         
     | 
| 843 | 
         
            +
                                global_labels_batch_list.append(labels_list[j][i, :eff_len])
         
     | 
| 844 | 
         
            +
                                global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
         
     | 
| 845 | 
         
            +
                            global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
         
     | 
| 846 | 
         
            +
                            global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
         
     | 
| 847 | 
         
            +
                            global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
         
     | 
| 848 | 
         
            +
                            global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
         
     | 
| 849 | 
         
            +
             
     | 
| 850 | 
         
            +
                            global_attention_mask = torch.nn.utils.rnn.pad_sequence(
         
     | 
| 851 | 
         
            +
                                global_attention_mask_list, batch_first=True, padding_value=False
         
     | 
| 852 | 
         
            +
                            )
         
     | 
| 853 | 
         
            +
                            global_position_ids = torch.nn.utils.rnn.pad_sequence(
         
     | 
| 854 | 
         
            +
                                global_position_ids_list, batch_first=True, padding_value=-1
         
     | 
| 855 | 
         
            +
                            )
         
     | 
| 856 | 
         
            +
                            global_labels = torch.nn.utils.rnn.pad_sequence(
         
     | 
| 857 | 
         
            +
                                global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
         
     | 
| 858 | 
         
            +
                            )
         
     | 
| 859 | 
         
            +
                            global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
         
     | 
| 860 | 
         
            +
                                global_inputs_embeds_list, batch_first=True, padding_value=0
         
     | 
| 861 | 
         
            +
                            )
         
     | 
| 862 | 
         
            +
             
     | 
| 863 | 
         
            +
                        # Re-shard the inputs
         
     | 
| 864 | 
         
            +
                        if ring_degree > 1:
         
     | 
| 865 | 
         
            +
                            total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
         
     | 
| 866 | 
         
            +
                            new_seqlen_per_rank = total_effective_seqlen // sp_degree
         
     | 
| 867 | 
         
            +
                            assert torch.all(
         
     | 
| 868 | 
         
            +
                                total_effective_seqlen % sp_degree == 0
         
     | 
| 869 | 
         
            +
                            ), "total_effective_seqlen must be divisible by sp_degree"
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
                            max_new_seqlen = torch.max(new_seqlen_per_rank).item()
         
     | 
| 872 | 
         
            +
             
     | 
| 873 | 
         
            +
                            new_attention_mask = torch.zeros(
         
     | 
| 874 | 
         
            +
                                (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
         
     | 
| 875 | 
         
            +
                            )
         
     | 
| 876 | 
         
            +
                            new_position_ids = torch.zeros(
         
     | 
| 877 | 
         
            +
                                (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
         
     | 
| 878 | 
         
            +
                            )
         
     | 
| 879 | 
         
            +
                            new_labels = torch.full(
         
     | 
| 880 | 
         
            +
                                (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
         
     | 
| 881 | 
         
            +
                            )
         
     | 
| 882 | 
         
            +
                            new_inputs_embeds = torch.zeros(
         
     | 
| 883 | 
         
            +
                                (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
         
     | 
| 884 | 
         
            +
                                dtype=global_inputs_embeds.dtype,
         
     | 
| 885 | 
         
            +
                                device=global_inputs_embeds.device,
         
     | 
| 886 | 
         
            +
                            )
         
     | 
| 887 | 
         
            +
             
     | 
| 888 | 
         
            +
                            if ring_type == "ring_varlen":
         
     | 
| 889 | 
         
            +
                                for i in range(bs):
         
     | 
| 890 | 
         
            +
                                    start_idx = new_seqlen_per_rank[i] * sp_rank
         
     | 
| 891 | 
         
            +
                                    end_idx = start_idx + new_seqlen_per_rank[i]
         
     | 
| 892 | 
         
            +
                                    new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
         
     | 
| 893 | 
         
            +
                                    new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
         
     | 
| 894 | 
         
            +
                                    new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
         
     | 
| 895 | 
         
            +
                                    new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
         
     | 
| 896 | 
         
            +
                                        i, start_idx:end_idx, :
         
     | 
| 897 | 
         
            +
                                    ]
         
     | 
| 898 | 
         
            +
                            elif ring_type == "zigzag_ring_varlen":
         
     | 
| 899 | 
         
            +
                                chunk_size = total_effective_seqlen // (2 * sp_degree)
         
     | 
| 900 | 
         
            +
                                for i in range(bs):
         
     | 
| 901 | 
         
            +
                                    # Zigzag pattern indices
         
     | 
| 902 | 
         
            +
                                    if sp_degree == ring_degree:
         
     | 
| 903 | 
         
            +
                                        forward_rank_idx = sp_rank
         
     | 
| 904 | 
         
            +
                                        backward_rank_idx = 2 * sp_degree - sp_rank - 1
         
     | 
| 905 | 
         
            +
                                    else:
         
     | 
| 906 | 
         
            +
                                        ulysses_offset = ulysses_rank * ring_degree * 2
         
     | 
| 907 | 
         
            +
                                        forward_rank_idx = ring_rank + ulysses_offset
         
     | 
| 908 | 
         
            +
                                        backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
         
     | 
| 909 | 
         
            +
             
     | 
| 910 | 
         
            +
                                    # Calculate start and end indices for the forward and backward zigzag
         
     | 
| 911 | 
         
            +
                                    start_idx_fwd = forward_rank_idx * chunk_size[i]
         
     | 
| 912 | 
         
            +
                                    end_idx_fwd = start_idx_fwd + chunk_size[i]
         
     | 
| 913 | 
         
            +
             
     | 
| 914 | 
         
            +
                                    start_idx_bwd = backward_rank_idx * chunk_size[i]
         
     | 
| 915 | 
         
            +
                                    end_idx_bwd = start_idx_bwd + chunk_size[i]
         
     | 
| 916 | 
         
            +
             
     | 
| 917 | 
         
            +
                                    # Fill new tensors with zigzag data
         
     | 
| 918 | 
         
            +
                                    new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
         
     | 
| 919 | 
         
            +
                                    new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
         
     | 
| 920 | 
         
            +
                                        i, start_idx_bwd:end_idx_bwd
         
     | 
| 921 | 
         
            +
                                    ]
         
     | 
| 922 | 
         
            +
             
     | 
| 923 | 
         
            +
                                    new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
         
     | 
| 924 | 
         
            +
                                    new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
         
     | 
| 925 | 
         
            +
                                        i, start_idx_bwd:end_idx_bwd
         
     | 
| 926 | 
         
            +
                                    ]
         
     | 
| 927 | 
         
            +
             
     | 
| 928 | 
         
            +
                                    new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
         
     | 
| 929 | 
         
            +
                                    new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
         
     | 
| 930 | 
         
            +
             
     | 
| 931 | 
         
            +
                                    new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
         
     | 
| 932 | 
         
            +
                                    new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
         
     | 
| 933 | 
         
            +
                                        i, start_idx_bwd:end_idx_bwd, :
         
     | 
| 934 | 
         
            +
                                    ]
         
     | 
| 935 | 
         
            +
                            else:
         
     | 
| 936 | 
         
            +
                                raise ValueError(f"Invalid ring_type: {ring_type}")
         
     | 
| 937 | 
         
            +
                        else:
         
     | 
| 938 | 
         
            +
                            global_seq_len = global_attention_mask.shape[-1]
         
     | 
| 939 | 
         
            +
                            seq_len_sharded = global_seq_len // sp_degree
         
     | 
| 940 | 
         
            +
                            start_idx_reshard = seq_len_sharded * sp_rank
         
     | 
| 941 | 
         
            +
                            end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
         
     | 
| 942 | 
         
            +
             
     | 
| 943 | 
         
            +
                            new_attention_mask = torch.narrow(
         
     | 
| 944 | 
         
            +
                                global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
         
     | 
| 945 | 
         
            +
                            )
         
     | 
| 946 | 
         
            +
                            new_position_ids = torch.narrow(
         
     | 
| 947 | 
         
            +
                                global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
         
     | 
| 948 | 
         
            +
                            )
         
     | 
| 949 | 
         
            +
                            new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
         
     | 
| 950 | 
         
            +
                            new_inputs_embeds = torch.narrow(
         
     | 
| 951 | 
         
            +
                                global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
         
     | 
| 952 | 
         
            +
                            )
         
     | 
| 953 | 
         
            +
             
     | 
| 954 | 
         
            +
                        return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
         
     | 
| 955 | 
         
            +
             
     | 
| 956 | 
         
            +
                    device = inputs_embeds.device
         
     | 
| 957 | 
         
            +
                    batch_size = inputs_embeds.shape[0]
         
     | 
| 958 | 
         
            +
                    seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
         
     | 
| 959 | 
         
            +
             
     | 
| 960 | 
         
            +
                    # Pack all sequences together
         
     | 
| 961 | 
         
            +
                    inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
         
     | 
| 962 | 
         
            +
                    attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
         
     | 
| 963 | 
         
            +
                    position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
         
     | 
| 964 | 
         
            +
                    labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
         
     | 
| 965 | 
         
            +
             
     | 
| 966 | 
         
            +
                    # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
         
     | 
| 967 | 
         
            +
                    inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
         
     | 
| 968 | 
         
            +
                    attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
         
     | 
| 969 | 
         
            +
                    position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
         
     | 
| 970 | 
         
            +
                    labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
         
     | 
| 971 | 
         
            +
             
     | 
| 972 | 
         
            +
                    # Mask the first token of each sequence to avoid contamination
         
     | 
| 973 | 
         
            +
                    for label in labels_p:
         
     | 
| 974 | 
         
            +
                        label[0] = IGNORE_INDEX
         
     | 
| 975 | 
         
            +
             
     | 
| 976 | 
         
            +
                    # Batch the data
         
     | 
| 977 | 
         
            +
                    inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
         
     | 
| 978 | 
         
            +
                    attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
         
     | 
| 979 | 
         
            +
                    position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
         
     | 
| 980 | 
         
            +
                    labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
         
     | 
| 981 | 
         
            +
             
     | 
| 982 | 
         
            +
                    if hasattr(
         
     | 
| 983 | 
         
            +
                        self, "pad_to_multiple_of"
         
     | 
| 984 | 
         
            +
                    ):  # related to quantization, please refer to ModelArguments for more information.
         
     | 
| 985 | 
         
            +
                        assert len(labels_p.shape) == 2
         
     | 
| 986 | 
         
            +
                        batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
         
     | 
| 987 | 
         
            +
                        hidden_size = inputs_embeds_p.shape[-1]
         
     | 
| 988 | 
         
            +
             
     | 
| 989 | 
         
            +
                        if max_length % self.pad_to_multiple_of != 0:
         
     | 
| 990 | 
         
            +
                            max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
         
     | 
| 991 | 
         
            +
                            difference = max_length - cur_length
         
     | 
| 992 | 
         
            +
             
     | 
| 993 | 
         
            +
                            inputs_embeds_p = torch.cat(
         
     | 
| 994 | 
         
            +
                                (
         
     | 
| 995 | 
         
            +
                                    inputs_embeds_p,
         
     | 
| 996 | 
         
            +
                                    torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
         
     | 
| 997 | 
         
            +
                                ),
         
     | 
| 998 | 
         
            +
                                dim=1,
         
     | 
| 999 | 
         
            +
                            )
         
     | 
| 1000 | 
         
            +
                            labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
         
     | 
| 1001 | 
         
            +
                            attention_mask_p = torch.cat(
         
     | 
| 1002 | 
         
            +
                                (
         
     | 
| 1003 | 
         
            +
                                    attention_mask_p,
         
     | 
| 1004 | 
         
            +
                                    torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
         
     | 
| 1005 | 
         
            +
                                ),
         
     | 
| 1006 | 
         
            +
                                dim=1,
         
     | 
| 1007 | 
         
            +
                            )
         
     | 
| 1008 | 
         
            +
                            position_ids_p = torch.cat(
         
     | 
| 1009 | 
         
            +
                                (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
         
     | 
| 1010 | 
         
            +
                            )
         
     | 
| 1011 | 
         
            +
             
     | 
| 1012 | 
         
            +
                    return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
         
     | 
| 1013 | 
         
            +
             
     | 
| 1014 | 
         
            +
                def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
         
     | 
| 1015 | 
         
            +
                    raise NotImplementedError("This method is not implemented for VILA model.")
         
     | 
| 1016 | 
         
            +
                    # Convert response format to logits processor
         
     | 
| 1017 | 
         
            +
                    import xgrammar as xgr
         
     | 
| 1018 | 
         
            +
             
     | 
| 1019 | 
         
            +
                    logging.info("[XGrammar] Compiling grammar for contrained output")
         
     | 
| 1020 | 
         
            +
             
     | 
| 1021 | 
         
            +
                    if self.grammar_compiler is None:
         
     | 
| 1022 | 
         
            +
                        # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
         
     | 
| 1023 | 
         
            +
                        self.grammar_compiler = xgr.GrammarCompiler(
         
     | 
| 1024 | 
         
            +
                            xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
         
     | 
| 1025 | 
         
            +
                        )
         
     | 
| 1026 | 
         
            +
             
     | 
| 1027 | 
         
            +
                    if response_format.type == "json_schema":
         
     | 
| 1028 | 
         
            +
                        compiled_grammar = self.grammar_compiler.compile_json_schema(
         
     | 
| 1029 | 
         
            +
                            response_format.json_schema.schema_,
         
     | 
| 1030 | 
         
            +
                            indent=2,
         
     | 
| 1031 | 
         
            +
                        )
         
     | 
| 1032 | 
         
            +
                    else:
         
     | 
| 1033 | 
         
            +
                        compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
         
     | 
| 1034 | 
         
            +
             
     | 
| 1035 | 
         
            +
                    return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
         
     | 
| 1036 | 
         
            +
             
     | 
| 1037 | 
         
            +
                def forward(
         
     | 
| 1038 | 
         
            +
                    self,
         
     | 
| 1039 | 
         
            +
                    input_ids: torch.LongTensor = None,
         
     | 
| 1040 | 
         
            +
                    media: Optional[Dict[str, List[torch.Tensor]]] = None,
         
     | 
| 1041 | 
         
            +
                    images: Optional[torch.FloatTensor] = None,
         
     | 
| 1042 | 
         
            +
                    media_config: Optional[List] = None,
         
     | 
| 1043 | 
         
            +
                    pixel_values: Optional[torch.FloatTensor] = None,
         
     | 
| 1044 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 1045 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 1046 | 
         
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         
     | 
| 1047 | 
         
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 1048 | 
         
            +
                    labels: Optional[torch.LongTensor] = None,
         
     | 
| 1049 | 
         
            +
                    packing: bool = True,
         
     | 
| 1050 | 
         
            +
                    force_packing: bool = False,
         
     | 
| 1051 | 
         
            +
                    seqlens_in_batch: Optional[torch.LongTensor] = None,
         
     | 
| 1052 | 
         
            +
                    dpo_forward: bool = False,
         
     | 
| 1053 | 
         
            +
                    **kwargs,
         
     | 
| 1054 | 
         
            +
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         
     | 
| 1055 | 
         
            +
                    self.freezed_module_patch()
         
     | 
| 1056 | 
         
            +
             
     | 
| 1057 | 
         
            +
                    if images is not None:
         
     | 
| 1058 | 
         
            +
                        if media is not None:
         
     | 
| 1059 | 
         
            +
                            raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
         
     | 
| 1060 | 
         
            +
                        print("The 'images' argument is deprecated. Please use 'media' instead.")
         
     | 
| 1061 | 
         
            +
                        media = {"image": images}
         
     | 
| 1062 | 
         
            +
             
     | 
| 1063 | 
         
            +
                    if media_config is None:
         
     | 
| 1064 | 
         
            +
                        media_config = defaultdict(dict)
         
     | 
| 1065 | 
         
            +
             
     | 
| 1066 | 
         
            +
                    if inputs_embeds is None:
         
     | 
| 1067 | 
         
            +
                        inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
         
     | 
| 1068 | 
         
            +
             
     | 
| 1069 | 
         
            +
                    if force_packing or (packing and self.training and not dpo_forward):
         
     | 
| 1070 | 
         
            +
                        if seqlens_in_batch is None:
         
     | 
| 1071 | 
         
            +
                            seqlens_in_batch = torch.sum(attention_mask, dim=1)
         
     | 
| 1072 | 
         
            +
                        set_seqlens_in_batch(seqlens_in_batch)
         
     | 
| 1073 | 
         
            +
             
     | 
| 1074 | 
         
            +
                        (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
         
     | 
| 1075 | 
         
            +
                            inputs_embeds, attention_mask, position_ids, labels
         
     | 
| 1076 | 
         
            +
                        )
         
     | 
| 1077 | 
         
            +
             
     | 
| 1078 | 
         
            +
                    outputs = self.llm(
         
     | 
| 1079 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 1080 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 1081 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 1082 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 1083 | 
         
            +
                        labels=labels,
         
     | 
| 1084 | 
         
            +
                        **kwargs,
         
     | 
| 1085 | 
         
            +
                    )
         
     | 
| 1086 | 
         
            +
             
     | 
| 1087 | 
         
            +
                    if self.training and getattr(self.config, "time_token_ids", []):
         
     | 
| 1088 | 
         
            +
                        outputs.loss = soft_cross_entropy(
         
     | 
| 1089 | 
         
            +
                            outputs.logits,
         
     | 
| 1090 | 
         
            +
                            labels,
         
     | 
| 1091 | 
         
            +
                            soft_tokens=self.config.time_token_ids,
         
     | 
| 1092 | 
         
            +
                            std=self.config.soft_ce_std,
         
     | 
| 1093 | 
         
            +
                        )
         
     | 
| 1094 | 
         
            +
             
     | 
| 1095 | 
         
            +
                    if dpo_forward:
         
     | 
| 1096 | 
         
            +
                        return outputs.logits, labels
         
     | 
| 1097 | 
         
            +
             
     | 
| 1098 | 
         
            +
                    return outputs
         
     | 
| 1099 | 
         
            +
             
     | 
| 1100 | 
         
            +
                # @torch.inference_mode()
         
     | 
| 1101 | 
         
            +
                def generate(
         
     | 
| 1102 | 
         
            +
                    self,
         
     | 
| 1103 | 
         
            +
                    input_ids: Optional[torch.FloatTensor] = None,
         
     | 
| 1104 | 
         
            +
                    media: Optional[Dict[str, List[torch.Tensor]]] = None,
         
     | 
| 1105 | 
         
            +
                    media_config: Dict[str, Dict[str, Any]] = None,
         
     | 
| 1106 | 
         
            +
                    attention_mask: Optional[torch.LongTensor] = None,
         
     | 
| 1107 | 
         
            +
                    return_output_ids_only: bool = True,
         
     | 
| 1108 | 
         
            +
                    **generation_kwargs,
         
     | 
| 1109 | 
         
            +
                ) -> torch.LongTensor:
         
     | 
| 1110 | 
         
            +
                    """
         
     | 
| 1111 | 
         
            +
                    input_tokens: <image> describe the image
         
     | 
| 1112 | 
         
            +
                    media:        [Tensor(1, 3, 384, 384), ]
         
     | 
| 1113 | 
         
            +
                    ----------->
         
     | 
| 1114 | 
         
            +
                    input_tokens:      36000      001 002 003 004
         
     | 
| 1115 | 
         
            +
                    input_emds:     <media emd>   001 002 003 004
         
     | 
| 1116 | 
         
            +
                    """
         
     | 
| 1117 | 
         
            +
                    inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
         
     | 
| 1118 | 
         
            +
                    output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
         
     | 
| 1119 | 
         
            +
             
     | 
| 1120 | 
         
            +
                    if return_output_ids_only:
         
     | 
| 1121 | 
         
            +
                        return_value = output_ids
         
     | 
| 1122 | 
         
            +
                    else:
         
     | 
| 1123 | 
         
            +
                        # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
         
     | 
| 1124 | 
         
            +
                        generation_config = generation_kwargs.get("generation_config", None)
         
     | 
| 1125 | 
         
            +
                        if generation_config is not None:
         
     | 
| 1126 | 
         
            +
                            num_generations = generation_config.num_return_sequences
         
     | 
| 1127 | 
         
            +
                            repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
         
     | 
| 1128 | 
         
            +
                            return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
         
     | 
| 1129 | 
         
            +
                        else:
         
     | 
| 1130 | 
         
            +
                            return_value = torch.cat([input_ids, output_ids], dim=-1)
         
     | 
| 1131 | 
         
            +
             
     | 
| 1132 | 
         
            +
                    return return_value
         
     | 
| 1133 | 
         
            +
             
     | 
| 1134 | 
         
            +
                @torch.inference_mode()
         
     | 
| 1135 | 
         
            +
                def generate_content(
         
     | 
| 1136 | 
         
            +
                    self,
         
     | 
| 1137 | 
         
            +
                    prompt: Union[str, List],
         
     | 
| 1138 | 
         
            +
                    generation_config: Optional[GenerationConfig] = None,
         
     | 
| 1139 | 
         
            +
                    response_format=None,
         
     | 
| 1140 | 
         
            +
                ) -> str:
         
     | 
| 1141 | 
         
            +
                    conversation = [{"from": "human", "value": prompt}]
         
     | 
| 1142 | 
         
            +
             
     | 
| 1143 | 
         
            +
                    # Convert response format to logits processor
         
     | 
| 1144 | 
         
            +
                    xgr_logits_processor = None
         
     | 
| 1145 | 
         
            +
             
     | 
| 1146 | 
         
            +
                    # Extract media from the conversation
         
     | 
| 1147 | 
         
            +
             
     | 
| 1148 | 
         
            +
                    media = extract_media(conversation, self.config)
         
     | 
| 1149 | 
         
            +
             
     | 
| 1150 | 
         
            +
                    # Process media
         
     | 
| 1151 | 
         
            +
                    media_config = defaultdict(dict)
         
     | 
| 1152 | 
         
            +
                    for name in media:
         
     | 
| 1153 | 
         
            +
                        if name == "image":
         
     | 
| 1154 | 
         
            +
                            if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
         
     | 
| 1155 | 
         
            +
                                self.config.image_processor = self.vision_tower.image_processor
         
     | 
| 1156 | 
         
            +
                                if self.config.image_aspect_ratio == "dynamic":
         
     | 
| 1157 | 
         
            +
                                    images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
         
     | 
| 1158 | 
         
            +
                                    conversation[0]["value"] = conversation[0]["value"].replace(
         
     | 
| 1159 | 
         
            +
                                        DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
         
     | 
| 1160 | 
         
            +
                                    )
         
     | 
| 1161 | 
         
            +
                                else:
         
     | 
| 1162 | 
         
            +
                                    if type(self.config.s2_scales) is str:
         
     | 
| 1163 | 
         
            +
                                        self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
         
     | 
| 1164 | 
         
            +
                                    images, block_sizes = process_image(
         
     | 
| 1165 | 
         
            +
                                        media["image"][0], self.config, None, enable_dynamic_s2=True
         
     | 
| 1166 | 
         
            +
                                    )
         
     | 
| 1167 | 
         
            +
                                    images = images.half()
         
     | 
| 1168 | 
         
            +
                                    media_config[name]["block_sizes"] = [block_sizes]
         
     | 
| 1169 | 
         
            +
                            else:
         
     | 
| 1170 | 
         
            +
                                images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
         
     | 
| 1171 | 
         
            +
                            media[name] = [image for image in images]
         
     | 
| 1172 | 
         
            +
                        elif name == "video":
         
     | 
| 1173 | 
         
            +
                            if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
         
     | 
| 1174 | 
         
            +
                                media[name] = [
         
     | 
| 1175 | 
         
            +
                                    process_images(
         
     | 
| 1176 | 
         
            +
                                        images,
         
     | 
| 1177 | 
         
            +
                                        self.vision_tower.image_processor,
         
     | 
| 1178 | 
         
            +
                                        self.config,
         
     | 
| 1179 | 
         
            +
                                        enable_dynamic_res=True,
         
     | 
| 1180 | 
         
            +
                                        max_tiles=self.config.video_max_tiles,
         
     | 
| 1181 | 
         
            +
                                    ).half()
         
     | 
| 1182 | 
         
            +
                                    for images in media[name]
         
     | 
| 1183 | 
         
            +
                                ]
         
     | 
| 1184 | 
         
            +
                            elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
         
     | 
| 1185 | 
         
            +
                                self.config.image_processor = self.vision_tower.image_processor
         
     | 
| 1186 | 
         
            +
                                if type(self.config.s2_scales) is str:
         
     | 
| 1187 | 
         
            +
                                    self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
         
     | 
| 1188 | 
         
            +
                                media[name] = [
         
     | 
| 1189 | 
         
            +
                                    torch.cat(
         
     | 
| 1190 | 
         
            +
                                        [
         
     | 
| 1191 | 
         
            +
                                            process_image(
         
     | 
| 1192 | 
         
            +
                                                image,
         
     | 
| 1193 | 
         
            +
                                                self.config,
         
     | 
| 1194 | 
         
            +
                                                None,
         
     | 
| 1195 | 
         
            +
                                                enable_dynamic_s2=True,
         
     | 
| 1196 | 
         
            +
                                                max_tiles=self.config.video_max_tiles,
         
     | 
| 1197 | 
         
            +
                                            )[0].half()
         
     | 
| 1198 | 
         
            +
                                            for image in images
         
     | 
| 1199 | 
         
            +
                                        ]
         
     | 
| 1200 | 
         
            +
                                    )
         
     | 
| 1201 | 
         
            +
                                    for images in media[name]
         
     | 
| 1202 | 
         
            +
                                ]
         
     | 
| 1203 | 
         
            +
                            else:
         
     | 
| 1204 | 
         
            +
                                media[name] = [
         
     | 
| 1205 | 
         
            +
                                    process_images(images, self.vision_tower.image_processor, self.config).half()
         
     | 
| 1206 | 
         
            +
                                    for images in media[name]
         
     | 
| 1207 | 
         
            +
                                ]
         
     | 
| 1208 | 
         
            +
                        else:
         
     | 
| 1209 | 
         
            +
                            raise ValueError(f"Unsupported media type: {name}")
         
     | 
| 1210 | 
         
            +
             
     | 
| 1211 | 
         
            +
                    # Tokenize the conversation
         
     | 
| 1212 | 
         
            +
                    input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
         
     | 
| 1213 | 
         
            +
             
     | 
| 1214 | 
         
            +
                    # Set up the generation config
         
     | 
| 1215 | 
         
            +
                    generation_config = generation_config or self.default_generation_config
         
     | 
| 1216 | 
         
            +
             
     | 
| 1217 | 
         
            +
                    # Generate the response
         
     | 
| 1218 | 
         
            +
                    try:
         
     | 
| 1219 | 
         
            +
                        output_ids = self.generate(
         
     | 
| 1220 | 
         
            +
                            input_ids=input_ids,
         
     | 
| 1221 | 
         
            +
                            media=media,
         
     | 
| 1222 | 
         
            +
                            media_config=media_config,
         
     | 
| 1223 | 
         
            +
                            generation_config=generation_config,
         
     | 
| 1224 | 
         
            +
                            logits_processor=xgr_logits_processor,  # structured generation
         
     | 
| 1225 | 
         
            +
                        )
         
     | 
| 1226 | 
         
            +
                    except ValueError:
         
     | 
| 1227 | 
         
            +
                        if not generation_config.do_sample:
         
     | 
| 1228 | 
         
            +
                            raise
         
     | 
| 1229 | 
         
            +
                        logging.warning("Generation failed with sampling, retrying with greedy decoding.")
         
     | 
| 1230 | 
         
            +
                        generation_config.do_sample = False
         
     | 
| 1231 | 
         
            +
                        output_ids = self.generate(
         
     | 
| 1232 | 
         
            +
                            input_ids=input_ids,
         
     | 
| 1233 | 
         
            +
                            media=media,
         
     | 
| 1234 | 
         
            +
                            media_config=media_config,
         
     | 
| 1235 | 
         
            +
                            generation_config=generation_config,
         
     | 
| 1236 | 
         
            +
                            logits_processor=xgr_logits_processor,
         
     | 
| 1237 | 
         
            +
                        )
         
     | 
| 1238 | 
         
            +
             
     | 
| 1239 | 
         
            +
                    # Decode the response
         
     | 
| 1240 | 
         
            +
                    response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
         
     | 
| 1241 | 
         
            +
                    return response
         
     | 
| 1242 | 
         
            +
             
     | 
| 1243 | 
         
            +
                @property
         
     | 
| 1244 | 
         
            +
                def default_generation_config(self) -> GenerationConfig:
         
     | 
| 1245 | 
         
            +
                    generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
         
     | 
| 1246 | 
         
            +
                    if self.tokenizer.eos_token_id is None:
         
     | 
| 1247 | 
         
            +
                        raise ValueError("Tokenizer must have an EOS token")
         
     | 
| 1248 | 
         
            +
                    if generation_config.max_length == GenerationConfig().max_length:
         
     | 
| 1249 | 
         
            +
                        generation_config.max_length = self.tokenizer.model_max_length
         
     | 
| 1250 | 
         
            +
                    if generation_config.pad_token_id is None:
         
     | 
| 1251 | 
         
            +
                        generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
         
     | 
| 1252 | 
         
            +
                    if generation_config.bos_token_id is None:
         
     | 
| 1253 | 
         
            +
                        generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
         
     | 
| 1254 | 
         
            +
                    if generation_config.eos_token_id is None:
         
     | 
| 1255 | 
         
            +
                        generation_config.eos_token_id = self.tokenizer.eos_token_id
         
     | 
| 1256 | 
         
            +
                    return generation_config
         
     | 
    	
        siglip_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,286 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn as nn
         
     | 
| 19 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 20 | 
         
            +
            from accelerate.hooks import add_hook_to_module
         
     | 
| 21 | 
         
            +
            from einops import rearrange
         
     | 
| 22 | 
         
            +
            from s2wrapper import forward as multiscale_forward
         
     | 
| 23 | 
         
            +
            from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
         
     | 
| 24 | 
         
            +
            from transformers.image_processing_utils import BaseImageProcessor
         
     | 
| 25 | 
         
            +
            from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
         
     | 
| 26 | 
         
            +
            from transformers.models.siglip import SiglipVisionModel
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class VisionTower(nn.Module):
         
     | 
| 30 | 
         
            +
                def __init__(self, vision_tower, args, delay_load=False):
         
     | 
| 31 | 
         
            +
                    super().__init__()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    self.is_loaded = False
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.vision_tower_name = vision_tower
         
     | 
| 36 | 
         
            +
                    self.select_layer = getattr(args, "mm_vision_select_layer", -2)
         
     | 
| 37 | 
         
            +
                    self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.cfg_only = None
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def feature_select(self, image_forward_outs):
         
     | 
| 42 | 
         
            +
                    image_features = image_forward_outs.hidden_states[self.select_layer]
         
     | 
| 43 | 
         
            +
                    if self.select_feature == "patch":
         
     | 
| 44 | 
         
            +
                        image_features = image_features[:, 1:]
         
     | 
| 45 | 
         
            +
                    elif self.select_feature == "cls_patch":
         
     | 
| 46 | 
         
            +
                        image_features = image_features
         
     | 
| 47 | 
         
            +
                    else:
         
     | 
| 48 | 
         
            +
                        raise ValueError(f"Unexpected select feature: {self.select_feature}")
         
     | 
| 49 | 
         
            +
                    return image_features
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def _maybe_resize_pos_embeds(
         
     | 
| 52 | 
         
            +
                    self,
         
     | 
| 53 | 
         
            +
                    model: PreTrainedModel,
         
     | 
| 54 | 
         
            +
                    image_processor: BaseImageProcessor,
         
     | 
| 55 | 
         
            +
                    resolution: int = -1,
         
     | 
| 56 | 
         
            +
                    interpolate_mode: str = "linear",
         
     | 
| 57 | 
         
            +
                ):
         
     | 
| 58 | 
         
            +
                    if resolution in [model.config.image_size, -1]:
         
     | 
| 59 | 
         
            +
                        return
         
     | 
| 60 | 
         
            +
                    print(
         
     | 
| 61 | 
         
            +
                        f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
         
     | 
| 62 | 
         
            +
                    )
         
     | 
| 63 | 
         
            +
                    embeddings = model.vision_model.embeddings
         
     | 
| 64 | 
         
            +
                    patch_size = embeddings.patch_size
         
     | 
| 65 | 
         
            +
                    num_new_tokens = int((resolution // patch_size) ** 2)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    old_embeddings = embeddings.position_embedding
         
     | 
| 68 | 
         
            +
                    match interpolate_mode:
         
     | 
| 69 | 
         
            +
                        case "linear":
         
     | 
| 70 | 
         
            +
                            ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
         
     | 
| 71 | 
         
            +
                            ## Step 2:  Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
         
     | 
| 72 | 
         
            +
                            import torch
         
     | 
| 73 | 
         
            +
                            import torch.nn as nn
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                            if is_deepspeed_zero3_enabled():
         
     | 
| 76 | 
         
            +
                                try:
         
     | 
| 77 | 
         
            +
                                    import deepspeed
         
     | 
| 78 | 
         
            +
                                except ImportError:
         
     | 
| 79 | 
         
            +
                                    raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.")
         
     | 
| 80 | 
         
            +
                                with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
         
     | 
| 81 | 
         
            +
                                    old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
         
     | 
| 82 | 
         
            +
                            else:
         
     | 
| 83 | 
         
            +
                                old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
         
     | 
| 84 | 
         
            +
                            new_embeddings = nn.Embedding(
         
     | 
| 85 | 
         
            +
                                num_new_tokens,
         
     | 
| 86 | 
         
            +
                                old_embedding_dim,
         
     | 
| 87 | 
         
            +
                                dtype=old_embeddings.weight.dtype,
         
     | 
| 88 | 
         
            +
                                device=old_embeddings.weight.device,
         
     | 
| 89 | 
         
            +
                            )
         
     | 
| 90 | 
         
            +
                            mapped_indices = (
         
     | 
| 91 | 
         
            +
                                torch.arange(num_new_tokens).to(old_embeddings.weight.device)
         
     | 
| 92 | 
         
            +
                                / (num_new_tokens - 1)
         
     | 
| 93 | 
         
            +
                                * (old_num_tokens - 1)
         
     | 
| 94 | 
         
            +
                            )
         
     | 
| 95 | 
         
            +
                            floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
         
     | 
| 96 | 
         
            +
                            ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
         
     | 
| 97 | 
         
            +
                            if is_deepspeed_zero3_enabled():
         
     | 
| 98 | 
         
            +
                                params = [old_embeddings.weight, new_embeddings.weight]
         
     | 
| 99 | 
         
            +
                                with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
         
     | 
| 100 | 
         
            +
                                    interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
         
     | 
| 101 | 
         
            +
                                        ceil_indices, :
         
     | 
| 102 | 
         
            +
                                    ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
         
     | 
| 103 | 
         
            +
                            else:
         
     | 
| 104 | 
         
            +
                                interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
         
     | 
| 105 | 
         
            +
                                    ceil_indices, :
         
     | 
| 106 | 
         
            +
                                ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
         
     | 
| 107 | 
         
            +
                            new_embeddings.weight.data = interpolated_embeds
         
     | 
| 108 | 
         
            +
                        case _:
         
     | 
| 109 | 
         
            +
                            raise NotImplementedError
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    if hasattr(old_embeddings, "_hf_hook"):
         
     | 
| 112 | 
         
            +
                        hook = old_embeddings._hf_hook
         
     | 
| 113 | 
         
            +
                        add_hook_to_module(new_embeddings, hook)
         
     | 
| 114 | 
         
            +
                    new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
         
     | 
| 115 | 
         
            +
                    ## update vision encoder's configurations
         
     | 
| 116 | 
         
            +
                    model.config.image_size = resolution
         
     | 
| 117 | 
         
            +
                    if hasattr(image_processor, "crop_size"):
         
     | 
| 118 | 
         
            +
                        # CLIP vision tower
         
     | 
| 119 | 
         
            +
                        image_processor.crop_size = resolution
         
     | 
| 120 | 
         
            +
                    else:
         
     | 
| 121 | 
         
            +
                        # SIGLIP vision tower
         
     | 
| 122 | 
         
            +
                        assert hasattr(image_processor, "size")
         
     | 
| 123 | 
         
            +
                        image_processor.size = {"height": resolution, "width": resolution}
         
     | 
| 124 | 
         
            +
                    embeddings.position_embedding = new_embeddings
         
     | 
| 125 | 
         
            +
                    embeddings.image_size = resolution
         
     | 
| 126 | 
         
            +
                    embeddings.num_patches = embeddings.num_positions = num_new_tokens
         
     | 
| 127 | 
         
            +
                    embeddings.position_ids = (
         
     | 
| 128 | 
         
            +
                        torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
         
     | 
| 129 | 
         
            +
                    )
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def forward(self, images):
         
     | 
| 132 | 
         
            +
                    if type(images) is list:
         
     | 
| 133 | 
         
            +
                        image_features = []
         
     | 
| 134 | 
         
            +
                        for image in images:
         
     | 
| 135 | 
         
            +
                            image_forward_out = self.vision_tower(
         
     | 
| 136 | 
         
            +
                                image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
         
     | 
| 137 | 
         
            +
                                output_hidden_states=True,
         
     | 
| 138 | 
         
            +
                            )
         
     | 
| 139 | 
         
            +
                            image_feature = self.feature_select(image_forward_out).to(image.dtype)
         
     | 
| 140 | 
         
            +
                            image_features.append(image_feature)
         
     | 
| 141 | 
         
            +
                    else:
         
     | 
| 142 | 
         
            +
                        image_forward_outs = self.vision_tower(
         
     | 
| 143 | 
         
            +
                            images.to(device=self.device, dtype=self.dtype),
         
     | 
| 144 | 
         
            +
                            output_hidden_states=True,
         
     | 
| 145 | 
         
            +
                        )
         
     | 
| 146 | 
         
            +
                        image_features = self.feature_select(image_forward_outs).to(images.dtype)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    return image_features
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                @property
         
     | 
| 151 | 
         
            +
                def dummy_feature(self):
         
     | 
| 152 | 
         
            +
                    return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                @property
         
     | 
| 155 | 
         
            +
                def dtype(self):
         
     | 
| 156 | 
         
            +
                    return self.vision_tower.dtype
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                @property
         
     | 
| 159 | 
         
            +
                def device(self):
         
     | 
| 160 | 
         
            +
                    return self.vision_tower.device
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                @property
         
     | 
| 163 | 
         
            +
                def config(self):
         
     | 
| 164 | 
         
            +
                    if self.is_loaded:
         
     | 
| 165 | 
         
            +
                        return self.vision_tower.config
         
     | 
| 166 | 
         
            +
                    else:
         
     | 
| 167 | 
         
            +
                        return self.cfg_only
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                @property
         
     | 
| 170 | 
         
            +
                def hidden_size(self):
         
     | 
| 171 | 
         
            +
                    return self.config.hidden_size
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                @property
         
     | 
| 174 | 
         
            +
                def num_patches(self):
         
     | 
| 175 | 
         
            +
                    return (self.config.image_size // self.config.patch_size) ** 2
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            class VisionTowerS2(VisionTower):
         
     | 
| 179 | 
         
            +
                def __init__(self, vision_tower, args, delay_load=False):
         
     | 
| 180 | 
         
            +
                    super().__init__(vision_tower, args, delay_load)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    self.scales = list(map(int, args.s2_scales.split(",")))
         
     | 
| 183 | 
         
            +
                    self.scales.sort()
         
     | 
| 184 | 
         
            +
                    self.max_split_size = args.s2_max_split_size
         
     | 
| 185 | 
         
            +
                    self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def forward_feature(self, images):
         
     | 
| 188 | 
         
            +
                    image_forward_outs = self.vision_tower(
         
     | 
| 189 | 
         
            +
                        images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
         
     | 
| 190 | 
         
            +
                    )
         
     | 
| 191 | 
         
            +
                    image_features = self.feature_select(image_forward_outs).to(images.dtype)
         
     | 
| 192 | 
         
            +
                    return image_features
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                def forward(self, images):
         
     | 
| 195 | 
         
            +
                    if type(images) is list:
         
     | 
| 196 | 
         
            +
                        image_features = []
         
     | 
| 197 | 
         
            +
                        for image in images:
         
     | 
| 198 | 
         
            +
                            image_feature = multiscale_forward(
         
     | 
| 199 | 
         
            +
                                self.forward_feature,
         
     | 
| 200 | 
         
            +
                                image.unsqueeze(0),
         
     | 
| 201 | 
         
            +
                                img_sizes=self.scales,
         
     | 
| 202 | 
         
            +
                                max_split_size=self.max_split_size,
         
     | 
| 203 | 
         
            +
                                resize_output_to_idx=self.resize_output_to_scale_idx,
         
     | 
| 204 | 
         
            +
                            )
         
     | 
| 205 | 
         
            +
                            image_features.append(image_feature)
         
     | 
| 206 | 
         
            +
                    else:
         
     | 
| 207 | 
         
            +
                        image_features = multiscale_forward(
         
     | 
| 208 | 
         
            +
                            self.forward_feature,
         
     | 
| 209 | 
         
            +
                            images,
         
     | 
| 210 | 
         
            +
                            img_sizes=self.scales,
         
     | 
| 211 | 
         
            +
                            max_split_size=self.max_split_size,
         
     | 
| 212 | 
         
            +
                            resize_output_to_idx=self.resize_output_to_scale_idx,
         
     | 
| 213 | 
         
            +
                        )
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    return image_features
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                @property
         
     | 
| 218 | 
         
            +
                def hidden_size(self):
         
     | 
| 219 | 
         
            +
                    return self.config.hidden_size * len(self.scales)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
            class VisionTowerDynamicS2(VisionTower):
         
     | 
| 223 | 
         
            +
                def __init__(self, vision_tower, args, delay_load=False):
         
     | 
| 224 | 
         
            +
                    super().__init__(vision_tower, args, delay_load)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    self.scales = list(map(int, args.s2_scales.split(",")))
         
     | 
| 227 | 
         
            +
                    self.scales.sort()
         
     | 
| 228 | 
         
            +
                    self.max_split_size = args.s2_max_split_size
         
     | 
| 229 | 
         
            +
                    self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                def forward_feature(self, images):
         
     | 
| 232 | 
         
            +
                    image_forward_outs = self.vision_tower(
         
     | 
| 233 | 
         
            +
                        images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
         
     | 
| 234 | 
         
            +
                    )
         
     | 
| 235 | 
         
            +
                    image_features = self.feature_select(image_forward_outs).to(images.dtype)
         
     | 
| 236 | 
         
            +
                    return image_features
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                def forward(self, images):
         
     | 
| 239 | 
         
            +
                    assert type(images) is not list
         
     | 
| 240 | 
         
            +
                    image_features = self.forward_feature(images)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    return image_features
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                @property
         
     | 
| 245 | 
         
            +
                def hidden_size(self):
         
     | 
| 246 | 
         
            +
                    return self.config.hidden_size * len(self.scales)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
            class SiglipVisionTower(VisionTower):
         
     | 
| 250 | 
         
            +
                def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
         
     | 
| 251 | 
         
            +
                    super().__init__(model_name_or_path, config)
         
     | 
| 252 | 
         
            +
                    self.vision_tower = SiglipVisionModel.from_pretrained(
         
     | 
| 253 | 
         
            +
                        model_name_or_path,
         
     | 
| 254 | 
         
            +
                        attn_implementation=config._attn_implementation,
         
     | 
| 255 | 
         
            +
                        torch_dtype=eval(config.model_dtype),
         
     | 
| 256 | 
         
            +
                    )
         
     | 
| 257 | 
         
            +
                    self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
         
     | 
| 258 | 
         
            +
                    self.is_loaded = True
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
            class SiglipVisionTowerS2(VisionTowerS2):
         
     | 
| 262 | 
         
            +
                def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
         
     | 
| 263 | 
         
            +
                    super().__init__(model_name_or_path, config)
         
     | 
| 264 | 
         
            +
                    self.vision_tower = SiglipVisionModel.from_pretrained(
         
     | 
| 265 | 
         
            +
                        model_name_or_path,
         
     | 
| 266 | 
         
            +
                        attn_implementation=config._attn_implementation,
         
     | 
| 267 | 
         
            +
                        torch_dtype=eval(config.model_dtype),
         
     | 
| 268 | 
         
            +
                    )
         
     | 
| 269 | 
         
            +
                    self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
         
     | 
| 270 | 
         
            +
                    # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
         
     | 
| 271 | 
         
            +
                    self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
         
     | 
| 272 | 
         
            +
                    self.is_loaded = True
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
            class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
         
     | 
| 276 | 
         
            +
                def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
         
     | 
| 277 | 
         
            +
                    super().__init__(model_name_or_path, config)
         
     | 
| 278 | 
         
            +
                    self.vision_tower = SiglipVisionModel.from_pretrained(
         
     | 
| 279 | 
         
            +
                        model_name_or_path,
         
     | 
| 280 | 
         
            +
                        attn_implementation=config._attn_implementation,
         
     | 
| 281 | 
         
            +
                        torch_dtype=eval(config.model_dtype),
         
     | 
| 282 | 
         
            +
                    )
         
     | 
| 283 | 
         
            +
                    self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
         
     | 
| 284 | 
         
            +
                    # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
         
     | 
| 285 | 
         
            +
                    self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
         
     | 
| 286 | 
         
            +
                    self.is_loaded = True
         
     | 
    	
        tokenizer_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,181 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from typing import Any, Dict, List, Optional, Sequence
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            import transformers
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from .constants import IGNORE_INDEX, SENTINEL_TOKEN
         
     | 
| 23 | 
         
            +
            from .conversation import SeparatorStyle, default_conversation
         
     | 
| 24 | 
         
            +
            from .mm_utils import tokenizer_image_token
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            DUMMY_CONVERSATION = [
         
     | 
| 27 | 
         
            +
                {"from": "human", "value": "question"},
         
     | 
| 28 | 
         
            +
                {"from": "gpt", "value": "answer"},
         
     | 
| 29 | 
         
            +
            ] * 10
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def tokenize_conversation_legacy(
         
     | 
| 33 | 
         
            +
                messages: Sequence[Dict[str, str]],
         
     | 
| 34 | 
         
            +
                tokenizer: transformers.PreTrainedTokenizer,
         
     | 
| 35 | 
         
            +
                add_generation_prompt: bool = False,
         
     | 
| 36 | 
         
            +
                overrides: Optional[Dict[str, str]] = None,
         
     | 
| 37 | 
         
            +
                no_system_prompt: bool = False,
         
     | 
| 38 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 39 | 
         
            +
                conv = default_conversation.copy()
         
     | 
| 40 | 
         
            +
                roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                if no_system_prompt:
         
     | 
| 43 | 
         
            +
                    conv.system = ""
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                # Skip the first message if it is not from human
         
     | 
| 46 | 
         
            +
                if messages[0]["from"] != "human":
         
     | 
| 47 | 
         
            +
                    messages = messages[1:]
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                # Add a generation prompt if needed
         
     | 
| 50 | 
         
            +
                if add_generation_prompt:
         
     | 
| 51 | 
         
            +
                    messages.append({"from": "gpt", "value": None})
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                conv.messages = []
         
     | 
| 54 | 
         
            +
                for turn, message in enumerate(messages):
         
     | 
| 55 | 
         
            +
                    role = roles[message["from"]]
         
     | 
| 56 | 
         
            +
                    assert role == conv.roles[turn % 2]
         
     | 
| 57 | 
         
            +
                    if overrides is not None and message["from"] in overrides:
         
     | 
| 58 | 
         
            +
                        conv.append_message(role, overrides[message["from"]])
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        conv.append_message(role, message["value"])
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            def tokenize_conversation(
         
     | 
| 66 | 
         
            +
                messages: Sequence[Dict[str, str]],
         
     | 
| 67 | 
         
            +
                tokenizer: transformers.PreTrainedTokenizer,
         
     | 
| 68 | 
         
            +
                add_generation_prompt: bool = False,
         
     | 
| 69 | 
         
            +
                overrides: Optional[Dict[str, str]] = None,
         
     | 
| 70 | 
         
            +
                no_system_prompt: bool = False,
         
     | 
| 71 | 
         
            +
                return_ids_only=True,
         
     | 
| 72 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 73 | 
         
            +
                # Normalize the conversation before tokenization
         
     | 
| 74 | 
         
            +
                for message in messages:
         
     | 
| 75 | 
         
            +
                    message["value"] = message["value"].strip()
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                if default_conversation.sep_style != SeparatorStyle.AUTO:
         
     | 
| 78 | 
         
            +
                    return tokenize_conversation_legacy(
         
     | 
| 79 | 
         
            +
                        messages,
         
     | 
| 80 | 
         
            +
                        tokenizer,
         
     | 
| 81 | 
         
            +
                        add_generation_prompt=add_generation_prompt,
         
     | 
| 82 | 
         
            +
                        overrides=overrides,
         
     | 
| 83 | 
         
            +
                        no_system_prompt=no_system_prompt,
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                conversation = []
         
     | 
| 87 | 
         
            +
                for m in messages:
         
     | 
| 88 | 
         
            +
                    message = {}
         
     | 
| 89 | 
         
            +
                    if m["from"] == "human":
         
     | 
| 90 | 
         
            +
                        message["role"] = "user"
         
     | 
| 91 | 
         
            +
                    elif m["from"] == "gpt":
         
     | 
| 92 | 
         
            +
                        message["role"] = "assistant"
         
     | 
| 93 | 
         
            +
                    elif m["from"] == "system":
         
     | 
| 94 | 
         
            +
                        message["role"] = "system"
         
     | 
| 95 | 
         
            +
                        if no_system_prompt:
         
     | 
| 96 | 
         
            +
                            raise ValueError("message[role]=system is not allowed when no_system_prompt is set to True.")
         
     | 
| 97 | 
         
            +
                    else:
         
     | 
| 98 | 
         
            +
                        raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    message["content"] = m["value"]
         
     | 
| 101 | 
         
            +
                    if overrides is not None and m["from"] in overrides:
         
     | 
| 102 | 
         
            +
                        message["content"] = overrides[m["from"]]
         
     | 
| 103 | 
         
            +
                    conversation.append(message)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                if no_system_prompt:
         
     | 
| 106 | 
         
            +
                    conversation = [{"role": "system", "content": ""}] + conversation
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                text = tokenizer.apply_chat_template(
         
     | 
| 109 | 
         
            +
                    conversation,
         
     | 
| 110 | 
         
            +
                    add_generation_prompt=add_generation_prompt,
         
     | 
| 111 | 
         
            +
                    tokenize=False,
         
     | 
| 112 | 
         
            +
                )
         
     | 
| 113 | 
         
            +
                return tokenizer_image_token(text, tokenizer, return_tensors="pt", return_ids=return_ids_only)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
         
     | 
| 117 | 
         
            +
                if not hasattr(tokenizer, "sentinel_token"):
         
     | 
| 118 | 
         
            +
                    tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
         
     | 
| 119 | 
         
            +
                    tokenizer.sentinel_token = SENTINEL_TOKEN
         
     | 
| 120 | 
         
            +
                    tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            def preprocess_conversation(
         
     | 
| 124 | 
         
            +
                conversation: Sequence[Dict[str, str]],
         
     | 
| 125 | 
         
            +
                tokenizer: transformers.PreTrainedTokenizer,
         
     | 
| 126 | 
         
            +
                no_system_prompt: bool = False,
         
     | 
| 127 | 
         
            +
                retried: bool = False,
         
     | 
| 128 | 
         
            +
                **kwargs: Any,
         
     | 
| 129 | 
         
            +
            ) -> Dict[str, Any]:
         
     | 
| 130 | 
         
            +
                inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
         
     | 
| 131 | 
         
            +
                labels = torch.ones_like(inputs) * IGNORE_INDEX
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                # Generate the template by replacing the assistant's response with a sentinel.
         
     | 
| 134 | 
         
            +
                _maybe_add_sentinel_token(tokenizer)
         
     | 
| 135 | 
         
            +
                template = tokenize_conversation(
         
     | 
| 136 | 
         
            +
                    conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
         
     | 
| 137 | 
         
            +
                )
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                # Remove sentinel tokens from the template.
         
     | 
| 140 | 
         
            +
                mask = torch.ones_like(template, dtype=torch.bool)
         
     | 
| 141 | 
         
            +
                for k in range(template.size(0) - 1):
         
     | 
| 142 | 
         
            +
                    if template[k] == tokenizer.sentinel_token_id:
         
     | 
| 143 | 
         
            +
                        mask[k : k + 2] = False
         
     | 
| 144 | 
         
            +
                        if k > 0 and retried:
         
     | 
| 145 | 
         
            +
                            mask[k - 1] = False
         
     | 
| 146 | 
         
            +
                template = template[mask]
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                # Match the tokenized conversation with the template (with no assistant's response).
         
     | 
| 149 | 
         
            +
                # Every token that is not matched will be included in the label for training.
         
     | 
| 150 | 
         
            +
                p = 0
         
     | 
| 151 | 
         
            +
                for k in range(inputs.size(0)):
         
     | 
| 152 | 
         
            +
                    if p < template.size(0) and inputs[k] == template[p]:
         
     | 
| 153 | 
         
            +
                        p += 1
         
     | 
| 154 | 
         
            +
                    else:
         
     | 
| 155 | 
         
            +
                        labels[k] = inputs[k]
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                # Mask all tokens in the label if the template is not fully matched.
         
     | 
| 158 | 
         
            +
                if p < template.size(0):
         
     | 
| 159 | 
         
            +
                    if not retried:
         
     | 
| 160 | 
         
            +
                        return preprocess_conversation(
         
     | 
| 161 | 
         
            +
                            conversation,
         
     | 
| 162 | 
         
            +
                            tokenizer,
         
     | 
| 163 | 
         
            +
                            no_system_prompt=no_system_prompt,
         
     | 
| 164 | 
         
            +
                            retried=True,
         
     | 
| 165 | 
         
            +
                        )
         
     | 
| 166 | 
         
            +
                    print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
         
     | 
| 167 | 
         
            +
                    labels[:] = IGNORE_INDEX
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                return {"input_ids": inputs, "labels": labels}
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
         
     | 
| 173 | 
         
            +
                _maybe_add_sentinel_token(tokenizer)
         
     | 
| 174 | 
         
            +
                template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                stop_tokens = {tokenizer.eos_token}
         
     | 
| 177 | 
         
            +
                for k in range(template.size(0) - 1):
         
     | 
| 178 | 
         
            +
                    if template[k] == tokenizer.sentinel_token_id:
         
     | 
| 179 | 
         
            +
                        stop_token = tokenizer.decode(template[k + 1])
         
     | 
| 180 | 
         
            +
                        stop_tokens.add(stop_token)
         
     | 
| 181 | 
         
            +
                return list(stop_tokens)
         
     | 
    	
        utils.py
    ADDED
    
    | 
         @@ -0,0 +1,211 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #      http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 16 | 
         
            +
            # This file is modified from https://github.com/haotian-liu/LLaVA/
         
     | 
| 17 | 
         
            +
            import os
         
     | 
| 18 | 
         
            +
            import os.path as osp
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from huggingface_hub import repo_exists, snapshot_download
         
     | 
| 21 | 
         
            +
            from huggingface_hub.utils import HFValidationError, validate_repo_id
         
     | 
| 22 | 
         
            +
            from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from .configuration_vila import VILAConfig
         
     | 
| 25 | 
         
            +
            from .constants import MEDIA_TOKENS
         
     | 
| 26 | 
         
            +
            from .tokenizer_utils import infer_stop_tokens
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def load_tokenizer_then_handle_media_tokens_and_chat_template(
         
     | 
| 30 | 
         
            +
                model_name_or_path, config: VILAConfig, model_max_length=None
         
     | 
| 31 | 
         
            +
            ):
         
     | 
| 32 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 33 | 
         
            +
                    osp.join(model_name_or_path, "llm"), padding_side="right", use_fast=True, legacy=False
         
     | 
| 34 | 
         
            +
                )
         
     | 
| 35 | 
         
            +
                if model_max_length is not None:
         
     | 
| 36 | 
         
            +
                    tokenizer.model_max_length = model_max_length
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                # Load chat template if specified.
         
     | 
| 39 | 
         
            +
                if getattr(config, "chat_template", None) is not None:
         
     | 
| 40 | 
         
            +
                    print(f"Using chat template: {config.chat_template}")
         
     | 
| 41 | 
         
            +
                    fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
         
     | 
| 42 | 
         
            +
                    if not os.path.exists(fpath):
         
     | 
| 43 | 
         
            +
                        fpath = os.path.join(model_name_or_path, f"{config.chat_template}.jinja")
         
     | 
| 44 | 
         
            +
                    with open(fpath) as fd:
         
     | 
| 45 | 
         
            +
                        chat_template = fd.read()
         
     | 
| 46 | 
         
            +
                    tokenizer.chat_template = chat_template.replace("    ", "").replace("\n", "")
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                # Set stop tokens for the tokenizer
         
     | 
| 49 | 
         
            +
                tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
         
     | 
| 50 | 
         
            +
                tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                # Add media tokens to the tokenizer
         
     | 
| 53 | 
         
            +
                tokenizer.media_tokens = MEDIA_TOKENS
         
     | 
| 54 | 
         
            +
                tokenizer.media_token_ids = {}
         
     | 
| 55 | 
         
            +
                for name, token in MEDIA_TOKENS.items():
         
     | 
| 56 | 
         
            +
                    tokenizer.add_tokens([token], special_tokens=True)
         
     | 
| 57 | 
         
            +
                    tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                return tokenizer
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def get_model_config(config):
         
     | 
| 63 | 
         
            +
                default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
         
     | 
| 66 | 
         
            +
                    root_path = config._name_or_path
         
     | 
| 67 | 
         
            +
                else:
         
     | 
| 68 | 
         
            +
                    root_path = config.resume_path
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                # download from huggingface
         
     | 
| 71 | 
         
            +
                if root_path is not None and not osp.exists(root_path):
         
     | 
| 72 | 
         
            +
                    try:
         
     | 
| 73 | 
         
            +
                        valid_hf_repo = repo_exists(root_path)
         
     | 
| 74 | 
         
            +
                    except HFValidationError as e:
         
     | 
| 75 | 
         
            +
                        valid_hf_repo = False
         
     | 
| 76 | 
         
            +
                    if valid_hf_repo:
         
     | 
| 77 | 
         
            +
                        root_path = snapshot_download(root_path)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                return_list = []
         
     | 
| 80 | 
         
            +
                for key in default_keys:
         
     | 
| 81 | 
         
            +
                    cfg = getattr(config, key, None)
         
     | 
| 82 | 
         
            +
                    if isinstance(cfg, dict):
         
     | 
| 83 | 
         
            +
                        try:
         
     | 
| 84 | 
         
            +
                            return_list.append(os.path.join(root_path, key[:-4]))
         
     | 
| 85 | 
         
            +
                        except:
         
     | 
| 86 | 
         
            +
                            raise ValueError(f"Cannot find resume path in config for {key}!")
         
     | 
| 87 | 
         
            +
                    elif isinstance(cfg, PretrainedConfig):
         
     | 
| 88 | 
         
            +
                        return_list.append(os.path.join(root_path, key[:-4]))
         
     | 
| 89 | 
         
            +
                    elif isinstance(cfg, str):
         
     | 
| 90 | 
         
            +
                        return_list.append(cfg)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                return return_list
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            def get_model_config_fp8(config):
         
     | 
| 96 | 
         
            +
                default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
         
     | 
| 99 | 
         
            +
                    root_path = config._name_or_path
         
     | 
| 100 | 
         
            +
                else:
         
     | 
| 101 | 
         
            +
                    root_path = config.resume_path
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                # download from huggingface
         
     | 
| 104 | 
         
            +
                if root_path is not None and not osp.exists(root_path):
         
     | 
| 105 | 
         
            +
                    try:
         
     | 
| 106 | 
         
            +
                        valid_hf_repo = repo_exists(root_path)
         
     | 
| 107 | 
         
            +
                    except HFValidationError as e:
         
     | 
| 108 | 
         
            +
                        valid_hf_repo = False
         
     | 
| 109 | 
         
            +
                    if valid_hf_repo:
         
     | 
| 110 | 
         
            +
                        root_path = snapshot_download(root_path)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                return_list = []
         
     | 
| 113 | 
         
            +
                for key in default_keys:
         
     | 
| 114 | 
         
            +
                    cfg = getattr(config, key, None)
         
     | 
| 115 | 
         
            +
                    if isinstance(cfg, dict):
         
     | 
| 116 | 
         
            +
                        try:
         
     | 
| 117 | 
         
            +
                            return_list.append(os.path.join(root_path, key[:-4]))
         
     | 
| 118 | 
         
            +
                        except:
         
     | 
| 119 | 
         
            +
                            raise ValueError(f"Cannot find resume path in config for {key}!")
         
     | 
| 120 | 
         
            +
                    elif isinstance(cfg, PretrainedConfig):
         
     | 
| 121 | 
         
            +
                        return_list.append(os.path.join(root_path, key[:-4]))
         
     | 
| 122 | 
         
            +
                    elif isinstance(cfg, str):
         
     | 
| 123 | 
         
            +
                        return_list.append(cfg)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                # fp8_llm
         
     | 
| 126 | 
         
            +
                key = "fp8_llm_cfg"
         
     | 
| 127 | 
         
            +
                directory_path = os.path.join(root_path, key[:-4])
         
     | 
| 128 | 
         
            +
                assert os.path.isdir(directory_path) and os.listdir(
         
     | 
| 129 | 
         
            +
                    directory_path
         
     | 
| 130 | 
         
            +
                ), "You need to first convert the model weights to FP8 explicitly."
         
     | 
| 131 | 
         
            +
                return_list.append(directory_path)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                return return_list
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def get_model_config_fp8(config):
         
     | 
| 137 | 
         
            +
                default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
         
     | 
| 140 | 
         
            +
                    root_path = config._name_or_path
         
     | 
| 141 | 
         
            +
                else:
         
     | 
| 142 | 
         
            +
                    root_path = config.resume_path
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                # download from huggingface
         
     | 
| 145 | 
         
            +
                if root_path is not None and not osp.exists(root_path):
         
     | 
| 146 | 
         
            +
                    try:
         
     | 
| 147 | 
         
            +
                        valid_hf_repo = repo_exists(root_path)
         
     | 
| 148 | 
         
            +
                    except HFValidationError as e:
         
     | 
| 149 | 
         
            +
                        valid_hf_repo = False
         
     | 
| 150 | 
         
            +
                    if valid_hf_repo:
         
     | 
| 151 | 
         
            +
                        root_path = snapshot_download(root_path)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                return_list = []
         
     | 
| 154 | 
         
            +
                for key in default_keys:
         
     | 
| 155 | 
         
            +
                    cfg = getattr(config, key, None)
         
     | 
| 156 | 
         
            +
                    if isinstance(cfg, dict):
         
     | 
| 157 | 
         
            +
                        try:
         
     | 
| 158 | 
         
            +
                            return_list.append(os.path.join(root_path, key[:-4]))
         
     | 
| 159 | 
         
            +
                        except:
         
     | 
| 160 | 
         
            +
                            raise ValueError(f"Cannot find resume path in config for {key}!")
         
     | 
| 161 | 
         
            +
                    elif isinstance(cfg, PretrainedConfig):
         
     | 
| 162 | 
         
            +
                        return_list.append(os.path.join(root_path, key[:-4]))
         
     | 
| 163 | 
         
            +
                    elif isinstance(cfg, str):
         
     | 
| 164 | 
         
            +
                        return_list.append(cfg)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                # fp8_llm
         
     | 
| 167 | 
         
            +
                key = "fp8_llm_cfg"
         
     | 
| 168 | 
         
            +
                directory_path = os.path.join(root_path, key[:-4])
         
     | 
| 169 | 
         
            +
                assert os.path.isdir(directory_path) and os.listdir(
         
     | 
| 170 | 
         
            +
                    directory_path
         
     | 
| 171 | 
         
            +
                ), "You need to first convert the model weights to FP8 explicitly."
         
     | 
| 172 | 
         
            +
                return_list.append(directory_path)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                return return_list
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
            def is_mm_model(model_path):
         
     | 
| 178 | 
         
            +
                """
         
     | 
| 179 | 
         
            +
                Check if the model at the given path is a visual language model.
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                Args:
         
     | 
| 182 | 
         
            +
                    model_path (str): The path to the model.
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                Returns:
         
     | 
| 185 | 
         
            +
                    bool: True if the model is an MM model, False otherwise.
         
     | 
| 186 | 
         
            +
                """
         
     | 
| 187 | 
         
            +
                config = AutoConfig.from_pretrained(model_path)
         
     | 
| 188 | 
         
            +
                architectures = config.architectures
         
     | 
| 189 | 
         
            +
                for architecture in architectures:
         
     | 
| 190 | 
         
            +
                    if "llava" in architecture.lower():
         
     | 
| 191 | 
         
            +
                        return True
         
     | 
| 192 | 
         
            +
                return False
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
            def auto_upgrade(config):
         
     | 
| 196 | 
         
            +
                cfg = AutoConfig.from_pretrained(config)
         
     | 
| 197 | 
         
            +
                if "llava" in config and "llava" not in cfg.model_type:
         
     | 
| 198 | 
         
            +
                    assert cfg.model_type == "llama"
         
     | 
| 199 | 
         
            +
                    print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
         
     | 
| 200 | 
         
            +
                    print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
         
     | 
| 201 | 
         
            +
                    confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
         
     | 
| 202 | 
         
            +
                    if confirm.lower() in ["y", "yes"]:
         
     | 
| 203 | 
         
            +
                        print("Upgrading checkpoint...")
         
     | 
| 204 | 
         
            +
                        assert len(cfg.architectures) == 1
         
     | 
| 205 | 
         
            +
                        setattr(cfg.__class__, "model_type", "llava")
         
     | 
| 206 | 
         
            +
                        cfg.architectures[0] = "LlavaLlamaForCausalLM"
         
     | 
| 207 | 
         
            +
                        cfg.save_pretrained(config)
         
     | 
| 208 | 
         
            +
                        print("Checkpoint upgraded.")
         
     | 
| 209 | 
         
            +
                    else:
         
     | 
| 210 | 
         
            +
                        print("Checkpoint upgrade aborted.")
         
     | 
| 211 | 
         
            +
                        exit(1)
         
     |