Add a remote code file for transformers integration 🤗
#2
by
						
reach-vb
	
							HF Staff
						- opened
							
					
- README.md +58 -1
- config.json +4 -0
- llava_qwen.py +2195 -0
    	
        README.md
    CHANGED
    
    | @@ -3,6 +3,8 @@ license: apple-amlr | |
| 3 | 
             
            license_name: apple-ascl
         | 
| 4 | 
             
            license_link: https://github.com/apple/ml-fastvlm/blob/main/LICENSE_MODEL
         | 
| 5 | 
             
            library_name: ml-fastvlm
         | 
|  | |
|  | |
| 6 | 
             
            ---
         | 
| 7 | 
             
            # FastVLM: Efficient Vision Encoding for Vision Language Models
         | 
| 8 |  | 
| @@ -51,6 +53,61 @@ python predict.py --model-path /path/to/checkpoint-dir \ | |
| 51 | 
             
                              --prompt "Describe the image."
         | 
| 52 | 
             
            ```
         | 
| 53 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 |  | 
| 55 | 
             
            ## Citation
         | 
| 56 | 
             
            If you found this model useful, please cite the following paper:
         | 
| @@ -62,4 +119,4 @@ If you found this model useful, please cite the following paper: | |
| 62 | 
             
              month = {June},
         | 
| 63 | 
             
              year = {2025},
         | 
| 64 | 
             
            }
         | 
| 65 | 
            -
            ```
         | 
|  | |
| 3 | 
             
            license_name: apple-ascl
         | 
| 4 | 
             
            license_link: https://github.com/apple/ml-fastvlm/blob/main/LICENSE_MODEL
         | 
| 5 | 
             
            library_name: ml-fastvlm
         | 
| 6 | 
            +
            tags:
         | 
| 7 | 
            +
            - transformers
         | 
| 8 | 
             
            ---
         | 
| 9 | 
             
            # FastVLM: Efficient Vision Encoding for Vision Language Models
         | 
| 10 |  | 
|  | |
| 53 | 
             
                              --prompt "Describe the image."
         | 
| 54 | 
             
            ```
         | 
| 55 |  | 
| 56 | 
            +
            ### Run inference with Transformers (Remote Code)
         | 
| 57 | 
            +
            To run inference with transformers we can leverage `trust_remote_code` along with the following snippet:
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            ```python
         | 
| 60 | 
            +
            import torch
         | 
| 61 | 
            +
            from PIL import Image
         | 
| 62 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            MID = "apple/FastVLM-0.5B"
         | 
| 65 | 
            +
            IMAGE_TOKEN_INDEX = -200  # what the model code looks for
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            # Load
         | 
| 68 | 
            +
            tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
         | 
| 69 | 
            +
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 70 | 
            +
                MID,
         | 
| 71 | 
            +
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         | 
| 72 | 
            +
                device_map="auto",
         | 
| 73 | 
            +
                trust_remote_code=True,
         | 
| 74 | 
            +
            )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            # Build chat -> render to string (not tokens) so we can place <image> exactly
         | 
| 77 | 
            +
            messages = [
         | 
| 78 | 
            +
                {"role": "user", "content": "<image>\nDescribe this image in detail."}
         | 
| 79 | 
            +
            ]
         | 
| 80 | 
            +
            rendered = tok.apply_chat_template(
         | 
| 81 | 
            +
                messages, add_generation_prompt=True, tokenize=False
         | 
| 82 | 
            +
            )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            pre, post = rendered.split("<image>", 1)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            # Tokenize the text *around* the image token (no extra specials!)
         | 
| 87 | 
            +
            pre_ids  = tok(pre,  return_tensors="pt", add_special_tokens=False).input_ids
         | 
| 88 | 
            +
            post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            # Splice in the IMAGE token id (-200) at the placeholder position
         | 
| 91 | 
            +
            img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
         | 
| 92 | 
            +
            input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
         | 
| 93 | 
            +
            attention_mask = torch.ones_like(input_ids, device=model.device)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            # Preprocess image via the model's own processor
         | 
| 96 | 
            +
            img = Image.open("test-2.jpg").convert("RGB")
         | 
| 97 | 
            +
            px = model.get_vision_tower().image_processor(images=img, return_tensors="pt")["pixel_values"]
         | 
| 98 | 
            +
            px = px.to(model.device, dtype=model.dtype)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            # Generate
         | 
| 101 | 
            +
            with torch.no_grad():
         | 
| 102 | 
            +
                out = model.generate(
         | 
| 103 | 
            +
                    inputs=input_ids,
         | 
| 104 | 
            +
                    attention_mask=attention_mask,
         | 
| 105 | 
            +
                    images=px,
         | 
| 106 | 
            +
                    max_new_tokens=128,
         | 
| 107 | 
            +
                )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            print(tok.decode(out[0], skip_special_tokens=True))
         | 
| 110 | 
            +
            ```
         | 
| 111 |  | 
| 112 | 
             
            ## Citation
         | 
| 113 | 
             
            If you found this model useful, please cite the following paper:
         | 
|  | |
| 119 | 
             
              month = {June},
         | 
| 120 | 
             
              year = {2025},
         | 
| 121 | 
             
            }
         | 
| 122 | 
            +
            ```
         | 
    	
        config.json
    CHANGED
    
    | @@ -3,6 +3,10 @@ | |
| 3 | 
             
              "architectures": [
         | 
| 4 | 
             
                "LlavaQwen2ForCausalLM"
         | 
| 5 | 
             
              ],
         | 
|  | |
|  | |
|  | |
|  | |
| 6 | 
             
              "attention_dropout": 0.0,
         | 
| 7 | 
             
              "bos_token_id": 151643,
         | 
| 8 | 
             
              "eos_token_id": 151645,
         | 
|  | |
| 3 | 
             
              "architectures": [
         | 
| 4 | 
             
                "LlavaQwen2ForCausalLM"
         | 
| 5 | 
             
              ],
         | 
| 6 | 
            +
              "auto_map": {
         | 
| 7 | 
            +
                "AutoConfig": "llava_qwen.LlavaConfig",
         | 
| 8 | 
            +
                "AutoModelForCausalLM": "llava_qwen.LlavaQwen2ForCausalLM"
         | 
| 9 | 
            +
              },  
         | 
| 10 | 
             
              "attention_dropout": 0.0,
         | 
| 11 | 
             
              "bos_token_id": 151643,
         | 
| 12 | 
             
              "eos_token_id": 151645,
         | 
    	
        llava_qwen.py
    ADDED
    
    | @@ -0,0 +1,2195 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #    Copyright 2023 Haotian Liu
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            #    Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            #    you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            #    You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #        http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #    Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            #    distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            #    See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            #    limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import re
         | 
| 19 | 
            +
            import copy
         | 
| 20 | 
            +
            from timm.models import create_model
         | 
| 21 | 
            +
            from abc import ABC, abstractmethod
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            import torch
         | 
| 24 | 
            +
            import torch.nn as nn
         | 
| 25 | 
            +
            from torch import Tensor
         | 
| 26 | 
            +
            import torch.nn.functional as F
         | 
| 27 | 
            +
            from torch.nn.init import normal_
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            from transformers import CLIPImageProcessor
         | 
| 30 | 
            +
            from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         | 
| 33 | 
            +
            from transformers.generation.utils import GenerateOutput
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            from functools import partial
         | 
| 36 | 
            +
            from typing import List, Tuple, Optional, Union, Dict, Any
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            from timm.models import register_model
         | 
| 39 | 
            +
            from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
         | 
| 40 | 
            +
            from timm.layers import DropPath, SqueezeExcite
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            CONTROLLER_HEART_BEAT_EXPIRATION = 30
         | 
| 43 | 
            +
            WORKER_HEART_BEAT_INTERVAL = 15
         | 
| 44 | 
            +
            LOGDIR = "."
         | 
| 45 | 
            +
            # Model Constants
         | 
| 46 | 
            +
            IGNORE_INDEX = -100
         | 
| 47 | 
            +
            IMAGE_TOKEN_INDEX = -200
         | 
| 48 | 
            +
            DEFAULT_IMAGE_TOKEN = "<image>"
         | 
| 49 | 
            +
            DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
         | 
| 50 | 
            +
            DEFAULT_IM_START_TOKEN = "<im_start>"
         | 
| 51 | 
            +
            DEFAULT_IM_END_TOKEN = "<im_end>"
         | 
| 52 | 
            +
            IMAGE_PLACEHOLDER = "<image-placeholder>"
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            class LlavaConfig(Qwen2Config):
         | 
| 55 | 
            +
                model_type = "llava_qwen2"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def _cfg(url="", **kwargs):
         | 
| 58 | 
            +
                return {
         | 
| 59 | 
            +
                    "url": url,
         | 
| 60 | 
            +
                    "num_classes": 1000,
         | 
| 61 | 
            +
                    "input_size": (3, 256, 256),
         | 
| 62 | 
            +
                    "pool_size": None,
         | 
| 63 | 
            +
                    "crop_pct": 0.95,
         | 
| 64 | 
            +
                    "interpolation": "bicubic",
         | 
| 65 | 
            +
                    "mean": IMAGENET_DEFAULT_MEAN,
         | 
| 66 | 
            +
                    "std": IMAGENET_DEFAULT_STD,
         | 
| 67 | 
            +
                    "classifier": "head",
         | 
| 68 | 
            +
                    **kwargs,
         | 
| 69 | 
            +
                }
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            default_cfgs = {
         | 
| 73 | 
            +
                "fastvit_t": _cfg(crop_pct=0.9),
         | 
| 74 | 
            +
                "fastvit_s": _cfg(crop_pct=0.9),
         | 
| 75 | 
            +
                "fastvit_m": _cfg(crop_pct=0.95),
         | 
| 76 | 
            +
            }
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class SEBlock(nn.Module):
         | 
| 80 | 
            +
                """Squeeze and Excite module.
         | 
| 81 | 
            +
                Pytorch implementation of `Squeeze-and-Excitation Networks` -
         | 
| 82 | 
            +
                https://arxiv.org/pdf/1709.01507.pdf
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
         | 
| 86 | 
            +
                    """Construct a Squeeze and Excite Module.
         | 
| 87 | 
            +
                    Args:
         | 
| 88 | 
            +
                        in_channels: Number of input channels.
         | 
| 89 | 
            +
                        rd_ratio: Input channel reduction ratio.
         | 
| 90 | 
            +
                    """
         | 
| 91 | 
            +
                    super(SEBlock, self).__init__()
         | 
| 92 | 
            +
                    self.reduce = nn.Conv2d(
         | 
| 93 | 
            +
                        in_channels=in_channels,
         | 
| 94 | 
            +
                        out_channels=int(in_channels * rd_ratio),
         | 
| 95 | 
            +
                        kernel_size=1,
         | 
| 96 | 
            +
                        stride=1,
         | 
| 97 | 
            +
                        bias=True,
         | 
| 98 | 
            +
                    )
         | 
| 99 | 
            +
                    self.expand = nn.Conv2d(
         | 
| 100 | 
            +
                        in_channels=int(in_channels * rd_ratio),
         | 
| 101 | 
            +
                        out_channels=in_channels,
         | 
| 102 | 
            +
                        kernel_size=1,
         | 
| 103 | 
            +
                        stride=1,
         | 
| 104 | 
            +
                        bias=True,
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         | 
| 108 | 
            +
                    """Apply forward pass."""
         | 
| 109 | 
            +
                    b, c, h, w = inputs.size()
         | 
| 110 | 
            +
                    # x = F.avg_pool2d(inputs, kernel_size=[h, w])
         | 
| 111 | 
            +
                    x = F.avg_pool2d(inputs, kernel_size=[16, 16])
         | 
| 112 | 
            +
                    x = self.reduce(x)
         | 
| 113 | 
            +
                    x = F.relu(x)
         | 
| 114 | 
            +
                    x = self.expand(x)
         | 
| 115 | 
            +
                    x = torch.sigmoid(x)
         | 
| 116 | 
            +
                    x = x.view(-1, c, 1, 1)
         | 
| 117 | 
            +
                    return inputs * x
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            class MobileOneBlock(nn.Module):
         | 
| 121 | 
            +
                """MobileOne building block.
         | 
| 122 | 
            +
                This block has a multi-branched architecture at train-time
         | 
| 123 | 
            +
                and plain-CNN style architecture at inference time
         | 
| 124 | 
            +
                For more details, please refer to our paper:
         | 
| 125 | 
            +
                `An Improved One millisecond Mobile Backbone` -
         | 
| 126 | 
            +
                https://arxiv.org/pdf/2206.04040.pdf
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def __init__(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    in_channels: int,
         | 
| 132 | 
            +
                    out_channels: int,
         | 
| 133 | 
            +
                    kernel_size: int,
         | 
| 134 | 
            +
                    stride: int = 1,
         | 
| 135 | 
            +
                    padding: int = 0,
         | 
| 136 | 
            +
                    dilation: int = 1,
         | 
| 137 | 
            +
                    groups: int = 1,
         | 
| 138 | 
            +
                    inference_mode: bool = False,
         | 
| 139 | 
            +
                    use_se: bool = False,
         | 
| 140 | 
            +
                    use_act: bool = True,
         | 
| 141 | 
            +
                    use_scale_branch: bool = True,
         | 
| 142 | 
            +
                    num_conv_branches: int = 1,
         | 
| 143 | 
            +
                    activation: nn.Module = nn.GELU(),
         | 
| 144 | 
            +
                ) -> None:
         | 
| 145 | 
            +
                    """Construct a MobileOneBlock module.
         | 
| 146 | 
            +
                    Args:
         | 
| 147 | 
            +
                        in_channels: Number of channels in the input.
         | 
| 148 | 
            +
                        out_channels: Number of channels produced by the block.
         | 
| 149 | 
            +
                        kernel_size: Size of the convolution kernel.
         | 
| 150 | 
            +
                        stride: Stride size.
         | 
| 151 | 
            +
                        padding: Zero-padding size.
         | 
| 152 | 
            +
                        dilation: Kernel dilation factor.
         | 
| 153 | 
            +
                        groups: Group number.
         | 
| 154 | 
            +
                        inference_mode: If True, instantiates model in inference mode.
         | 
| 155 | 
            +
                        use_se: Whether to use SE-ReLU activations.
         | 
| 156 | 
            +
                        use_act: Whether to use activation. Default: ``True``
         | 
| 157 | 
            +
                        use_scale_branch: Whether to use scale branch. Default: ``True``
         | 
| 158 | 
            +
                        num_conv_branches: Number of linear conv branches.
         | 
| 159 | 
            +
                    """
         | 
| 160 | 
            +
                    super(MobileOneBlock, self).__init__()
         | 
| 161 | 
            +
                    self.inference_mode = inference_mode
         | 
| 162 | 
            +
                    self.groups = groups
         | 
| 163 | 
            +
                    self.stride = stride
         | 
| 164 | 
            +
                    self.padding = padding
         | 
| 165 | 
            +
                    self.dilation = dilation
         | 
| 166 | 
            +
                    self.kernel_size = kernel_size
         | 
| 167 | 
            +
                    self.in_channels = in_channels
         | 
| 168 | 
            +
                    self.out_channels = out_channels
         | 
| 169 | 
            +
                    self.num_conv_branches = num_conv_branches
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # Check if SE-ReLU is requested
         | 
| 172 | 
            +
                    if use_se:
         | 
| 173 | 
            +
                        self.se = SEBlock(out_channels)
         | 
| 174 | 
            +
                    else:
         | 
| 175 | 
            +
                        self.se = nn.Identity()
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    if use_act:
         | 
| 178 | 
            +
                        self.activation = activation
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        self.activation = nn.Identity()
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    if inference_mode:
         | 
| 183 | 
            +
                        self.reparam_conv = nn.Conv2d(
         | 
| 184 | 
            +
                            in_channels=in_channels,
         | 
| 185 | 
            +
                            out_channels=out_channels,
         | 
| 186 | 
            +
                            kernel_size=kernel_size,
         | 
| 187 | 
            +
                            stride=stride,
         | 
| 188 | 
            +
                            padding=padding,
         | 
| 189 | 
            +
                            dilation=dilation,
         | 
| 190 | 
            +
                            groups=groups,
         | 
| 191 | 
            +
                            bias=True,
         | 
| 192 | 
            +
                        )
         | 
| 193 | 
            +
                    else:
         | 
| 194 | 
            +
                        # Re-parameterizable skip connection
         | 
| 195 | 
            +
                        # Fallback, sometimes batchnorm tensors
         | 
| 196 | 
            +
                        # do not get instantiated correctly on some processes
         | 
| 197 | 
            +
                        # when using deepspeed + accelerate
         | 
| 198 | 
            +
                        norm_layer = nn.BatchNorm2d(num_features=in_channels)
         | 
| 199 | 
            +
                        if norm_layer.weight.shape[0] == 0:
         | 
| 200 | 
            +
                            norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
         | 
| 201 | 
            +
                        if norm_layer.bias.shape[0] == 0:
         | 
| 202 | 
            +
                            norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                        self.rbr_skip = (
         | 
| 205 | 
            +
                            norm_layer
         | 
| 206 | 
            +
                            if out_channels == in_channels and stride == 1
         | 
| 207 | 
            +
                            else None
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                        # Re-parameterizable conv branches
         | 
| 211 | 
            +
                        if num_conv_branches > 0:
         | 
| 212 | 
            +
                            rbr_conv = list()
         | 
| 213 | 
            +
                            for _ in range(self.num_conv_branches):
         | 
| 214 | 
            +
                                rbr_conv.append(
         | 
| 215 | 
            +
                                    self._conv_bn(kernel_size=kernel_size, padding=padding)
         | 
| 216 | 
            +
                                )
         | 
| 217 | 
            +
                            self.rbr_conv = nn.ModuleList(rbr_conv)
         | 
| 218 | 
            +
                        else:
         | 
| 219 | 
            +
                            self.rbr_conv = None
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                        # Re-parameterizable scale branch
         | 
| 222 | 
            +
                        self.rbr_scale = None
         | 
| 223 | 
            +
                        if not isinstance(kernel_size, int):
         | 
| 224 | 
            +
                            kernel_size = kernel_size[0]
         | 
| 225 | 
            +
                        if (kernel_size > 1) and use_scale_branch:
         | 
| 226 | 
            +
                            self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 229 | 
            +
                    """Apply forward pass."""
         | 
| 230 | 
            +
                    # Inference mode forward pass.
         | 
| 231 | 
            +
                    if self.inference_mode:
         | 
| 232 | 
            +
                        return self.activation(self.se(self.reparam_conv(x)))
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    # Multi-branched train-time forward pass.
         | 
| 235 | 
            +
                    # Skip branch output
         | 
| 236 | 
            +
                    identity_out = 0
         | 
| 237 | 
            +
                    if self.rbr_skip is not None:
         | 
| 238 | 
            +
                        identity_out = self.rbr_skip(x)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    # Scale branch output
         | 
| 241 | 
            +
                    scale_out = 0
         | 
| 242 | 
            +
                    if self.rbr_scale is not None:
         | 
| 243 | 
            +
                        scale_out = self.rbr_scale(x)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    # Other branches
         | 
| 246 | 
            +
                    out = scale_out + identity_out
         | 
| 247 | 
            +
                    if self.rbr_conv is not None:
         | 
| 248 | 
            +
                        for ix in range(self.num_conv_branches):
         | 
| 249 | 
            +
                            out += self.rbr_conv[ix](x)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    return self.activation(self.se(out))
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def reparameterize(self):
         | 
| 254 | 
            +
                    """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
         | 
| 255 | 
            +
                    https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
         | 
| 256 | 
            +
                    architecture used at training time to obtain a plain CNN-like structure
         | 
| 257 | 
            +
                    for inference.
         | 
| 258 | 
            +
                    """
         | 
| 259 | 
            +
                    if self.inference_mode:
         | 
| 260 | 
            +
                        return
         | 
| 261 | 
            +
                    kernel, bias = self._get_kernel_bias()
         | 
| 262 | 
            +
                    self.reparam_conv = nn.Conv2d(
         | 
| 263 | 
            +
                        in_channels=self.in_channels,
         | 
| 264 | 
            +
                        out_channels=self.out_channels,
         | 
| 265 | 
            +
                        kernel_size=self.kernel_size,
         | 
| 266 | 
            +
                        stride=self.stride,
         | 
| 267 | 
            +
                        padding=self.padding,
         | 
| 268 | 
            +
                        dilation=self.dilation,
         | 
| 269 | 
            +
                        groups=self.groups,
         | 
| 270 | 
            +
                        bias=True,
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
                    self.reparam_conv.weight.data = kernel
         | 
| 273 | 
            +
                    self.reparam_conv.bias.data = bias
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # Delete un-used branches
         | 
| 276 | 
            +
                    self.__delattr__("rbr_conv")
         | 
| 277 | 
            +
                    self.__delattr__("rbr_scale")
         | 
| 278 | 
            +
                    if hasattr(self, "rbr_skip"):
         | 
| 279 | 
            +
                        self.__delattr__("rbr_skip")
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    self.inference_mode = True
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 284 | 
            +
                    """Method to obtain re-parameterized kernel and bias.
         | 
| 285 | 
            +
                    Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
         | 
| 286 | 
            +
                    Returns:
         | 
| 287 | 
            +
                        Tuple of (kernel, bias) after fusing branches.
         | 
| 288 | 
            +
                    """
         | 
| 289 | 
            +
                    # get weights and bias of scale branch
         | 
| 290 | 
            +
                    kernel_scale = 0
         | 
| 291 | 
            +
                    bias_scale = 0
         | 
| 292 | 
            +
                    if self.rbr_scale is not None:
         | 
| 293 | 
            +
                        kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
         | 
| 294 | 
            +
                        # Pad scale branch kernel to match conv branch kernel size.
         | 
| 295 | 
            +
                        pad = self.kernel_size // 2
         | 
| 296 | 
            +
                        kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # get weights and bias of skip branch
         | 
| 299 | 
            +
                    kernel_identity = 0
         | 
| 300 | 
            +
                    bias_identity = 0
         | 
| 301 | 
            +
                    if self.rbr_skip is not None:
         | 
| 302 | 
            +
                        kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    # get weights and bias of conv branches
         | 
| 305 | 
            +
                    kernel_conv = 0
         | 
| 306 | 
            +
                    bias_conv = 0
         | 
| 307 | 
            +
                    if self.rbr_conv is not None:
         | 
| 308 | 
            +
                        for ix in range(self.num_conv_branches):
         | 
| 309 | 
            +
                            _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
         | 
| 310 | 
            +
                            kernel_conv += _kernel
         | 
| 311 | 
            +
                            bias_conv += _bias
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    kernel_final = kernel_conv + kernel_scale + kernel_identity
         | 
| 314 | 
            +
                    bias_final = bias_conv + bias_scale + bias_identity
         | 
| 315 | 
            +
                    return kernel_final, bias_final
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def _fuse_bn_tensor(
         | 
| 318 | 
            +
                    self, branch: Union[nn.Sequential, nn.BatchNorm2d]
         | 
| 319 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 320 | 
            +
                    """Method to fuse batchnorm layer with preceeding conv layer.
         | 
| 321 | 
            +
                    Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
         | 
| 322 | 
            +
                    Args:
         | 
| 323 | 
            +
                        branch: Sequence of ops to be fused.
         | 
| 324 | 
            +
                    Returns:
         | 
| 325 | 
            +
                        Tuple of (kernel, bias) after fusing batchnorm.
         | 
| 326 | 
            +
                    """
         | 
| 327 | 
            +
                    if isinstance(branch, nn.Sequential):
         | 
| 328 | 
            +
                        kernel = branch.conv.weight
         | 
| 329 | 
            +
                        running_mean = branch.bn.running_mean
         | 
| 330 | 
            +
                        running_var = branch.bn.running_var
         | 
| 331 | 
            +
                        gamma = branch.bn.weight
         | 
| 332 | 
            +
                        beta = branch.bn.bias
         | 
| 333 | 
            +
                        eps = branch.bn.eps
         | 
| 334 | 
            +
                    else:
         | 
| 335 | 
            +
                        assert isinstance(branch, nn.BatchNorm2d)
         | 
| 336 | 
            +
                        if not hasattr(self, "id_tensor"):
         | 
| 337 | 
            +
                            input_dim = self.in_channels // self.groups
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                            kernel_size = self.kernel_size
         | 
| 340 | 
            +
                            if isinstance(self.kernel_size, int):
         | 
| 341 | 
            +
                                kernel_size = (self.kernel_size, self.kernel_size)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                            kernel_value = torch.zeros(
         | 
| 344 | 
            +
                                (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
         | 
| 345 | 
            +
                                dtype=branch.weight.dtype,
         | 
| 346 | 
            +
                                device=branch.weight.device,
         | 
| 347 | 
            +
                            )
         | 
| 348 | 
            +
                            for i in range(self.in_channels):
         | 
| 349 | 
            +
                                kernel_value[
         | 
| 350 | 
            +
                                    i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
         | 
| 351 | 
            +
                                ] = 1
         | 
| 352 | 
            +
                            self.id_tensor = kernel_value
         | 
| 353 | 
            +
                        kernel = self.id_tensor
         | 
| 354 | 
            +
                        running_mean = branch.running_mean
         | 
| 355 | 
            +
                        running_var = branch.running_var
         | 
| 356 | 
            +
                        gamma = branch.weight
         | 
| 357 | 
            +
                        beta = branch.bias
         | 
| 358 | 
            +
                        eps = branch.eps
         | 
| 359 | 
            +
                    std = (running_var + eps).sqrt()
         | 
| 360 | 
            +
                    t = (gamma / std).reshape(-1, 1, 1, 1)
         | 
| 361 | 
            +
                    return kernel * t, beta - running_mean * gamma / std
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
         | 
| 364 | 
            +
                    """Helper method to construct conv-batchnorm layers.
         | 
| 365 | 
            +
                    Args:
         | 
| 366 | 
            +
                        kernel_size: Size of the convolution kernel.
         | 
| 367 | 
            +
                        padding: Zero-padding size.
         | 
| 368 | 
            +
                    Returns:
         | 
| 369 | 
            +
                        Conv-BN module.
         | 
| 370 | 
            +
                    """
         | 
| 371 | 
            +
                    # Fallback, sometimes batchnorm tensors
         | 
| 372 | 
            +
                    # do not get instantiated correctly on some processes
         | 
| 373 | 
            +
                    # when using deepspeed + accelerate
         | 
| 374 | 
            +
                    norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
         | 
| 375 | 
            +
                    if norm_layer.weight.shape[0] == 0:
         | 
| 376 | 
            +
                        norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
         | 
| 377 | 
            +
                    if norm_layer.bias.shape[0] == 0:
         | 
| 378 | 
            +
                        norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    mod_list = nn.Sequential()
         | 
| 381 | 
            +
                    mod_list.add_module(
         | 
| 382 | 
            +
                        "conv",
         | 
| 383 | 
            +
                        nn.Conv2d(
         | 
| 384 | 
            +
                            in_channels=self.in_channels,
         | 
| 385 | 
            +
                            out_channels=self.out_channels,
         | 
| 386 | 
            +
                            kernel_size=kernel_size,
         | 
| 387 | 
            +
                            stride=self.stride,
         | 
| 388 | 
            +
                            padding=padding,
         | 
| 389 | 
            +
                            groups=self.groups,
         | 
| 390 | 
            +
                            bias=False,
         | 
| 391 | 
            +
                        ),
         | 
| 392 | 
            +
                    )
         | 
| 393 | 
            +
                    mod_list.add_module("bn", norm_layer)
         | 
| 394 | 
            +
                    return mod_list
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            class ReparamLargeKernelConv(nn.Module):
         | 
| 398 | 
            +
                """Building Block of RepLKNet
         | 
| 399 | 
            +
                This class defines overparameterized large kernel conv block
         | 
| 400 | 
            +
                introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
         | 
| 401 | 
            +
                Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
         | 
| 402 | 
            +
                """
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def __init__(
         | 
| 405 | 
            +
                    self,
         | 
| 406 | 
            +
                    in_channels: int,
         | 
| 407 | 
            +
                    out_channels: int,
         | 
| 408 | 
            +
                    kernel_size: int,
         | 
| 409 | 
            +
                    stride: int,
         | 
| 410 | 
            +
                    groups: int,
         | 
| 411 | 
            +
                    small_kernel: int,
         | 
| 412 | 
            +
                    inference_mode: bool = False,
         | 
| 413 | 
            +
                    use_se: bool = False,
         | 
| 414 | 
            +
                    activation: nn.Module = nn.GELU(),
         | 
| 415 | 
            +
                ) -> None:
         | 
| 416 | 
            +
                    """Construct a ReparamLargeKernelConv module.
         | 
| 417 | 
            +
                    Args:
         | 
| 418 | 
            +
                        in_channels: Number of input channels.
         | 
| 419 | 
            +
                        out_channels: Number of output channels.
         | 
| 420 | 
            +
                        kernel_size: Kernel size of the large kernel conv branch.
         | 
| 421 | 
            +
                        stride: Stride size. Default: 1
         | 
| 422 | 
            +
                        groups: Group number. Default: 1
         | 
| 423 | 
            +
                        small_kernel: Kernel size of small kernel conv branch.
         | 
| 424 | 
            +
                        inference_mode: If True, instantiates model in inference mode. Default: ``False``
         | 
| 425 | 
            +
                        activation: Activation module. Default: ``nn.GELU``
         | 
| 426 | 
            +
                    """
         | 
| 427 | 
            +
                    super(ReparamLargeKernelConv, self).__init__()
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.stride = stride
         | 
| 430 | 
            +
                    self.groups = groups
         | 
| 431 | 
            +
                    self.in_channels = in_channels
         | 
| 432 | 
            +
                    self.out_channels = out_channels
         | 
| 433 | 
            +
                    self.activation = activation
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    self.kernel_size = kernel_size
         | 
| 436 | 
            +
                    self.small_kernel = small_kernel
         | 
| 437 | 
            +
                    self.padding = kernel_size // 2
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    # Check if SE is requested
         | 
| 440 | 
            +
                    if use_se:
         | 
| 441 | 
            +
                        self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
         | 
| 442 | 
            +
                    else:
         | 
| 443 | 
            +
                        self.se = nn.Identity()
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    if inference_mode:
         | 
| 446 | 
            +
                        self.lkb_reparam = nn.Conv2d(
         | 
| 447 | 
            +
                            in_channels=in_channels,
         | 
| 448 | 
            +
                            out_channels=out_channels,
         | 
| 449 | 
            +
                            kernel_size=kernel_size,
         | 
| 450 | 
            +
                            stride=stride,
         | 
| 451 | 
            +
                            padding=self.padding,
         | 
| 452 | 
            +
                            dilation=1,
         | 
| 453 | 
            +
                            groups=groups,
         | 
| 454 | 
            +
                            bias=True,
         | 
| 455 | 
            +
                        )
         | 
| 456 | 
            +
                    else:
         | 
| 457 | 
            +
                        self.lkb_origin = self._conv_bn(
         | 
| 458 | 
            +
                            kernel_size=kernel_size, padding=self.padding
         | 
| 459 | 
            +
                        )
         | 
| 460 | 
            +
                        if small_kernel is not None:
         | 
| 461 | 
            +
                            assert (
         | 
| 462 | 
            +
                                small_kernel <= kernel_size
         | 
| 463 | 
            +
                            ), "The kernel size for re-param cannot be larger than the large kernel!"
         | 
| 464 | 
            +
                            self.small_conv = self._conv_bn(
         | 
| 465 | 
            +
                                kernel_size=small_kernel, padding=small_kernel // 2
         | 
| 466 | 
            +
                            )
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 469 | 
            +
                    """Apply forward pass."""
         | 
| 470 | 
            +
                    if hasattr(self, "lkb_reparam"):
         | 
| 471 | 
            +
                        out = self.lkb_reparam(x)
         | 
| 472 | 
            +
                    else:
         | 
| 473 | 
            +
                        out = self.lkb_origin(x)
         | 
| 474 | 
            +
                        if hasattr(self, "small_conv"):
         | 
| 475 | 
            +
                            out += self.small_conv(x)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    return self.activation(self.se(out))
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 480 | 
            +
                    """Method to obtain re-parameterized kernel and bias.
         | 
| 481 | 
            +
                    Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
         | 
| 482 | 
            +
                    Returns:
         | 
| 483 | 
            +
                        Tuple of (kernel, bias) after fusing branches.
         | 
| 484 | 
            +
                    """
         | 
| 485 | 
            +
                    eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
         | 
| 486 | 
            +
                    if hasattr(self, "small_conv"):
         | 
| 487 | 
            +
                        small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
         | 
| 488 | 
            +
                        eq_b += small_b
         | 
| 489 | 
            +
                        eq_k += nn.functional.pad(
         | 
| 490 | 
            +
                            small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
         | 
| 491 | 
            +
                        )
         | 
| 492 | 
            +
                    return eq_k, eq_b
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                def reparameterize(self) -> None:
         | 
| 495 | 
            +
                    """
         | 
| 496 | 
            +
                    Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
         | 
| 497 | 
            +
                    https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
         | 
| 498 | 
            +
                    architecture used at training time to obtain a plain CNN-like structure
         | 
| 499 | 
            +
                    for inference.
         | 
| 500 | 
            +
                    """
         | 
| 501 | 
            +
                    eq_k, eq_b = self.get_kernel_bias()
         | 
| 502 | 
            +
                    self.lkb_reparam = nn.Conv2d(
         | 
| 503 | 
            +
                        in_channels=self.in_channels,
         | 
| 504 | 
            +
                        out_channels=self.out_channels,
         | 
| 505 | 
            +
                        kernel_size=self.kernel_size,
         | 
| 506 | 
            +
                        stride=self.stride,
         | 
| 507 | 
            +
                        padding=self.padding,
         | 
| 508 | 
            +
                        dilation=self.lkb_origin.conv.dilation,
         | 
| 509 | 
            +
                        groups=self.groups,
         | 
| 510 | 
            +
                        bias=True,
         | 
| 511 | 
            +
                    )
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    self.lkb_reparam.weight.data = eq_k
         | 
| 514 | 
            +
                    self.lkb_reparam.bias.data = eq_b
         | 
| 515 | 
            +
                    self.__delattr__("lkb_origin")
         | 
| 516 | 
            +
                    if hasattr(self, "small_conv"):
         | 
| 517 | 
            +
                        self.__delattr__("small_conv")
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                @staticmethod
         | 
| 520 | 
            +
                def _fuse_bn(
         | 
| 521 | 
            +
                    conv: torch.Tensor, bn: nn.BatchNorm2d
         | 
| 522 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 523 | 
            +
                    """Method to fuse batchnorm layer with conv layer.
         | 
| 524 | 
            +
                    Args:
         | 
| 525 | 
            +
                        conv: Convolutional kernel weights.
         | 
| 526 | 
            +
                        bn: Batchnorm 2d layer.
         | 
| 527 | 
            +
                    Returns:
         | 
| 528 | 
            +
                        Tuple of (kernel, bias) after fusing batchnorm.
         | 
| 529 | 
            +
                    """
         | 
| 530 | 
            +
                    kernel = conv.weight
         | 
| 531 | 
            +
                    running_mean = bn.running_mean
         | 
| 532 | 
            +
                    running_var = bn.running_var
         | 
| 533 | 
            +
                    gamma = bn.weight
         | 
| 534 | 
            +
                    beta = bn.bias
         | 
| 535 | 
            +
                    eps = bn.eps
         | 
| 536 | 
            +
                    std = (running_var + eps).sqrt()
         | 
| 537 | 
            +
                    t = (gamma / std).reshape(-1, 1, 1, 1)
         | 
| 538 | 
            +
                    return kernel * t, beta - running_mean * gamma / std
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
         | 
| 541 | 
            +
                    """Helper method to construct conv-batchnorm layers.
         | 
| 542 | 
            +
                    Args:
         | 
| 543 | 
            +
                        kernel_size: Size of the convolution kernel.
         | 
| 544 | 
            +
                        padding: Zero-padding size.
         | 
| 545 | 
            +
                    Returns:
         | 
| 546 | 
            +
                        A nn.Sequential Conv-BN module.
         | 
| 547 | 
            +
                    """
         | 
| 548 | 
            +
                    # Fallback, sometimes batchnorm tensors
         | 
| 549 | 
            +
                    # do not get instantiated correctly on some processes
         | 
| 550 | 
            +
                    # when using deepspeed + accelerate
         | 
| 551 | 
            +
                    norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
         | 
| 552 | 
            +
                    if norm_layer.weight.shape[0] == 0:
         | 
| 553 | 
            +
                        norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
         | 
| 554 | 
            +
                    if norm_layer.bias.shape[0] == 0:
         | 
| 555 | 
            +
                        norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    mod_list = nn.Sequential()
         | 
| 558 | 
            +
                    mod_list.add_module(
         | 
| 559 | 
            +
                        "conv",
         | 
| 560 | 
            +
                        nn.Conv2d(
         | 
| 561 | 
            +
                            in_channels=self.in_channels,
         | 
| 562 | 
            +
                            out_channels=self.out_channels,
         | 
| 563 | 
            +
                            kernel_size=kernel_size,
         | 
| 564 | 
            +
                            stride=self.stride,
         | 
| 565 | 
            +
                            padding=padding,
         | 
| 566 | 
            +
                            groups=self.groups,
         | 
| 567 | 
            +
                            bias=False,
         | 
| 568 | 
            +
                        ),
         | 
| 569 | 
            +
                    )
         | 
| 570 | 
            +
                    mod_list.add_module("bn", norm_layer)
         | 
| 571 | 
            +
                    return mod_list
         | 
| 572 | 
            +
             | 
| 573 | 
            +
             | 
| 574 | 
            +
            def convolutional_stem(
         | 
| 575 | 
            +
                in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
         | 
| 576 | 
            +
            ) -> nn.Sequential:
         | 
| 577 | 
            +
                """Build convolutional stem with MobileOne blocks.
         | 
| 578 | 
            +
                Args:
         | 
| 579 | 
            +
                    in_channels: Number of input channels.
         | 
| 580 | 
            +
                    out_channels: Number of output channels.
         | 
| 581 | 
            +
                    inference_mode: Flag to instantiate model in inference mode. Default: ``False``
         | 
| 582 | 
            +
                Returns:
         | 
| 583 | 
            +
                    nn.Sequential object with stem elements.
         | 
| 584 | 
            +
                """
         | 
| 585 | 
            +
                return nn.Sequential(
         | 
| 586 | 
            +
                    MobileOneBlock(
         | 
| 587 | 
            +
                        in_channels=in_channels,
         | 
| 588 | 
            +
                        out_channels=out_channels,
         | 
| 589 | 
            +
                        kernel_size=3,
         | 
| 590 | 
            +
                        stride=2,
         | 
| 591 | 
            +
                        padding=1,
         | 
| 592 | 
            +
                        groups=1,
         | 
| 593 | 
            +
                        inference_mode=inference_mode,
         | 
| 594 | 
            +
                        use_se=False,
         | 
| 595 | 
            +
                        num_conv_branches=1,
         | 
| 596 | 
            +
                        use_scale_branch=use_scale_branch
         | 
| 597 | 
            +
                    ),
         | 
| 598 | 
            +
                    MobileOneBlock(
         | 
| 599 | 
            +
                        in_channels=out_channels,
         | 
| 600 | 
            +
                        out_channels=out_channels,
         | 
| 601 | 
            +
                        kernel_size=3,
         | 
| 602 | 
            +
                        stride=2,
         | 
| 603 | 
            +
                        padding=1,
         | 
| 604 | 
            +
                        groups=out_channels,
         | 
| 605 | 
            +
                        inference_mode=inference_mode,
         | 
| 606 | 
            +
                        use_se=False,
         | 
| 607 | 
            +
                        num_conv_branches=1,
         | 
| 608 | 
            +
                        use_scale_branch=use_scale_branch
         | 
| 609 | 
            +
                    ),
         | 
| 610 | 
            +
                    MobileOneBlock(
         | 
| 611 | 
            +
                        in_channels=out_channels,
         | 
| 612 | 
            +
                        out_channels=out_channels,
         | 
| 613 | 
            +
                        kernel_size=1,
         | 
| 614 | 
            +
                        stride=1,
         | 
| 615 | 
            +
                        padding=0,
         | 
| 616 | 
            +
                        groups=1,
         | 
| 617 | 
            +
                        inference_mode=inference_mode,
         | 
| 618 | 
            +
                        use_se=False,
         | 
| 619 | 
            +
                        num_conv_branches=1,
         | 
| 620 | 
            +
                        use_scale_branch=use_scale_branch
         | 
| 621 | 
            +
                    ),
         | 
| 622 | 
            +
                )
         | 
| 623 | 
            +
             | 
| 624 | 
            +
             | 
| 625 | 
            +
            class LayerNormChannel(nn.Module):
         | 
| 626 | 
            +
                """
         | 
| 627 | 
            +
                LayerNorm only for Channel Dimension.
         | 
| 628 | 
            +
                Input: tensor in shape [B, C, H, W]
         | 
| 629 | 
            +
                """
         | 
| 630 | 
            +
                def __init__(self, num_features, eps=1e-05) -> None:
         | 
| 631 | 
            +
                    super().__init__()
         | 
| 632 | 
            +
                    self.weight = nn.Parameter(torch.ones(num_features))
         | 
| 633 | 
            +
                    self.bias = nn.Parameter(torch.zeros(num_features))
         | 
| 634 | 
            +
                    self.eps = eps
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                def forward(self, x) -> torch.Tensor:
         | 
| 637 | 
            +
                    u = x.mean(1, keepdim=True)
         | 
| 638 | 
            +
                    s = (x - u).pow(2).mean(1, keepdim=True)
         | 
| 639 | 
            +
                    x = (x - u) / torch.sqrt(s + self.eps)
         | 
| 640 | 
            +
                    x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
         | 
| 641 | 
            +
                        + self.bias.unsqueeze(-1).unsqueeze(-1)
         | 
| 642 | 
            +
                    return x
         | 
| 643 | 
            +
             | 
| 644 | 
            +
             | 
| 645 | 
            +
            class MHSA(nn.Module):
         | 
| 646 | 
            +
                """Multi-headed Self Attention module.
         | 
| 647 | 
            +
                Source modified from:
         | 
| 648 | 
            +
                https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
         | 
| 649 | 
            +
                """
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                def __init__(
         | 
| 652 | 
            +
                    self,
         | 
| 653 | 
            +
                    dim: int,
         | 
| 654 | 
            +
                    head_dim: int = 32,
         | 
| 655 | 
            +
                    qkv_bias: bool = False,
         | 
| 656 | 
            +
                    attn_drop: float = 0.0,
         | 
| 657 | 
            +
                    proj_drop: float = 0.0,
         | 
| 658 | 
            +
                ) -> None:
         | 
| 659 | 
            +
                    """Build MHSA module that can handle 3D or 4D input tensors.
         | 
| 660 | 
            +
                    Args:
         | 
| 661 | 
            +
                        dim: Number of embedding dimensions.
         | 
| 662 | 
            +
                        head_dim: Number of hidden dimensions per head. Default: ``32``
         | 
| 663 | 
            +
                        qkv_bias: Use bias or not. Default: ``False``
         | 
| 664 | 
            +
                        attn_drop: Dropout rate for attention tensor.
         | 
| 665 | 
            +
                        proj_drop: Dropout rate for projection tensor.
         | 
| 666 | 
            +
                    """
         | 
| 667 | 
            +
                    super().__init__()
         | 
| 668 | 
            +
                    assert dim % head_dim == 0, "dim should be divisible by head_dim"
         | 
| 669 | 
            +
                    self.head_dim = head_dim
         | 
| 670 | 
            +
                    self.num_heads = dim // head_dim
         | 
| 671 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 674 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 675 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 676 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 679 | 
            +
                    shape = x.shape
         | 
| 680 | 
            +
                    B, C, H, W = shape
         | 
| 681 | 
            +
                    N = H * W
         | 
| 682 | 
            +
                    if len(shape) == 4:
         | 
| 683 | 
            +
                        x = torch.flatten(x, start_dim=2).transpose(-2, -1)  # (B, N, C)
         | 
| 684 | 
            +
                    qkv = (
         | 
| 685 | 
            +
                        self.qkv(x)
         | 
| 686 | 
            +
                        .reshape(B, N, 3, self.num_heads, self.head_dim)
         | 
| 687 | 
            +
                        .permute(2, 0, 3, 1, 4)
         | 
| 688 | 
            +
                    )
         | 
| 689 | 
            +
                    q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                    # trick here to make [email protected] more stable
         | 
| 692 | 
            +
                    attn = (q * self.scale) @ k.transpose(-2, -1)
         | 
| 693 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 694 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 697 | 
            +
                    x = self.proj(x)
         | 
| 698 | 
            +
                    x = self.proj_drop(x)
         | 
| 699 | 
            +
                    if len(shape) == 4:
         | 
| 700 | 
            +
                        x = x.transpose(-2, -1).reshape(B, C, H, W)
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                    return x
         | 
| 703 | 
            +
             | 
| 704 | 
            +
             | 
| 705 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 706 | 
            +
                """Convolutional patch embedding layer."""
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                def __init__(
         | 
| 709 | 
            +
                    self,
         | 
| 710 | 
            +
                    patch_size: int,
         | 
| 711 | 
            +
                    stride: int,
         | 
| 712 | 
            +
                    in_channels: int,
         | 
| 713 | 
            +
                    embed_dim: int,
         | 
| 714 | 
            +
                    inference_mode: bool = False,
         | 
| 715 | 
            +
                    use_se: bool = False,
         | 
| 716 | 
            +
                ) -> None:
         | 
| 717 | 
            +
                    """Build patch embedding layer.
         | 
| 718 | 
            +
                    Args:
         | 
| 719 | 
            +
                        patch_size: Patch size for embedding computation.
         | 
| 720 | 
            +
                        stride: Stride for convolutional embedding layer.
         | 
| 721 | 
            +
                        in_channels: Number of channels of input tensor.
         | 
| 722 | 
            +
                        embed_dim: Number of embedding dimensions.
         | 
| 723 | 
            +
                        inference_mode: Flag to instantiate model in inference mode. Default: ``False``
         | 
| 724 | 
            +
                        use_se: If ``True`` SE block will be used.
         | 
| 725 | 
            +
                    """
         | 
| 726 | 
            +
                    super().__init__()
         | 
| 727 | 
            +
                    block = list()
         | 
| 728 | 
            +
                    block.append(
         | 
| 729 | 
            +
                        ReparamLargeKernelConv(
         | 
| 730 | 
            +
                            in_channels=in_channels,
         | 
| 731 | 
            +
                            out_channels=embed_dim,
         | 
| 732 | 
            +
                            kernel_size=patch_size,
         | 
| 733 | 
            +
                            stride=stride,
         | 
| 734 | 
            +
                            groups=in_channels,
         | 
| 735 | 
            +
                            small_kernel=3,
         | 
| 736 | 
            +
                            inference_mode=inference_mode,
         | 
| 737 | 
            +
                            use_se=use_se,
         | 
| 738 | 
            +
                        )
         | 
| 739 | 
            +
                    )
         | 
| 740 | 
            +
                    block.append(
         | 
| 741 | 
            +
                        MobileOneBlock(
         | 
| 742 | 
            +
                            in_channels=embed_dim,
         | 
| 743 | 
            +
                            out_channels=embed_dim,
         | 
| 744 | 
            +
                            kernel_size=1,
         | 
| 745 | 
            +
                            stride=1,
         | 
| 746 | 
            +
                            padding=0,
         | 
| 747 | 
            +
                            groups=1,
         | 
| 748 | 
            +
                            inference_mode=inference_mode,
         | 
| 749 | 
            +
                            use_se=False,
         | 
| 750 | 
            +
                            num_conv_branches=1,
         | 
| 751 | 
            +
                        )
         | 
| 752 | 
            +
                    )
         | 
| 753 | 
            +
                    self.proj = nn.Sequential(*block)
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 756 | 
            +
                    x = self.proj(x)
         | 
| 757 | 
            +
                    return x
         | 
| 758 | 
            +
             | 
| 759 | 
            +
             | 
| 760 | 
            +
            class RepMixer(nn.Module):
         | 
| 761 | 
            +
                """Reparameterizable token mixer.
         | 
| 762 | 
            +
                For more details, please refer to our paper:
         | 
| 763 | 
            +
                `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
         | 
| 764 | 
            +
                """
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                def __init__(
         | 
| 767 | 
            +
                    self,
         | 
| 768 | 
            +
                    dim,
         | 
| 769 | 
            +
                    kernel_size=3,
         | 
| 770 | 
            +
                    use_layer_scale=True,
         | 
| 771 | 
            +
                    layer_scale_init_value=1e-5,
         | 
| 772 | 
            +
                    inference_mode: bool = False,
         | 
| 773 | 
            +
                ):
         | 
| 774 | 
            +
                    """Build RepMixer Module.
         | 
| 775 | 
            +
                    Args:
         | 
| 776 | 
            +
                        dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
         | 
| 777 | 
            +
                        kernel_size: Kernel size for spatial mixing. Default: 3
         | 
| 778 | 
            +
                        use_layer_scale: If True, learnable layer scale is used. Default: ``True``
         | 
| 779 | 
            +
                        layer_scale_init_value: Initial value for layer scale. Default: 1e-5
         | 
| 780 | 
            +
                        inference_mode: If True, instantiates model in inference mode. Default: ``False``
         | 
| 781 | 
            +
                    """
         | 
| 782 | 
            +
                    super().__init__()
         | 
| 783 | 
            +
                    self.dim = dim
         | 
| 784 | 
            +
                    self.kernel_size = kernel_size
         | 
| 785 | 
            +
                    self.inference_mode = inference_mode
         | 
| 786 | 
            +
             | 
| 787 | 
            +
                    if inference_mode:
         | 
| 788 | 
            +
                        self.reparam_conv = nn.Conv2d(
         | 
| 789 | 
            +
                            in_channels=self.dim,
         | 
| 790 | 
            +
                            out_channels=self.dim,
         | 
| 791 | 
            +
                            kernel_size=self.kernel_size,
         | 
| 792 | 
            +
                            stride=1,
         | 
| 793 | 
            +
                            padding=self.kernel_size // 2,
         | 
| 794 | 
            +
                            groups=self.dim,
         | 
| 795 | 
            +
                            bias=True,
         | 
| 796 | 
            +
                        )
         | 
| 797 | 
            +
                    else:
         | 
| 798 | 
            +
                        self.norm = MobileOneBlock(
         | 
| 799 | 
            +
                            dim,
         | 
| 800 | 
            +
                            dim,
         | 
| 801 | 
            +
                            kernel_size,
         | 
| 802 | 
            +
                            padding=kernel_size // 2,
         | 
| 803 | 
            +
                            groups=dim,
         | 
| 804 | 
            +
                            use_act=False,
         | 
| 805 | 
            +
                            use_scale_branch=False,
         | 
| 806 | 
            +
                            num_conv_branches=0,
         | 
| 807 | 
            +
                        )
         | 
| 808 | 
            +
                        self.mixer = MobileOneBlock(
         | 
| 809 | 
            +
                            dim,
         | 
| 810 | 
            +
                            dim,
         | 
| 811 | 
            +
                            kernel_size,
         | 
| 812 | 
            +
                            padding=kernel_size // 2,
         | 
| 813 | 
            +
                            groups=dim,
         | 
| 814 | 
            +
                            use_act=False,
         | 
| 815 | 
            +
                        )
         | 
| 816 | 
            +
                        self.use_layer_scale = use_layer_scale
         | 
| 817 | 
            +
                        if use_layer_scale:
         | 
| 818 | 
            +
                            self.layer_scale = nn.Parameter(
         | 
| 819 | 
            +
                                layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
         | 
| 820 | 
            +
                            )
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 823 | 
            +
                    if hasattr(self, "reparam_conv"):
         | 
| 824 | 
            +
                        x = self.reparam_conv(x)
         | 
| 825 | 
            +
                        return x
         | 
| 826 | 
            +
                    else:
         | 
| 827 | 
            +
                        if self.use_layer_scale:
         | 
| 828 | 
            +
                            x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
         | 
| 829 | 
            +
                        else:
         | 
| 830 | 
            +
                            x = x + self.mixer(x) - self.norm(x)
         | 
| 831 | 
            +
                        return x
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                def reparameterize(self) -> None:
         | 
| 834 | 
            +
                    """Reparameterize mixer and norm into a single
         | 
| 835 | 
            +
                    convolutional layer for efficient inference.
         | 
| 836 | 
            +
                    """
         | 
| 837 | 
            +
                    if self.inference_mode:
         | 
| 838 | 
            +
                        return
         | 
| 839 | 
            +
             | 
| 840 | 
            +
                    self.mixer.reparameterize()
         | 
| 841 | 
            +
                    self.norm.reparameterize()
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                    if self.use_layer_scale:
         | 
| 844 | 
            +
                        w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
         | 
| 845 | 
            +
                            self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
         | 
| 846 | 
            +
                        )
         | 
| 847 | 
            +
                        b = torch.squeeze(self.layer_scale) * (
         | 
| 848 | 
            +
                            self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
         | 
| 849 | 
            +
                        )
         | 
| 850 | 
            +
                    else:
         | 
| 851 | 
            +
                        w = (
         | 
| 852 | 
            +
                            self.mixer.id_tensor
         | 
| 853 | 
            +
                            + self.mixer.reparam_conv.weight
         | 
| 854 | 
            +
                            - self.norm.reparam_conv.weight
         | 
| 855 | 
            +
                        )
         | 
| 856 | 
            +
                        b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                    self.reparam_conv = nn.Conv2d(
         | 
| 859 | 
            +
                        in_channels=self.dim,
         | 
| 860 | 
            +
                        out_channels=self.dim,
         | 
| 861 | 
            +
                        kernel_size=self.kernel_size,
         | 
| 862 | 
            +
                        stride=1,
         | 
| 863 | 
            +
                        padding=self.kernel_size // 2,
         | 
| 864 | 
            +
                        groups=self.dim,
         | 
| 865 | 
            +
                        bias=True,
         | 
| 866 | 
            +
                    )
         | 
| 867 | 
            +
                    self.reparam_conv.weight.data = w
         | 
| 868 | 
            +
                    self.reparam_conv.bias.data = b
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                    self.__delattr__("mixer")
         | 
| 871 | 
            +
                    self.__delattr__("norm")
         | 
| 872 | 
            +
                    if self.use_layer_scale:
         | 
| 873 | 
            +
                        self.__delattr__("layer_scale")
         | 
| 874 | 
            +
             | 
| 875 | 
            +
             | 
| 876 | 
            +
            class ConvFFN(nn.Module):
         | 
| 877 | 
            +
                """Convolutional FFN Module."""
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                def __init__(
         | 
| 880 | 
            +
                    self,
         | 
| 881 | 
            +
                    in_channels: int,
         | 
| 882 | 
            +
                    hidden_channels: Optional[int] = None,
         | 
| 883 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 884 | 
            +
                    act_layer: nn.Module = nn.GELU,
         | 
| 885 | 
            +
                    drop: float = 0.0,
         | 
| 886 | 
            +
                ) -> None:
         | 
| 887 | 
            +
                    """Build convolutional FFN module.
         | 
| 888 | 
            +
                    Args:
         | 
| 889 | 
            +
                        in_channels: Number of input channels.
         | 
| 890 | 
            +
                        hidden_channels: Number of channels after expansion. Default: None
         | 
| 891 | 
            +
                        out_channels: Number of output channels. Default: None
         | 
| 892 | 
            +
                        act_layer: Activation layer. Default: ``GELU``
         | 
| 893 | 
            +
                        drop: Dropout rate. Default: ``0.0``.
         | 
| 894 | 
            +
                    """
         | 
| 895 | 
            +
                    super().__init__()
         | 
| 896 | 
            +
                    out_channels = out_channels or in_channels
         | 
| 897 | 
            +
                    hidden_channels = hidden_channels or in_channels
         | 
| 898 | 
            +
                    self.conv = nn.Sequential()
         | 
| 899 | 
            +
                    self.conv.add_module(
         | 
| 900 | 
            +
                        "conv",
         | 
| 901 | 
            +
                        nn.Conv2d(
         | 
| 902 | 
            +
                            in_channels=in_channels,
         | 
| 903 | 
            +
                            out_channels=out_channels,
         | 
| 904 | 
            +
                            kernel_size=7,
         | 
| 905 | 
            +
                            padding=3,
         | 
| 906 | 
            +
                            groups=in_channels,
         | 
| 907 | 
            +
                            bias=False,
         | 
| 908 | 
            +
                        ),
         | 
| 909 | 
            +
                    )
         | 
| 910 | 
            +
             | 
| 911 | 
            +
                    # Fallback, sometimes batchnorm tensors
         | 
| 912 | 
            +
                    # do not get instantiated correctly on some processes
         | 
| 913 | 
            +
                    # when using deepspeed + accelerate
         | 
| 914 | 
            +
                    norm_layer = nn.BatchNorm2d(num_features=out_channels)
         | 
| 915 | 
            +
                    if norm_layer.weight.shape[0] == 0:
         | 
| 916 | 
            +
                        norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
         | 
| 917 | 
            +
                    if norm_layer.bias.shape[0] == 0:
         | 
| 918 | 
            +
                        norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
         | 
| 919 | 
            +
             | 
| 920 | 
            +
                    self.conv.add_module("bn", norm_layer)
         | 
| 921 | 
            +
                    self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
         | 
| 922 | 
            +
                    self.act = act_layer()
         | 
| 923 | 
            +
                    self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
         | 
| 924 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 925 | 
            +
                    self.apply(self._init_weights)
         | 
| 926 | 
            +
             | 
| 927 | 
            +
                def _init_weights(self, m: nn.Module) -> None:
         | 
| 928 | 
            +
                    if isinstance(m, nn.Conv2d):
         | 
| 929 | 
            +
                        normal_(m.weight, std=0.02)
         | 
| 930 | 
            +
                        if m.bias is not None:
         | 
| 931 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 932 | 
            +
             | 
| 933 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 934 | 
            +
                    x = self.conv(x)
         | 
| 935 | 
            +
                    x = self.fc1(x)
         | 
| 936 | 
            +
                    x = self.act(x)
         | 
| 937 | 
            +
                    x = self.drop(x)
         | 
| 938 | 
            +
                    x = self.fc2(x)
         | 
| 939 | 
            +
                    x = self.drop(x)
         | 
| 940 | 
            +
                    return x
         | 
| 941 | 
            +
             | 
| 942 | 
            +
             | 
| 943 | 
            +
            class RepCPE(nn.Module):
         | 
| 944 | 
            +
                """Implementation of conditional positional encoding.
         | 
| 945 | 
            +
                For more details refer to paper:
         | 
| 946 | 
            +
                `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
         | 
| 947 | 
            +
                In our implementation, we can reparameterize this module to eliminate a skip connection.
         | 
| 948 | 
            +
                """
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                def __init__(
         | 
| 951 | 
            +
                    self,
         | 
| 952 | 
            +
                    in_channels: int,
         | 
| 953 | 
            +
                    embed_dim: int = 768,
         | 
| 954 | 
            +
                    spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
         | 
| 955 | 
            +
                    inference_mode=False,
         | 
| 956 | 
            +
                ) -> None:
         | 
| 957 | 
            +
                    """Build reparameterizable conditional positional encoding
         | 
| 958 | 
            +
                    Args:
         | 
| 959 | 
            +
                        in_channels: Number of input channels.
         | 
| 960 | 
            +
                        embed_dim: Number of embedding dimensions. Default: 768
         | 
| 961 | 
            +
                        spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
         | 
| 962 | 
            +
                        inference_mode: Flag to instantiate block in inference mode. Default: ``False``
         | 
| 963 | 
            +
                    """
         | 
| 964 | 
            +
                    super(RepCPE, self).__init__()
         | 
| 965 | 
            +
                    if isinstance(spatial_shape, int):
         | 
| 966 | 
            +
                        spatial_shape = tuple([spatial_shape] * 2)
         | 
| 967 | 
            +
                    assert isinstance(spatial_shape, Tuple), (
         | 
| 968 | 
            +
                        f'"spatial_shape" must by a sequence or int, '
         | 
| 969 | 
            +
                        f"get {type(spatial_shape)} instead."
         | 
| 970 | 
            +
                    )
         | 
| 971 | 
            +
                    assert len(spatial_shape) == 2, (
         | 
| 972 | 
            +
                        f'Length of "spatial_shape" should be 2, '
         | 
| 973 | 
            +
                        f"got {len(spatial_shape)} instead."
         | 
| 974 | 
            +
                    )
         | 
| 975 | 
            +
             | 
| 976 | 
            +
                    self.spatial_shape = spatial_shape
         | 
| 977 | 
            +
                    self.embed_dim = embed_dim
         | 
| 978 | 
            +
                    self.in_channels = in_channels
         | 
| 979 | 
            +
                    self.groups = embed_dim
         | 
| 980 | 
            +
             | 
| 981 | 
            +
                    if inference_mode:
         | 
| 982 | 
            +
                        self.reparam_conv = nn.Conv2d(
         | 
| 983 | 
            +
                            in_channels=self.in_channels,
         | 
| 984 | 
            +
                            out_channels=self.embed_dim,
         | 
| 985 | 
            +
                            kernel_size=self.spatial_shape,
         | 
| 986 | 
            +
                            stride=1,
         | 
| 987 | 
            +
                            padding=int(self.spatial_shape[0] // 2),
         | 
| 988 | 
            +
                            groups=self.embed_dim,
         | 
| 989 | 
            +
                            bias=True,
         | 
| 990 | 
            +
                        )
         | 
| 991 | 
            +
                    else:
         | 
| 992 | 
            +
                        self.pe = nn.Conv2d(
         | 
| 993 | 
            +
                            in_channels,
         | 
| 994 | 
            +
                            embed_dim,
         | 
| 995 | 
            +
                            spatial_shape,
         | 
| 996 | 
            +
                            1,
         | 
| 997 | 
            +
                            int(spatial_shape[0] // 2),
         | 
| 998 | 
            +
                            bias=True,
         | 
| 999 | 
            +
                            groups=embed_dim,
         | 
| 1000 | 
            +
                        )
         | 
| 1001 | 
            +
             | 
| 1002 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 1003 | 
            +
                    if hasattr(self, "reparam_conv"):
         | 
| 1004 | 
            +
                        x = self.reparam_conv(x)
         | 
| 1005 | 
            +
                        return x
         | 
| 1006 | 
            +
                    else:
         | 
| 1007 | 
            +
                        x = self.pe(x) + x
         | 
| 1008 | 
            +
                        return x
         | 
| 1009 | 
            +
             | 
| 1010 | 
            +
                def reparameterize(self) -> None:
         | 
| 1011 | 
            +
                    # Build equivalent Id tensor
         | 
| 1012 | 
            +
                    input_dim = self.in_channels // self.groups
         | 
| 1013 | 
            +
                    kernel_value = torch.zeros(
         | 
| 1014 | 
            +
                        (
         | 
| 1015 | 
            +
                            self.in_channels,
         | 
| 1016 | 
            +
                            input_dim,
         | 
| 1017 | 
            +
                            self.spatial_shape[0],
         | 
| 1018 | 
            +
                            self.spatial_shape[1],
         | 
| 1019 | 
            +
                        ),
         | 
| 1020 | 
            +
                        dtype=self.pe.weight.dtype,
         | 
| 1021 | 
            +
                        device=self.pe.weight.device,
         | 
| 1022 | 
            +
                    )
         | 
| 1023 | 
            +
                    for i in range(self.in_channels):
         | 
| 1024 | 
            +
                        kernel_value[
         | 
| 1025 | 
            +
                            i,
         | 
| 1026 | 
            +
                            i % input_dim,
         | 
| 1027 | 
            +
                            self.spatial_shape[0] // 2,
         | 
| 1028 | 
            +
                            self.spatial_shape[1] // 2,
         | 
| 1029 | 
            +
                        ] = 1
         | 
| 1030 | 
            +
                    id_tensor = kernel_value
         | 
| 1031 | 
            +
             | 
| 1032 | 
            +
                    # Reparameterize Id tensor and conv
         | 
| 1033 | 
            +
                    w_final = id_tensor + self.pe.weight
         | 
| 1034 | 
            +
                    b_final = self.pe.bias
         | 
| 1035 | 
            +
             | 
| 1036 | 
            +
                    # Introduce reparam conv
         | 
| 1037 | 
            +
                    self.reparam_conv = nn.Conv2d(
         | 
| 1038 | 
            +
                        in_channels=self.in_channels,
         | 
| 1039 | 
            +
                        out_channels=self.embed_dim,
         | 
| 1040 | 
            +
                        kernel_size=self.spatial_shape,
         | 
| 1041 | 
            +
                        stride=1,
         | 
| 1042 | 
            +
                        padding=int(self.spatial_shape[0] // 2),
         | 
| 1043 | 
            +
                        groups=self.embed_dim,
         | 
| 1044 | 
            +
                        bias=True,
         | 
| 1045 | 
            +
                    )
         | 
| 1046 | 
            +
                    self.reparam_conv.weight.data = w_final
         | 
| 1047 | 
            +
                    self.reparam_conv.bias.data = b_final
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                    self.__delattr__("pe")
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
             | 
| 1052 | 
            +
            class RepMixerBlock(nn.Module):
         | 
| 1053 | 
            +
                """Implementation of Metaformer block with RepMixer as token mixer.
         | 
| 1054 | 
            +
                For more details on Metaformer structure, please refer to:
         | 
| 1055 | 
            +
                `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
         | 
| 1056 | 
            +
                """
         | 
| 1057 | 
            +
             | 
| 1058 | 
            +
                def __init__(
         | 
| 1059 | 
            +
                    self,
         | 
| 1060 | 
            +
                    dim: int,
         | 
| 1061 | 
            +
                    kernel_size: int = 3,
         | 
| 1062 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 1063 | 
            +
                    act_layer: nn.Module = nn.GELU,
         | 
| 1064 | 
            +
                    drop: float = 0.0,
         | 
| 1065 | 
            +
                    drop_path: float = 0.0,
         | 
| 1066 | 
            +
                    use_layer_scale: bool = True,
         | 
| 1067 | 
            +
                    layer_scale_init_value: float = 1e-5,
         | 
| 1068 | 
            +
                    inference_mode: bool = False,
         | 
| 1069 | 
            +
                ):
         | 
| 1070 | 
            +
                    """Build RepMixer Block.
         | 
| 1071 | 
            +
                    Args:
         | 
| 1072 | 
            +
                        dim: Number of embedding dimensions.
         | 
| 1073 | 
            +
                        kernel_size: Kernel size for repmixer. Default: 3
         | 
| 1074 | 
            +
                        mlp_ratio: MLP expansion ratio. Default: 4.0
         | 
| 1075 | 
            +
                        act_layer: Activation layer. Default: ``nn.GELU``
         | 
| 1076 | 
            +
                        drop: Dropout rate. Default: 0.0
         | 
| 1077 | 
            +
                        drop_path: Drop path rate. Default: 0.0
         | 
| 1078 | 
            +
                        use_layer_scale: Flag to turn on layer scale. Default: ``True``
         | 
| 1079 | 
            +
                        layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
         | 
| 1080 | 
            +
                        inference_mode: Flag to instantiate block in inference mode. Default: ``False``
         | 
| 1081 | 
            +
                    """
         | 
| 1082 | 
            +
             | 
| 1083 | 
            +
                    super().__init__()
         | 
| 1084 | 
            +
             | 
| 1085 | 
            +
                    self.token_mixer = RepMixer(
         | 
| 1086 | 
            +
                        dim,
         | 
| 1087 | 
            +
                        kernel_size=kernel_size,
         | 
| 1088 | 
            +
                        use_layer_scale=use_layer_scale,
         | 
| 1089 | 
            +
                        layer_scale_init_value=layer_scale_init_value,
         | 
| 1090 | 
            +
                        inference_mode=inference_mode,
         | 
| 1091 | 
            +
                    )
         | 
| 1092 | 
            +
             | 
| 1093 | 
            +
                    assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
         | 
| 1094 | 
            +
                        mlp_ratio
         | 
| 1095 | 
            +
                    )
         | 
| 1096 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 1097 | 
            +
                    self.convffn = ConvFFN(
         | 
| 1098 | 
            +
                        in_channels=dim,
         | 
| 1099 | 
            +
                        hidden_channels=mlp_hidden_dim,
         | 
| 1100 | 
            +
                        act_layer=act_layer,
         | 
| 1101 | 
            +
                        drop=drop,
         | 
| 1102 | 
            +
                    )
         | 
| 1103 | 
            +
             | 
| 1104 | 
            +
                    # Drop Path
         | 
| 1105 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
                    # Layer Scale
         | 
| 1108 | 
            +
                    self.use_layer_scale = use_layer_scale
         | 
| 1109 | 
            +
                    if use_layer_scale:
         | 
| 1110 | 
            +
                        self.layer_scale = nn.Parameter(
         | 
| 1111 | 
            +
                            layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
         | 
| 1112 | 
            +
                        )
         | 
| 1113 | 
            +
             | 
| 1114 | 
            +
                def forward(self, x):
         | 
| 1115 | 
            +
                    if self.use_layer_scale:
         | 
| 1116 | 
            +
                        x = self.token_mixer(x)
         | 
| 1117 | 
            +
                        x = x + self.drop_path(self.layer_scale * self.convffn(x))
         | 
| 1118 | 
            +
                    else:
         | 
| 1119 | 
            +
                        x = self.token_mixer(x)
         | 
| 1120 | 
            +
                        x = x + self.drop_path(self.convffn(x))
         | 
| 1121 | 
            +
                    return x
         | 
| 1122 | 
            +
             | 
| 1123 | 
            +
             | 
| 1124 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 1125 | 
            +
                """Implementation of metaformer block with MHSA as token mixer.
         | 
| 1126 | 
            +
                For more details on Metaformer structure, please refer to:
         | 
| 1127 | 
            +
                `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
         | 
| 1128 | 
            +
                """
         | 
| 1129 | 
            +
             | 
| 1130 | 
            +
                def __init__(
         | 
| 1131 | 
            +
                    self,
         | 
| 1132 | 
            +
                    dim: int,
         | 
| 1133 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 1134 | 
            +
                    act_layer: nn.Module = nn.GELU,
         | 
| 1135 | 
            +
                    norm_layer: nn.Module = nn.BatchNorm2d,
         | 
| 1136 | 
            +
                    drop: float = 0.0,
         | 
| 1137 | 
            +
                    drop_path: float = 0.0,
         | 
| 1138 | 
            +
                    use_layer_scale: bool = True,
         | 
| 1139 | 
            +
                    layer_scale_init_value: float = 1e-5,
         | 
| 1140 | 
            +
                ):
         | 
| 1141 | 
            +
                    """Build Attention Block.
         | 
| 1142 | 
            +
                    Args:
         | 
| 1143 | 
            +
                        dim: Number of embedding dimensions.
         | 
| 1144 | 
            +
                        mlp_ratio: MLP expansion ratio. Default: 4.0
         | 
| 1145 | 
            +
                        act_layer: Activation layer. Default: ``nn.GELU``
         | 
| 1146 | 
            +
                        norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
         | 
| 1147 | 
            +
                        drop: Dropout rate. Default: 0.0
         | 
| 1148 | 
            +
                        drop_path: Drop path rate. Default: 0.0
         | 
| 1149 | 
            +
                        use_layer_scale: Flag to turn on layer scale. Default: ``True``
         | 
| 1150 | 
            +
                        layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
         | 
| 1151 | 
            +
                    """
         | 
| 1152 | 
            +
             | 
| 1153 | 
            +
                    super().__init__()
         | 
| 1154 | 
            +
             | 
| 1155 | 
            +
                    # Fallback, sometimes batchnorm tensors
         | 
| 1156 | 
            +
                    # do not get instantiated correctly on some processes
         | 
| 1157 | 
            +
                    # when using deepspeed + accelerate
         | 
| 1158 | 
            +
                    norm_layer_ = norm_layer(num_features=dim)
         | 
| 1159 | 
            +
                    if norm_layer_.weight.shape[0] == 0:
         | 
| 1160 | 
            +
                        norm_layer_.weight = nn.Parameter(torch.zeros(dim))
         | 
| 1161 | 
            +
                    if norm_layer_.bias.shape[0] == 0:
         | 
| 1162 | 
            +
                        norm_layer_.bias = nn.Parameter(torch.zeros(dim))
         | 
| 1163 | 
            +
             | 
| 1164 | 
            +
                    self.norm = norm_layer_
         | 
| 1165 | 
            +
                    self.token_mixer = MHSA(dim=dim)
         | 
| 1166 | 
            +
             | 
| 1167 | 
            +
                    assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
         | 
| 1168 | 
            +
                        mlp_ratio
         | 
| 1169 | 
            +
                    )
         | 
| 1170 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 1171 | 
            +
                    self.convffn = ConvFFN(
         | 
| 1172 | 
            +
                        in_channels=dim,
         | 
| 1173 | 
            +
                        hidden_channels=mlp_hidden_dim,
         | 
| 1174 | 
            +
                        act_layer=act_layer,
         | 
| 1175 | 
            +
                        drop=drop,
         | 
| 1176 | 
            +
                    )
         | 
| 1177 | 
            +
             | 
| 1178 | 
            +
                    # Drop path
         | 
| 1179 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 1180 | 
            +
             | 
| 1181 | 
            +
                    # Layer Scale
         | 
| 1182 | 
            +
                    self.use_layer_scale = use_layer_scale
         | 
| 1183 | 
            +
                    if use_layer_scale:
         | 
| 1184 | 
            +
                        self.layer_scale_1 = nn.Parameter(
         | 
| 1185 | 
            +
                            layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
         | 
| 1186 | 
            +
                        )
         | 
| 1187 | 
            +
                        self.layer_scale_2 = nn.Parameter(
         | 
| 1188 | 
            +
                            layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
         | 
| 1189 | 
            +
                        )
         | 
| 1190 | 
            +
             | 
| 1191 | 
            +
                def forward(self, x):
         | 
| 1192 | 
            +
                    if self.use_layer_scale:
         | 
| 1193 | 
            +
                        x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
         | 
| 1194 | 
            +
                        x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
         | 
| 1195 | 
            +
                    else:
         | 
| 1196 | 
            +
                        x = x + self.drop_path(self.token_mixer(self.norm(x)))
         | 
| 1197 | 
            +
                        x = x + self.drop_path(self.convffn(x))
         | 
| 1198 | 
            +
                    return x
         | 
| 1199 | 
            +
             | 
| 1200 | 
            +
             | 
| 1201 | 
            +
            def basic_blocks(
         | 
| 1202 | 
            +
                dim: int,
         | 
| 1203 | 
            +
                block_index: int,
         | 
| 1204 | 
            +
                num_blocks: List[int],
         | 
| 1205 | 
            +
                token_mixer_type: str,
         | 
| 1206 | 
            +
                kernel_size: int = 3,
         | 
| 1207 | 
            +
                mlp_ratio: float = 4.0,
         | 
| 1208 | 
            +
                act_layer: nn.Module = nn.GELU,
         | 
| 1209 | 
            +
                norm_layer: nn.Module = nn.BatchNorm2d,
         | 
| 1210 | 
            +
                drop_rate: float = 0.0,
         | 
| 1211 | 
            +
                drop_path_rate: float = 0.0,
         | 
| 1212 | 
            +
                use_layer_scale: bool = True,
         | 
| 1213 | 
            +
                layer_scale_init_value: float = 1e-5,
         | 
| 1214 | 
            +
                inference_mode=False,
         | 
| 1215 | 
            +
            ) -> nn.Sequential:
         | 
| 1216 | 
            +
                """Build FastViT blocks within a stage.
         | 
| 1217 | 
            +
                Args:
         | 
| 1218 | 
            +
                    dim: Number of embedding dimensions.
         | 
| 1219 | 
            +
                    block_index: block index.
         | 
| 1220 | 
            +
                    num_blocks: List containing number of blocks per stage.
         | 
| 1221 | 
            +
                    token_mixer_type: Token mixer type.
         | 
| 1222 | 
            +
                    kernel_size: Kernel size for repmixer.
         | 
| 1223 | 
            +
                    mlp_ratio: MLP expansion ratio.
         | 
| 1224 | 
            +
                    act_layer: Activation layer.
         | 
| 1225 | 
            +
                    norm_layer: Normalization layer.
         | 
| 1226 | 
            +
                    drop_rate: Dropout rate.
         | 
| 1227 | 
            +
                    drop_path_rate: Drop path rate.
         | 
| 1228 | 
            +
                    use_layer_scale: Flag to turn on layer scale regularization.
         | 
| 1229 | 
            +
                    layer_scale_init_value: Layer scale value at initialization.
         | 
| 1230 | 
            +
                    inference_mode: Flag to instantiate block in inference mode.
         | 
| 1231 | 
            +
                Returns:
         | 
| 1232 | 
            +
                    nn.Sequential object of all the blocks within the stage.
         | 
| 1233 | 
            +
                """
         | 
| 1234 | 
            +
                blocks = []
         | 
| 1235 | 
            +
                for block_idx in range(num_blocks[block_index]):
         | 
| 1236 | 
            +
                    block_dpr = (
         | 
| 1237 | 
            +
                        drop_path_rate
         | 
| 1238 | 
            +
                        * (block_idx + sum(num_blocks[:block_index]))
         | 
| 1239 | 
            +
                        / (sum(num_blocks) - 1)
         | 
| 1240 | 
            +
                    )
         | 
| 1241 | 
            +
                    if token_mixer_type == "repmixer":
         | 
| 1242 | 
            +
                        blocks.append(
         | 
| 1243 | 
            +
                            RepMixerBlock(
         | 
| 1244 | 
            +
                                dim,
         | 
| 1245 | 
            +
                                kernel_size=kernel_size,
         | 
| 1246 | 
            +
                                mlp_ratio=mlp_ratio,
         | 
| 1247 | 
            +
                                act_layer=act_layer,
         | 
| 1248 | 
            +
                                drop=drop_rate,
         | 
| 1249 | 
            +
                                drop_path=block_dpr,
         | 
| 1250 | 
            +
                                use_layer_scale=use_layer_scale,
         | 
| 1251 | 
            +
                                layer_scale_init_value=layer_scale_init_value,
         | 
| 1252 | 
            +
                                inference_mode=inference_mode,
         | 
| 1253 | 
            +
                            )
         | 
| 1254 | 
            +
                        )
         | 
| 1255 | 
            +
                    elif token_mixer_type == "attention":
         | 
| 1256 | 
            +
                        blocks.append(
         | 
| 1257 | 
            +
                            AttentionBlock(
         | 
| 1258 | 
            +
                                dim,
         | 
| 1259 | 
            +
                                mlp_ratio=mlp_ratio,
         | 
| 1260 | 
            +
                                act_layer=act_layer,
         | 
| 1261 | 
            +
                                norm_layer=norm_layer,
         | 
| 1262 | 
            +
                                drop=drop_rate,
         | 
| 1263 | 
            +
                                drop_path=block_dpr,
         | 
| 1264 | 
            +
                                use_layer_scale=use_layer_scale,
         | 
| 1265 | 
            +
                                layer_scale_init_value=layer_scale_init_value,
         | 
| 1266 | 
            +
                            )
         | 
| 1267 | 
            +
                        )
         | 
| 1268 | 
            +
                    else:
         | 
| 1269 | 
            +
                        raise ValueError(
         | 
| 1270 | 
            +
                            "Token mixer type: {} not supported".format(token_mixer_type)
         | 
| 1271 | 
            +
                        )
         | 
| 1272 | 
            +
                blocks = nn.Sequential(*blocks)
         | 
| 1273 | 
            +
                return blocks
         | 
| 1274 | 
            +
             | 
| 1275 | 
            +
             | 
| 1276 | 
            +
            class GlobalPool2D(nn.Module):
         | 
| 1277 | 
            +
                """This class implements global pooling with linear projection."""
         | 
| 1278 | 
            +
             | 
| 1279 | 
            +
                def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
         | 
| 1280 | 
            +
                    super().__init__()
         | 
| 1281 | 
            +
                    scale = in_dim**-0.5
         | 
| 1282 | 
            +
                    self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
         | 
| 1283 | 
            +
                    self.in_dim = in_dim
         | 
| 1284 | 
            +
                    self.out_dim = out_dim
         | 
| 1285 | 
            +
             | 
| 1286 | 
            +
                def pool(self, x) -> Tensor:
         | 
| 1287 | 
            +
                    if x.dim() == 4:
         | 
| 1288 | 
            +
                        dims = [-2, -1]
         | 
| 1289 | 
            +
                    elif x.dim() == 5:
         | 
| 1290 | 
            +
                        dims = [-3, -2, -1]
         | 
| 1291 | 
            +
                    x = torch.mean(x, dim=dims, keepdim=False)
         | 
| 1292 | 
            +
                    return x
         | 
| 1293 | 
            +
             | 
| 1294 | 
            +
                def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
         | 
| 1295 | 
            +
                    # x is of shape [batch, in_dim]
         | 
| 1296 | 
            +
                    assert (
         | 
| 1297 | 
            +
                        x.dim() == 4
         | 
| 1298 | 
            +
                    ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
         | 
| 1299 | 
            +
                        x.shape
         | 
| 1300 | 
            +
                    )
         | 
| 1301 | 
            +
             | 
| 1302 | 
            +
                    # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
         | 
| 1303 | 
            +
                    x = self.pool(x)
         | 
| 1304 | 
            +
                    # [batch, in_dim]  x [in_dim, out_dim] --> [batch, out_dim]
         | 
| 1305 | 
            +
                    x = x @ self.proj
         | 
| 1306 | 
            +
                    return x
         | 
| 1307 | 
            +
             | 
| 1308 | 
            +
             | 
| 1309 | 
            +
            class FastViT(nn.Module):
         | 
| 1310 | 
            +
                """
         | 
| 1311 | 
            +
                This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
         | 
| 1312 | 
            +
                """
         | 
| 1313 | 
            +
             | 
| 1314 | 
            +
                def __init__(
         | 
| 1315 | 
            +
                    self,
         | 
| 1316 | 
            +
                    layers,
         | 
| 1317 | 
            +
                    token_mixers: Tuple[str, ...],
         | 
| 1318 | 
            +
                    embed_dims=None,
         | 
| 1319 | 
            +
                    mlp_ratios=None,
         | 
| 1320 | 
            +
                    downsamples=None,
         | 
| 1321 | 
            +
                    se_downsamples=None,
         | 
| 1322 | 
            +
                    repmixer_kernel_size=3,
         | 
| 1323 | 
            +
                    norm_layer: nn.Module = nn.BatchNorm2d,
         | 
| 1324 | 
            +
                    act_layer: nn.Module = nn.GELU,
         | 
| 1325 | 
            +
                    num_classes=1000,
         | 
| 1326 | 
            +
                    pos_embs=None,
         | 
| 1327 | 
            +
                    down_patch_size=7,
         | 
| 1328 | 
            +
                    down_stride=2,
         | 
| 1329 | 
            +
                    drop_rate=0.0,
         | 
| 1330 | 
            +
                    drop_path_rate=0.0,
         | 
| 1331 | 
            +
                    use_layer_scale=True,
         | 
| 1332 | 
            +
                    layer_scale_init_value=1e-5,
         | 
| 1333 | 
            +
                    init_cfg=None,
         | 
| 1334 | 
            +
                    pretrained=None,
         | 
| 1335 | 
            +
                    cls_ratio=2.0,
         | 
| 1336 | 
            +
                    inference_mode=False,
         | 
| 1337 | 
            +
                    stem_scale_branch=True,
         | 
| 1338 | 
            +
                    **kwargs,
         | 
| 1339 | 
            +
                ) -> None:
         | 
| 1340 | 
            +
             | 
| 1341 | 
            +
                    super().__init__()
         | 
| 1342 | 
            +
             | 
| 1343 | 
            +
                    self.num_classes = num_classes
         | 
| 1344 | 
            +
                    if len(layers) == 4:
         | 
| 1345 | 
            +
                        self.out_indices = [0, 2, 4, 7]
         | 
| 1346 | 
            +
                    elif len(layers) == 5:
         | 
| 1347 | 
            +
                        self.out_indices = [0, 2, 4, 7, 10]
         | 
| 1348 | 
            +
                    else:
         | 
| 1349 | 
            +
                        raise NotImplementedError("FPN is not implemented for more than 5 stages.")
         | 
| 1350 | 
            +
             | 
| 1351 | 
            +
                    if pos_embs is None:
         | 
| 1352 | 
            +
                        pos_embs = [None] * len(layers)
         | 
| 1353 | 
            +
             | 
| 1354 | 
            +
                    if se_downsamples is None:
         | 
| 1355 | 
            +
                        se_downsamples = [False] * len(layers)
         | 
| 1356 | 
            +
             | 
| 1357 | 
            +
                    # Convolutional stem
         | 
| 1358 | 
            +
                    self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
         | 
| 1359 | 
            +
                                                          use_scale_branch=stem_scale_branch)
         | 
| 1360 | 
            +
             | 
| 1361 | 
            +
                    # Build the main stages of the network architecture
         | 
| 1362 | 
            +
                    network = []
         | 
| 1363 | 
            +
                    for i in range(len(layers)):
         | 
| 1364 | 
            +
                        # Add position embeddings if requested
         | 
| 1365 | 
            +
                        if pos_embs[i] is not None:
         | 
| 1366 | 
            +
                            network.append(
         | 
| 1367 | 
            +
                                pos_embs[i](
         | 
| 1368 | 
            +
                                    embed_dims[i], embed_dims[i], inference_mode=inference_mode
         | 
| 1369 | 
            +
                                )
         | 
| 1370 | 
            +
                            )
         | 
| 1371 | 
            +
                        stage = basic_blocks(
         | 
| 1372 | 
            +
                            embed_dims[i],
         | 
| 1373 | 
            +
                            i,
         | 
| 1374 | 
            +
                            layers,
         | 
| 1375 | 
            +
                            token_mixer_type=token_mixers[i],
         | 
| 1376 | 
            +
                            kernel_size=repmixer_kernel_size,
         | 
| 1377 | 
            +
                            mlp_ratio=mlp_ratios[i],
         | 
| 1378 | 
            +
                            act_layer=act_layer,
         | 
| 1379 | 
            +
                            norm_layer=norm_layer,
         | 
| 1380 | 
            +
                            drop_rate=drop_rate,
         | 
| 1381 | 
            +
                            drop_path_rate=drop_path_rate,
         | 
| 1382 | 
            +
                            use_layer_scale=use_layer_scale,
         | 
| 1383 | 
            +
                            layer_scale_init_value=layer_scale_init_value,
         | 
| 1384 | 
            +
                            inference_mode=inference_mode,
         | 
| 1385 | 
            +
                        )
         | 
| 1386 | 
            +
                        network.append(stage)
         | 
| 1387 | 
            +
                        if i >= len(layers) - 1:
         | 
| 1388 | 
            +
                            break
         | 
| 1389 | 
            +
             | 
| 1390 | 
            +
                        # Patch merging/downsampling between stages.
         | 
| 1391 | 
            +
                        if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
         | 
| 1392 | 
            +
                            network.append(
         | 
| 1393 | 
            +
                                PatchEmbed(
         | 
| 1394 | 
            +
                                    patch_size=down_patch_size,
         | 
| 1395 | 
            +
                                    stride=down_stride,
         | 
| 1396 | 
            +
                                    in_channels=embed_dims[i],
         | 
| 1397 | 
            +
                                    embed_dim=embed_dims[i + 1],
         | 
| 1398 | 
            +
                                    inference_mode=inference_mode,
         | 
| 1399 | 
            +
                                    use_se=se_downsamples[i + 1],
         | 
| 1400 | 
            +
                                )
         | 
| 1401 | 
            +
                            )
         | 
| 1402 | 
            +
                    self.network = nn.ModuleList(network)
         | 
| 1403 | 
            +
             | 
| 1404 | 
            +
                    # Classifier head
         | 
| 1405 | 
            +
                    self.conv_exp = MobileOneBlock(
         | 
| 1406 | 
            +
                        in_channels=embed_dims[-1],
         | 
| 1407 | 
            +
                        out_channels=int(embed_dims[-1] * cls_ratio),
         | 
| 1408 | 
            +
                        kernel_size=3,
         | 
| 1409 | 
            +
                        stride=1,
         | 
| 1410 | 
            +
                        padding=1,
         | 
| 1411 | 
            +
                        groups=embed_dims[-1],
         | 
| 1412 | 
            +
                        inference_mode=inference_mode,
         | 
| 1413 | 
            +
                        use_se=True,
         | 
| 1414 | 
            +
                        num_conv_branches=1,
         | 
| 1415 | 
            +
                    )
         | 
| 1416 | 
            +
                    self.head = (
         | 
| 1417 | 
            +
                        nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
         | 
| 1418 | 
            +
                        if num_classes > 0
         | 
| 1419 | 
            +
                        else nn.Identity()
         | 
| 1420 | 
            +
                    )
         | 
| 1421 | 
            +
                    self.apply(self.cls_init_weights)
         | 
| 1422 | 
            +
                    self.init_cfg = copy.deepcopy(init_cfg)
         | 
| 1423 | 
            +
             | 
| 1424 | 
            +
                def cls_init_weights(self, m: nn.Module) -> None:
         | 
| 1425 | 
            +
                    """Init. for classification"""
         | 
| 1426 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 1427 | 
            +
                        normal_(m.weight, std=0.02)
         | 
| 1428 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 1429 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 1430 | 
            +
             | 
| 1431 | 
            +
                def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 1432 | 
            +
                    x = self.patch_embed(x)
         | 
| 1433 | 
            +
                    return x
         | 
| 1434 | 
            +
             | 
| 1435 | 
            +
                def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
         | 
| 1436 | 
            +
                    for idx, block in enumerate(self.network):
         | 
| 1437 | 
            +
                        x = block(x)
         | 
| 1438 | 
            +
                    return x
         | 
| 1439 | 
            +
             | 
| 1440 | 
            +
                def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
         | 
| 1441 | 
            +
                    # input embedding
         | 
| 1442 | 
            +
                    x = self.forward_embeddings(x)
         | 
| 1443 | 
            +
                    # through backbone
         | 
| 1444 | 
            +
                    x = self.forward_tokens(x)
         | 
| 1445 | 
            +
                    # for image classification/embedding
         | 
| 1446 | 
            +
                    x = self.conv_exp(x)
         | 
| 1447 | 
            +
                    cls_out = self.head(x)
         | 
| 1448 | 
            +
             | 
| 1449 | 
            +
                    out_dict = dict()
         | 
| 1450 | 
            +
                    if kwargs.get("return_image_embeddings", False):
         | 
| 1451 | 
            +
                        out_dict.update({"logits": cls_out})
         | 
| 1452 | 
            +
                        out_dict.update({"image_embeddings": x})
         | 
| 1453 | 
            +
                        return out_dict
         | 
| 1454 | 
            +
                    else:
         | 
| 1455 | 
            +
                        return cls_out
         | 
| 1456 | 
            +
             | 
| 1457 | 
            +
             | 
| 1458 | 
            +
            @register_model
         | 
| 1459 | 
            +
            def fastvithd(pretrained=False, **kwargs):
         | 
| 1460 | 
            +
                """Instantiate FastViTHD model variant."""
         | 
| 1461 | 
            +
                layers = [2, 12, 24, 4, 2]
         | 
| 1462 | 
            +
                embed_dims = [96, 192, 384, 768, 1536]
         | 
| 1463 | 
            +
                mlp_ratios = [4, 4, 4, 4, 4]
         | 
| 1464 | 
            +
                downsamples = [True, True, True, True, True]
         | 
| 1465 | 
            +
                pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
         | 
| 1466 | 
            +
                token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
         | 
| 1467 | 
            +
                model = FastViT(
         | 
| 1468 | 
            +
                    layers,
         | 
| 1469 | 
            +
                    token_mixers=token_mixers,
         | 
| 1470 | 
            +
                    embed_dims=embed_dims,
         | 
| 1471 | 
            +
                    pos_embs=pos_embs,
         | 
| 1472 | 
            +
                    mlp_ratios=mlp_ratios,
         | 
| 1473 | 
            +
                    downsamples=downsamples,
         | 
| 1474 | 
            +
                    norm_layer=LayerNormChannel,
         | 
| 1475 | 
            +
                    stem_scale_branch=False,
         | 
| 1476 | 
            +
                    inference_mode=True,
         | 
| 1477 | 
            +
                    **kwargs,
         | 
| 1478 | 
            +
                )
         | 
| 1479 | 
            +
                model.default_cfg = default_cfgs["fastvit_m"]
         | 
| 1480 | 
            +
                if pretrained:
         | 
| 1481 | 
            +
                    raise ValueError("Functionality not implemented.")
         | 
| 1482 | 
            +
                return model
         | 
| 1483 | 
            +
             | 
| 1484 | 
            +
            def load_model_config(
         | 
| 1485 | 
            +
                    model_name: str,
         | 
| 1486 | 
            +
            ) -> Any:
         | 
| 1487 | 
            +
                model_cfg = {
         | 
| 1488 | 
            +
                    "embed_dim": 768,
         | 
| 1489 | 
            +
                    "image_cfg": {
         | 
| 1490 | 
            +
                        "image_size": 1024,
         | 
| 1491 | 
            +
                        "model_name": "fastvithd",
         | 
| 1492 | 
            +
                        "embed_dim": 3072,
         | 
| 1493 | 
            +
                        "patch_size": 64
         | 
| 1494 | 
            +
                    },
         | 
| 1495 | 
            +
                    "text_cfg": {
         | 
| 1496 | 
            +
                        "context_length": 77,
         | 
| 1497 | 
            +
                        "vocab_size": 49408,
         | 
| 1498 | 
            +
                        "dim": 768,
         | 
| 1499 | 
            +
                        "ffn_multiplier_per_layer": 4.0,
         | 
| 1500 | 
            +
                        "n_heads_per_layer": 12,
         | 
| 1501 | 
            +
                        "n_transformer_layers": 12,
         | 
| 1502 | 
            +
                        "norm_layer": "layer_norm_fp32",
         | 
| 1503 | 
            +
                        "causal_masking": False,
         | 
| 1504 | 
            +
                        "model_name": "base"
         | 
| 1505 | 
            +
                    }
         | 
| 1506 | 
            +
                }
         | 
| 1507 | 
            +
                return model_cfg
         | 
| 1508 | 
            +
             | 
| 1509 | 
            +
             | 
| 1510 | 
            +
            class MCi(nn.Module):
         | 
| 1511 | 
            +
                """
         | 
| 1512 | 
            +
                This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
         | 
| 1513 | 
            +
                """
         | 
| 1514 | 
            +
             | 
| 1515 | 
            +
                def __init__(self, model_name: str, *args, **kwargs) -> None:
         | 
| 1516 | 
            +
                    super().__init__()
         | 
| 1517 | 
            +
                    self.projection_dim = None
         | 
| 1518 | 
            +
                    if "projection_dim" in kwargs:
         | 
| 1519 | 
            +
                        self.projection_dim = kwargs.get("projection_dim")
         | 
| 1520 | 
            +
             | 
| 1521 | 
            +
                    # Create model
         | 
| 1522 | 
            +
                    self.model = create_model(model_name, projection_dim=self.projection_dim)
         | 
| 1523 | 
            +
             | 
| 1524 | 
            +
                    # Build out projection head.
         | 
| 1525 | 
            +
                    if self.projection_dim is not None:
         | 
| 1526 | 
            +
                        if hasattr(self.model, "head"):
         | 
| 1527 | 
            +
                            self.model.head = MCi._update_image_classifier(
         | 
| 1528 | 
            +
                                image_classifier=self.model.head, projection_dim=self.projection_dim
         | 
| 1529 | 
            +
                            )
         | 
| 1530 | 
            +
             | 
| 1531 | 
            +
                def forward(self, x: Any, *args, **kwargs) -> Any:
         | 
| 1532 | 
            +
                    """A forward function of the model."""
         | 
| 1533 | 
            +
                    x = self.model(x, *args, **kwargs)
         | 
| 1534 | 
            +
                    return x
         | 
| 1535 | 
            +
             | 
| 1536 | 
            +
                @staticmethod
         | 
| 1537 | 
            +
                def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
         | 
| 1538 | 
            +
                    """Return the input feature dimension to the image classification head."""
         | 
| 1539 | 
            +
                    in_features = None
         | 
| 1540 | 
            +
                    if isinstance(image_classifier, nn.Sequential):
         | 
| 1541 | 
            +
                        # Classifier that uses nn.Sequential usually has global pooling and
         | 
| 1542 | 
            +
                        # multiple linear layers. Find the first linear layer and get its
         | 
| 1543 | 
            +
                        # in_features
         | 
| 1544 | 
            +
                        for layer in image_classifier:
         | 
| 1545 | 
            +
                            if isinstance(layer, nn.Linear):
         | 
| 1546 | 
            +
                                in_features = layer.in_features
         | 
| 1547 | 
            +
                                break
         | 
| 1548 | 
            +
                    elif isinstance(image_classifier, nn.Linear):
         | 
| 1549 | 
            +
                        in_features = image_classifier.in_features
         | 
| 1550 | 
            +
             | 
| 1551 | 
            +
                    if in_features is None:
         | 
| 1552 | 
            +
                        raise NotImplementedError(
         | 
| 1553 | 
            +
                            f"Cannot get input feature dimension of {image_classifier}."
         | 
| 1554 | 
            +
                        )
         | 
| 1555 | 
            +
                    return in_features
         | 
| 1556 | 
            +
             | 
| 1557 | 
            +
                @staticmethod
         | 
| 1558 | 
            +
                def _update_image_classifier(
         | 
| 1559 | 
            +
                    image_classifier: nn.Module, projection_dim: int, *args, **kwargs
         | 
| 1560 | 
            +
                ) -> nn.Module:
         | 
| 1561 | 
            +
                    in_features = MCi._get_in_feature_dimension(image_classifier)
         | 
| 1562 | 
            +
                    new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
         | 
| 1563 | 
            +
                    return new_img_classifier
         | 
| 1564 | 
            +
             | 
| 1565 | 
            +
             | 
| 1566 | 
            +
            class MobileCLIPVisionTower(nn.Module):
         | 
| 1567 | 
            +
                def __init__(self, vision_tower, args, delay_load=False):
         | 
| 1568 | 
            +
                    super().__init__()
         | 
| 1569 | 
            +
             | 
| 1570 | 
            +
                    self.is_loaded = False
         | 
| 1571 | 
            +
                    self.vision_tower_name = vision_tower
         | 
| 1572 | 
            +
                    self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
         | 
| 1573 | 
            +
                    self.input_image_size = int(vision_tower.split("_")[-1])
         | 
| 1574 | 
            +
             | 
| 1575 | 
            +
                    # Delay load is disabled for now
         | 
| 1576 | 
            +
                    if not delay_load:
         | 
| 1577 | 
            +
                        self.load_model()
         | 
| 1578 | 
            +
                    elif getattr(args, 'unfreeze_mm_vision_tower', False):
         | 
| 1579 | 
            +
                        self.load_model()
         | 
| 1580 | 
            +
                    else:
         | 
| 1581 | 
            +
                        model_cfg = load_model_config(self.vision_tower_name)
         | 
| 1582 | 
            +
                        self.cfg_only = model_cfg
         | 
| 1583 | 
            +
             | 
| 1584 | 
            +
                def load_model(self, device_map=None):
         | 
| 1585 | 
            +
                    if self.is_loaded:
         | 
| 1586 | 
            +
                        print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
         | 
| 1587 | 
            +
                        return
         | 
| 1588 | 
            +
             | 
| 1589 | 
            +
                    # Load model config
         | 
| 1590 | 
            +
                    model_cfg = load_model_config(self.vision_tower_name)
         | 
| 1591 | 
            +
             | 
| 1592 | 
            +
                    # Override default image resolution
         | 
| 1593 | 
            +
                    model_cfg["image_cfg"]["image_size"] = self.input_image_size
         | 
| 1594 | 
            +
             | 
| 1595 | 
            +
                    self.cfg_only = model_cfg
         | 
| 1596 | 
            +
             | 
| 1597 | 
            +
                    # Build HF CLIPImageProcessor with MobileCLIP parameters
         | 
| 1598 | 
            +
                    self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
         | 
| 1599 | 
            +
                                                                         "width": model_cfg["image_cfg"]["image_size"]},
         | 
| 1600 | 
            +
                                                              image_mean=[0.0, 0.0, 0.0],
         | 
| 1601 | 
            +
                                                              image_std=[1.0, 1.0, 1.0],
         | 
| 1602 | 
            +
                                                              size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
         | 
| 1603 | 
            +
             | 
| 1604 | 
            +
                    # Instantiate the image encoder
         | 
| 1605 | 
            +
                    self.vision_tower = MCi(model_name=model_cfg["image_cfg"]["model_name"],
         | 
| 1606 | 
            +
                                                       projection_dim=model_cfg["embed_dim"])
         | 
| 1607 | 
            +
             | 
| 1608 | 
            +
                    if not self.tune_vision_tower:
         | 
| 1609 | 
            +
                        self.vision_tower.requires_grad_(False)
         | 
| 1610 | 
            +
             | 
| 1611 | 
            +
                    self.is_loaded = True
         | 
| 1612 | 
            +
             | 
| 1613 | 
            +
                def feature_select(self, image_forward_outs):
         | 
| 1614 | 
            +
                    # Features from penultimate layer
         | 
| 1615 | 
            +
                    image_features = image_forward_outs["image_embeddings"]
         | 
| 1616 | 
            +
             | 
| 1617 | 
            +
                    # Reshape 4D tensor to 3D
         | 
| 1618 | 
            +
                    B, C, H, W = image_features.shape
         | 
| 1619 | 
            +
                    image_features = image_features.reshape(B, C, H*W)
         | 
| 1620 | 
            +
                    image_features = image_features.transpose(1, 2)
         | 
| 1621 | 
            +
                    return image_features
         | 
| 1622 | 
            +
             | 
| 1623 | 
            +
                def forward(self, images):
         | 
| 1624 | 
            +
                    if self.tune_vision_tower:
         | 
| 1625 | 
            +
                        return self.forward_images(images)
         | 
| 1626 | 
            +
                    else:
         | 
| 1627 | 
            +
                        with torch.no_grad():
         | 
| 1628 | 
            +
                            return self.forward_images(images)
         | 
| 1629 | 
            +
             | 
| 1630 | 
            +
                def forward_images(self, images):
         | 
| 1631 | 
            +
                    if type(images) is list:
         | 
| 1632 | 
            +
                        image_features = []
         | 
| 1633 | 
            +
                        for image in images:
         | 
| 1634 | 
            +
                            image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
         | 
| 1635 | 
            +
                            image_feature = self.feature_select(image_forward_out).to(image.dtype)
         | 
| 1636 | 
            +
                            image_features.append(image_feature)
         | 
| 1637 | 
            +
                    else:
         | 
| 1638 | 
            +
                        image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
         | 
| 1639 | 
            +
                        image_features = self.feature_select(image_forward_outs).to(images.dtype)
         | 
| 1640 | 
            +
             | 
| 1641 | 
            +
                    return image_features
         | 
| 1642 | 
            +
             | 
| 1643 | 
            +
                @property
         | 
| 1644 | 
            +
                def dummy_feature(self):
         | 
| 1645 | 
            +
                    return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
         | 
| 1646 | 
            +
             | 
| 1647 | 
            +
                @property
         | 
| 1648 | 
            +
                def dtype(self):
         | 
| 1649 | 
            +
                    return next(self.vision_tower.parameters()).dtype
         | 
| 1650 | 
            +
             | 
| 1651 | 
            +
                @property
         | 
| 1652 | 
            +
                def device(self):
         | 
| 1653 | 
            +
                    return next(self.vision_tower.parameters()).device
         | 
| 1654 | 
            +
             | 
| 1655 | 
            +
                @property
         | 
| 1656 | 
            +
                def config(self):
         | 
| 1657 | 
            +
                    return self.cfg_only
         | 
| 1658 | 
            +
             | 
| 1659 | 
            +
                @property
         | 
| 1660 | 
            +
                def hidden_size(self):
         | 
| 1661 | 
            +
                    return self.config["image_cfg"]["embed_dim"]
         | 
| 1662 | 
            +
             | 
| 1663 | 
            +
                @property
         | 
| 1664 | 
            +
                def num_patches_per_side(self):
         | 
| 1665 | 
            +
                    return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
         | 
| 1666 | 
            +
             | 
| 1667 | 
            +
                @property
         | 
| 1668 | 
            +
                def num_patches(self):
         | 
| 1669 | 
            +
                    return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
         | 
| 1670 | 
            +
             | 
| 1671 | 
            +
            class IdentityMap(nn.Module):
         | 
| 1672 | 
            +
                def __init__(self):
         | 
| 1673 | 
            +
                    super().__init__()
         | 
| 1674 | 
            +
             | 
| 1675 | 
            +
                def forward(self, x, *args, **kwargs):
         | 
| 1676 | 
            +
                    return x
         | 
| 1677 | 
            +
             | 
| 1678 | 
            +
                @property
         | 
| 1679 | 
            +
                def config(self):
         | 
| 1680 | 
            +
                    return {"mm_projector_type": 'identity'}
         | 
| 1681 | 
            +
             | 
| 1682 | 
            +
            def build_vision_projector(config, delay_load=False, **kwargs):
         | 
| 1683 | 
            +
                projector_type = getattr(config, 'mm_projector_type', 'linear')
         | 
| 1684 | 
            +
             | 
| 1685 | 
            +
                if projector_type == 'linear':
         | 
| 1686 | 
            +
                    return nn.Linear(config.mm_hidden_size, config.hidden_size)
         | 
| 1687 | 
            +
             | 
| 1688 | 
            +
                mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
         | 
| 1689 | 
            +
                if mlp_gelu_match:
         | 
| 1690 | 
            +
                    mlp_depth = int(mlp_gelu_match.group(1))
         | 
| 1691 | 
            +
                    modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
         | 
| 1692 | 
            +
                    for _ in range(1, mlp_depth):
         | 
| 1693 | 
            +
                        modules.append(nn.GELU())
         | 
| 1694 | 
            +
                        modules.append(nn.Linear(config.hidden_size, config.hidden_size))
         | 
| 1695 | 
            +
                    return nn.Sequential(*modules)
         | 
| 1696 | 
            +
             | 
| 1697 | 
            +
                if projector_type == 'identity':
         | 
| 1698 | 
            +
                    return IdentityMap()
         | 
| 1699 | 
            +
             | 
| 1700 | 
            +
                raise ValueError(f'Unknown projector type: {projector_type}')
         | 
| 1701 | 
            +
             | 
| 1702 | 
            +
            def build_vision_tower(vision_tower_cfg, **kwargs):
         | 
| 1703 | 
            +
                vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
         | 
| 1704 | 
            +
                return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
         | 
| 1705 | 
            +
             | 
| 1706 | 
            +
            class LlavaMetaModel:
         | 
| 1707 | 
            +
             | 
| 1708 | 
            +
                def __init__(self, config):
         | 
| 1709 | 
            +
                    super(LlavaMetaModel, self).__init__(config)
         | 
| 1710 | 
            +
             | 
| 1711 | 
            +
                    if hasattr(config, "mm_vision_tower"):
         | 
| 1712 | 
            +
                        self.vision_tower = build_vision_tower(config, delay_load=True)
         | 
| 1713 | 
            +
                        self.mm_projector = build_vision_projector(config)
         | 
| 1714 | 
            +
             | 
| 1715 | 
            +
                        if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
         | 
| 1716 | 
            +
                            self.image_newline = nn.Parameter(
         | 
| 1717 | 
            +
                                torch.empty(config.hidden_size, dtype=self.dtype)
         | 
| 1718 | 
            +
                            )
         | 
| 1719 | 
            +
             | 
| 1720 | 
            +
                def get_vision_tower(self):
         | 
| 1721 | 
            +
                    vision_tower = getattr(self, 'vision_tower', None)
         | 
| 1722 | 
            +
                    if type(vision_tower) is list:
         | 
| 1723 | 
            +
                        vision_tower = vision_tower[0]
         | 
| 1724 | 
            +
                    return vision_tower
         | 
| 1725 | 
            +
             | 
| 1726 | 
            +
                def initialize_vision_modules(self, model_args, fsdp=None):
         | 
| 1727 | 
            +
                    vision_tower = model_args.vision_tower
         | 
| 1728 | 
            +
                    mm_vision_select_layer = model_args.mm_vision_select_layer
         | 
| 1729 | 
            +
                    mm_vision_select_feature = model_args.mm_vision_select_feature
         | 
| 1730 | 
            +
                    pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
         | 
| 1731 | 
            +
                    mm_patch_merge_type = model_args.mm_patch_merge_type
         | 
| 1732 | 
            +
             | 
| 1733 | 
            +
                    self.config.mm_vision_tower = vision_tower
         | 
| 1734 | 
            +
             | 
| 1735 | 
            +
                    if self.get_vision_tower() is None:
         | 
| 1736 | 
            +
                        vision_tower = build_vision_tower(model_args)
         | 
| 1737 | 
            +
             | 
| 1738 | 
            +
                        if fsdp is not None and len(fsdp) > 0:
         | 
| 1739 | 
            +
                            self.vision_tower = [vision_tower]
         | 
| 1740 | 
            +
                        else:
         | 
| 1741 | 
            +
                            self.vision_tower = vision_tower
         | 
| 1742 | 
            +
                    else:
         | 
| 1743 | 
            +
                        if fsdp is not None and len(fsdp) > 0:
         | 
| 1744 | 
            +
                            vision_tower = self.vision_tower[0]
         | 
| 1745 | 
            +
                        else:
         | 
| 1746 | 
            +
                            vision_tower = self.vision_tower
         | 
| 1747 | 
            +
                        vision_tower.load_model()
         | 
| 1748 | 
            +
             | 
| 1749 | 
            +
                    self.config.use_mm_proj = True
         | 
| 1750 | 
            +
                    self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
         | 
| 1751 | 
            +
                    self.config.mm_hidden_size = vision_tower.hidden_size
         | 
| 1752 | 
            +
                    self.config.mm_vision_select_layer = mm_vision_select_layer
         | 
| 1753 | 
            +
                    self.config.mm_vision_select_feature = mm_vision_select_feature
         | 
| 1754 | 
            +
                    self.config.mm_patch_merge_type = mm_patch_merge_type
         | 
| 1755 | 
            +
             | 
| 1756 | 
            +
                    if getattr(self, 'mm_projector', None) is None:
         | 
| 1757 | 
            +
                        self.mm_projector = build_vision_projector(self.config)
         | 
| 1758 | 
            +
             | 
| 1759 | 
            +
                        if 'unpad' in mm_patch_merge_type:
         | 
| 1760 | 
            +
                            embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
         | 
| 1761 | 
            +
                            self.image_newline = nn.Parameter(
         | 
| 1762 | 
            +
                                torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
         | 
| 1763 | 
            +
                            )
         | 
| 1764 | 
            +
                    else:
         | 
| 1765 | 
            +
                        # In case it is frozen by LoRA
         | 
| 1766 | 
            +
                        for p in self.mm_projector.parameters():
         | 
| 1767 | 
            +
                            p.requires_grad = True
         | 
| 1768 | 
            +
             | 
| 1769 | 
            +
                    if pretrain_mm_mlp_adapter is not None:
         | 
| 1770 | 
            +
                        mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
         | 
| 1771 | 
            +
             | 
| 1772 | 
            +
                        def get_w(weights, keyword):
         | 
| 1773 | 
            +
                            return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
         | 
| 1774 | 
            +
             | 
| 1775 | 
            +
                        self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
         | 
| 1776 | 
            +
             | 
| 1777 | 
            +
            def select_best_resolution(original_size, possible_resolutions):
         | 
| 1778 | 
            +
                """
         | 
| 1779 | 
            +
                Selects the best resolution from a list of possible resolutions based on the original size.
         | 
| 1780 | 
            +
                Args:
         | 
| 1781 | 
            +
                    original_size (tuple): The original size of the image in the format (width, height).
         | 
| 1782 | 
            +
                    possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
         | 
| 1783 | 
            +
                Returns:
         | 
| 1784 | 
            +
                    tuple: The best fit resolution in the format (width, height).
         | 
| 1785 | 
            +
                """
         | 
| 1786 | 
            +
                original_width, original_height = original_size
         | 
| 1787 | 
            +
                best_fit = None
         | 
| 1788 | 
            +
                max_effective_resolution = 0
         | 
| 1789 | 
            +
                min_wasted_resolution = float('inf')
         | 
| 1790 | 
            +
             | 
| 1791 | 
            +
                for width, height in possible_resolutions:
         | 
| 1792 | 
            +
                    scale = min(width / original_width, height / original_height)
         | 
| 1793 | 
            +
                    downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
         | 
| 1794 | 
            +
                    effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
         | 
| 1795 | 
            +
                    wasted_resolution = (width * height) - effective_resolution
         | 
| 1796 | 
            +
             | 
| 1797 | 
            +
                    if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
         | 
| 1798 | 
            +
                        max_effective_resolution = effective_resolution
         | 
| 1799 | 
            +
                        min_wasted_resolution = wasted_resolution
         | 
| 1800 | 
            +
                        best_fit = (width, height)
         | 
| 1801 | 
            +
             | 
| 1802 | 
            +
                return best_fit
         | 
| 1803 | 
            +
             | 
| 1804 | 
            +
            def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
         | 
| 1805 | 
            +
                """
         | 
| 1806 | 
            +
                Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
         | 
| 1807 | 
            +
                Args:
         | 
| 1808 | 
            +
                    image_size (tuple): The size of the input image in the format (width, height).
         | 
| 1809 | 
            +
                    grid_pinpoints (str): A string representation of a list of possible resolutions.
         | 
| 1810 | 
            +
                    patch_size (int): The size of each image patch.
         | 
| 1811 | 
            +
                Returns:
         | 
| 1812 | 
            +
                    tuple: The shape of the image patch grid in the format (width, height).
         | 
| 1813 | 
            +
                """
         | 
| 1814 | 
            +
                import ast
         | 
| 1815 | 
            +
                if type(grid_pinpoints) is list:
         | 
| 1816 | 
            +
                    possible_resolutions = grid_pinpoints
         | 
| 1817 | 
            +
                else:
         | 
| 1818 | 
            +
                    possible_resolutions = ast.literal_eval(grid_pinpoints)
         | 
| 1819 | 
            +
                width, height = select_best_resolution(image_size, possible_resolutions)
         | 
| 1820 | 
            +
                return width // patch_size, height // patch_size
         | 
| 1821 | 
            +
             | 
| 1822 | 
            +
            class LlavaMetaForCausalLM(ABC):
         | 
| 1823 | 
            +
             | 
| 1824 | 
            +
                @abstractmethod
         | 
| 1825 | 
            +
                def get_model(self):
         | 
| 1826 | 
            +
                    pass
         | 
| 1827 | 
            +
             | 
| 1828 | 
            +
                def get_vision_tower(self):
         | 
| 1829 | 
            +
                    return self.get_model().get_vision_tower()
         | 
| 1830 | 
            +
             | 
| 1831 | 
            +
                def encode_images(self, images):
         | 
| 1832 | 
            +
                    image_features = self.get_model().get_vision_tower()(images)
         | 
| 1833 | 
            +
                    image_features = self.get_model().mm_projector(image_features)
         | 
| 1834 | 
            +
                    return image_features
         | 
| 1835 | 
            +
             | 
| 1836 | 
            +
                def prepare_inputs_labels_for_multimodal(
         | 
| 1837 | 
            +
                    self, input_ids, position_ids, attention_mask, past_key_values, labels,
         | 
| 1838 | 
            +
                    images, image_sizes=None
         | 
| 1839 | 
            +
                ):
         | 
| 1840 | 
            +
                    vision_tower = self.get_vision_tower()
         | 
| 1841 | 
            +
                    if vision_tower is None or images is None or input_ids.shape[1] == 1:
         | 
| 1842 | 
            +
                        return input_ids, position_ids, attention_mask, past_key_values, None, labels
         | 
| 1843 | 
            +
             | 
| 1844 | 
            +
                    if type(images) is list or images.ndim == 5:
         | 
| 1845 | 
            +
                        if type(images) is list:
         | 
| 1846 | 
            +
                            images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
         | 
| 1847 | 
            +
                        concat_images = torch.cat([image for image in images], dim=0)
         | 
| 1848 | 
            +
                        image_features = self.encode_images(concat_images)
         | 
| 1849 | 
            +
                        split_sizes = [image.shape[0] for image in images]
         | 
| 1850 | 
            +
                        image_features = torch.split(image_features, split_sizes, dim=0)
         | 
| 1851 | 
            +
                        mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
         | 
| 1852 | 
            +
                        image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
         | 
| 1853 | 
            +
                        if mm_patch_merge_type == 'flat':
         | 
| 1854 | 
            +
                            image_features = [x.flatten(0, 1) for x in image_features]
         | 
| 1855 | 
            +
                        elif mm_patch_merge_type.startswith('spatial'):
         | 
| 1856 | 
            +
                            new_image_features = []
         | 
| 1857 | 
            +
                            for image_idx, image_feature in enumerate(image_features):
         | 
| 1858 | 
            +
                                if image_feature.shape[0] > 1:
         | 
| 1859 | 
            +
                                    base_image_feature = image_feature[0]
         | 
| 1860 | 
            +
                                    image_feature = image_feature[1:]
         | 
| 1861 | 
            +
                                    height = width = self.get_vision_tower().num_patches_per_side
         | 
| 1862 | 
            +
                                    assert height * width == base_image_feature.shape[0]
         | 
| 1863 | 
            +
                                    if image_aspect_ratio == 'anyres':
         | 
| 1864 | 
            +
                                        if hasattr(self.get_vision_tower(), 's2_image_size'):
         | 
| 1865 | 
            +
                                            img_size = self.get_vision_tower().s2_image_size
         | 
| 1866 | 
            +
                                        elif isinstance(self.get_vision_tower().config, dict):
         | 
| 1867 | 
            +
                                            img_size = self.get_vision_tower().config["image_cfg"]["image_size"]
         | 
| 1868 | 
            +
                                        else:
         | 
| 1869 | 
            +
                                            img_size = self.get_vision_tower().config.image_size
         | 
| 1870 | 
            +
             | 
| 1871 | 
            +
                                        num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size)
         | 
| 1872 | 
            +
                                        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
         | 
| 1873 | 
            +
                                    else:
         | 
| 1874 | 
            +
                                        raise NotImplementedError
         | 
| 1875 | 
            +
                                    if 'unpad' in mm_patch_merge_type:
         | 
| 1876 | 
            +
                                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
         | 
| 1877 | 
            +
                                        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
         | 
| 1878 | 
            +
                                        image_feature = unpad_image(image_feature, image_sizes[image_idx])
         | 
| 1879 | 
            +
                                        image_feature = torch.cat((
         | 
| 1880 | 
            +
                                            image_feature,
         | 
| 1881 | 
            +
                                            self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
         | 
| 1882 | 
            +
                                        ), dim=-1)
         | 
| 1883 | 
            +
                                        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
         | 
| 1884 | 
            +
                                    else:
         | 
| 1885 | 
            +
                                        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
         | 
| 1886 | 
            +
                                        image_feature = image_feature.flatten(0, 3)
         | 
| 1887 | 
            +
                                    image_feature = torch.cat((base_image_feature, image_feature), dim=0)
         | 
| 1888 | 
            +
                                else:
         | 
| 1889 | 
            +
                                    image_feature = image_feature[0]
         | 
| 1890 | 
            +
                                    if 'unpad' in mm_patch_merge_type:
         | 
| 1891 | 
            +
                                        image_feature = torch.cat((
         | 
| 1892 | 
            +
                                            image_feature,
         | 
| 1893 | 
            +
                                            self.model.image_newline[None].to(image_feature.device)
         | 
| 1894 | 
            +
                                        ), dim=0)
         | 
| 1895 | 
            +
                                new_image_features.append(image_feature)
         | 
| 1896 | 
            +
                            image_features = new_image_features
         | 
| 1897 | 
            +
                        else:
         | 
| 1898 | 
            +
                            raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
         | 
| 1899 | 
            +
                    else:
         | 
| 1900 | 
            +
                        image_features = self.encode_images(images)
         | 
| 1901 | 
            +
             | 
| 1902 | 
            +
                    # TODO: image start / end is not implemented here to support pretraining.
         | 
| 1903 | 
            +
                    if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
         | 
| 1904 | 
            +
                        raise NotImplementedError
         | 
| 1905 | 
            +
             | 
| 1906 | 
            +
                    # Let's just add dummy tensors if they do not exist,
         | 
| 1907 | 
            +
                    # it is a headache to deal with None all the time.
         | 
| 1908 | 
            +
                    # But it is not ideal, and if you have a better idea,
         | 
| 1909 | 
            +
                    # please open an issue / submit a PR, thanks.
         | 
| 1910 | 
            +
                    _labels = labels
         | 
| 1911 | 
            +
                    _position_ids = position_ids
         | 
| 1912 | 
            +
                    _attention_mask = attention_mask
         | 
| 1913 | 
            +
                    if attention_mask is None:
         | 
| 1914 | 
            +
                        attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
         | 
| 1915 | 
            +
                    else:
         | 
| 1916 | 
            +
                        attention_mask = attention_mask.bool()
         | 
| 1917 | 
            +
                    if position_ids is None:
         | 
| 1918 | 
            +
                        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
         | 
| 1919 | 
            +
                    if labels is None:
         | 
| 1920 | 
            +
                        labels = torch.full_like(input_ids, IGNORE_INDEX)
         | 
| 1921 | 
            +
             | 
| 1922 | 
            +
                    # remove the padding using attention_mask -- FIXME
         | 
| 1923 | 
            +
                    _input_ids = input_ids
         | 
| 1924 | 
            +
                    input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
         | 
| 1925 | 
            +
                    labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
         | 
| 1926 | 
            +
             | 
| 1927 | 
            +
                    new_input_embeds = []
         | 
| 1928 | 
            +
                    new_labels = []
         | 
| 1929 | 
            +
                    cur_image_idx = 0
         | 
| 1930 | 
            +
                    for batch_idx, cur_input_ids in enumerate(input_ids):
         | 
| 1931 | 
            +
                        num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
         | 
| 1932 | 
            +
                        if num_images == 0:
         | 
| 1933 | 
            +
                            cur_image_features = image_features[cur_image_idx]
         | 
| 1934 | 
            +
                            cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
         | 
| 1935 | 
            +
                            cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
         | 
| 1936 | 
            +
                            new_input_embeds.append(cur_input_embeds)
         | 
| 1937 | 
            +
                            new_labels.append(labels[batch_idx])
         | 
| 1938 | 
            +
                            cur_image_idx += 1
         | 
| 1939 | 
            +
                            continue
         | 
| 1940 | 
            +
             | 
| 1941 | 
            +
                        image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
         | 
| 1942 | 
            +
                        cur_input_ids_noim = []
         | 
| 1943 | 
            +
                        cur_labels = labels[batch_idx]
         | 
| 1944 | 
            +
                        cur_labels_noim = []
         | 
| 1945 | 
            +
                        for i in range(len(image_token_indices) - 1):
         | 
| 1946 | 
            +
                            cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
         | 
| 1947 | 
            +
                            cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
         | 
| 1948 | 
            +
                        split_sizes = [x.shape[0] for x in cur_labels_noim]
         | 
| 1949 | 
            +
                        cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
         | 
| 1950 | 
            +
                        cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
         | 
| 1951 | 
            +
                        cur_new_input_embeds = []
         | 
| 1952 | 
            +
                        cur_new_labels = []
         | 
| 1953 | 
            +
             | 
| 1954 | 
            +
                        for i in range(num_images + 1):
         | 
| 1955 | 
            +
                            cur_new_input_embeds.append(cur_input_embeds_no_im[i])
         | 
| 1956 | 
            +
                            cur_new_labels.append(cur_labels_noim[i])
         | 
| 1957 | 
            +
                            if i < num_images:
         | 
| 1958 | 
            +
                                cur_image_features = image_features[cur_image_idx]
         | 
| 1959 | 
            +
                                cur_image_idx += 1
         | 
| 1960 | 
            +
                                cur_new_input_embeds.append(cur_image_features)
         | 
| 1961 | 
            +
                                cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
         | 
| 1962 | 
            +
             | 
| 1963 | 
            +
                        cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
         | 
| 1964 | 
            +
             | 
| 1965 | 
            +
                        cur_new_input_embeds = torch.cat(cur_new_input_embeds)
         | 
| 1966 | 
            +
                        cur_new_labels = torch.cat(cur_new_labels)
         | 
| 1967 | 
            +
             | 
| 1968 | 
            +
                        new_input_embeds.append(cur_new_input_embeds)
         | 
| 1969 | 
            +
                        new_labels.append(cur_new_labels)
         | 
| 1970 | 
            +
             | 
| 1971 | 
            +
                    # Truncate sequences to max length as image embeddings can make the sequence longer
         | 
| 1972 | 
            +
                    tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
         | 
| 1973 | 
            +
                    if tokenizer_model_max_length is not None:
         | 
| 1974 | 
            +
                        new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
         | 
| 1975 | 
            +
                        new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
         | 
| 1976 | 
            +
             | 
| 1977 | 
            +
                    # Combine them
         | 
| 1978 | 
            +
                    max_len = max(x.shape[0] for x in new_input_embeds)
         | 
| 1979 | 
            +
                    batch_size = len(new_input_embeds)
         | 
| 1980 | 
            +
             | 
| 1981 | 
            +
                    new_input_embeds_padded = []
         | 
| 1982 | 
            +
                    new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
         | 
| 1983 | 
            +
                    attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
         | 
| 1984 | 
            +
                    position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
         | 
| 1985 | 
            +
             | 
| 1986 | 
            +
                    for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
         | 
| 1987 | 
            +
                        cur_len = cur_new_embed.shape[0]
         | 
| 1988 | 
            +
                        if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
         | 
| 1989 | 
            +
                            new_input_embeds_padded.append(torch.cat((
         | 
| 1990 | 
            +
                                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
         | 
| 1991 | 
            +
                                cur_new_embed
         | 
| 1992 | 
            +
                            ), dim=0))
         | 
| 1993 | 
            +
                            if cur_len > 0:
         | 
| 1994 | 
            +
                                new_labels_padded[i, -cur_len:] = cur_new_labels
         | 
| 1995 | 
            +
                                attention_mask[i, -cur_len:] = True
         | 
| 1996 | 
            +
                                position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
         | 
| 1997 | 
            +
                        else:
         | 
| 1998 | 
            +
                            new_input_embeds_padded.append(torch.cat((
         | 
| 1999 | 
            +
                                cur_new_embed,
         | 
| 2000 | 
            +
                                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
         | 
| 2001 | 
            +
                            ), dim=0))
         | 
| 2002 | 
            +
                            if cur_len > 0:
         | 
| 2003 | 
            +
                                new_labels_padded[i, :cur_len] = cur_new_labels
         | 
| 2004 | 
            +
                                attention_mask[i, :cur_len] = True
         | 
| 2005 | 
            +
                                position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
         | 
| 2006 | 
            +
             | 
| 2007 | 
            +
                    new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
         | 
| 2008 | 
            +
             | 
| 2009 | 
            +
                    if _labels is None:
         | 
| 2010 | 
            +
                        new_labels = None
         | 
| 2011 | 
            +
                    else:
         | 
| 2012 | 
            +
                        new_labels = new_labels_padded
         | 
| 2013 | 
            +
             | 
| 2014 | 
            +
                    if _attention_mask is None:
         | 
| 2015 | 
            +
                        attention_mask = None
         | 
| 2016 | 
            +
                    else:
         | 
| 2017 | 
            +
                        attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
         | 
| 2018 | 
            +
             | 
| 2019 | 
            +
                    if _position_ids is None:
         | 
| 2020 | 
            +
                        position_ids = None
         | 
| 2021 | 
            +
             | 
| 2022 | 
            +
                    return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
         | 
| 2023 | 
            +
             | 
| 2024 | 
            +
                def initialize_vision_tokenizer(self, model_args, tokenizer):
         | 
| 2025 | 
            +
                    if model_args.mm_use_im_patch_token:
         | 
| 2026 | 
            +
                        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
         | 
| 2027 | 
            +
                        self.resize_token_embeddings(len(tokenizer))
         | 
| 2028 | 
            +
             | 
| 2029 | 
            +
                    if model_args.mm_use_im_start_end:
         | 
| 2030 | 
            +
                        num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
         | 
| 2031 | 
            +
                        self.resize_token_embeddings(len(tokenizer))
         | 
| 2032 | 
            +
             | 
| 2033 | 
            +
                        if num_new_tokens > 0:
         | 
| 2034 | 
            +
                            input_embeddings = self.get_input_embeddings().weight.data
         | 
| 2035 | 
            +
                            output_embeddings = self.get_output_embeddings().weight.data
         | 
| 2036 | 
            +
             | 
| 2037 | 
            +
                            input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
         | 
| 2038 | 
            +
                                dim=0, keepdim=True)
         | 
| 2039 | 
            +
                            output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
         | 
| 2040 | 
            +
                                dim=0, keepdim=True)
         | 
| 2041 | 
            +
             | 
| 2042 | 
            +
                            input_embeddings[-num_new_tokens:] = input_embeddings_avg
         | 
| 2043 | 
            +
                            output_embeddings[-num_new_tokens:] = output_embeddings_avg
         | 
| 2044 | 
            +
             | 
| 2045 | 
            +
                        if model_args.tune_mm_mlp_adapter:
         | 
| 2046 | 
            +
                            for p in self.get_input_embeddings().parameters():
         | 
| 2047 | 
            +
                                p.requires_grad = True
         | 
| 2048 | 
            +
                            for p in self.get_output_embeddings().parameters():
         | 
| 2049 | 
            +
                                p.requires_grad = False
         | 
| 2050 | 
            +
             | 
| 2051 | 
            +
                        if model_args.pretrain_mm_mlp_adapter:
         | 
| 2052 | 
            +
                            mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
         | 
| 2053 | 
            +
                            embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
         | 
| 2054 | 
            +
                            assert num_new_tokens == 2
         | 
| 2055 | 
            +
                            if input_embeddings.shape == embed_tokens_weight.shape:
         | 
| 2056 | 
            +
                                input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
         | 
| 2057 | 
            +
                            elif embed_tokens_weight.shape[0] == num_new_tokens:
         | 
| 2058 | 
            +
                                input_embeddings[-num_new_tokens:] = embed_tokens_weight
         | 
| 2059 | 
            +
                            else:
         | 
| 2060 | 
            +
                                raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
         | 
| 2061 | 
            +
                    elif model_args.mm_use_im_patch_token:
         | 
| 2062 | 
            +
                        if model_args.tune_mm_mlp_adapter:
         | 
| 2063 | 
            +
                            for p in self.get_input_embeddings().parameters():
         | 
| 2064 | 
            +
                                p.requires_grad = False
         | 
| 2065 | 
            +
                            for p in self.get_output_embeddings().parameters():
         | 
| 2066 | 
            +
                                p.requires_grad = False
         | 
| 2067 | 
            +
             | 
| 2068 | 
            +
             | 
| 2069 | 
            +
            class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
         | 
| 2070 | 
            +
                config_class = LlavaConfig
         | 
| 2071 | 
            +
             | 
| 2072 | 
            +
                def __init__(self, config: Qwen2Config):
         | 
| 2073 | 
            +
                    super(LlavaQwen2Model, self).__init__(config)
         | 
| 2074 | 
            +
             | 
| 2075 | 
            +
             | 
| 2076 | 
            +
            class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
         | 
| 2077 | 
            +
                config_class = LlavaConfig
         | 
| 2078 | 
            +
             | 
| 2079 | 
            +
                def __init__(self, config):
         | 
| 2080 | 
            +
                    super(Qwen2ForCausalLM, self).__init__(config)
         | 
| 2081 | 
            +
                    self.model = LlavaQwen2Model(config)
         | 
| 2082 | 
            +
                    # self.pretraining_tp = config.pretraining_tp
         | 
| 2083 | 
            +
                    self.vocab_size = config.vocab_size
         | 
| 2084 | 
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 2085 | 
            +
             | 
| 2086 | 
            +
                    # Initialize weights and apply final processing
         | 
| 2087 | 
            +
                    self.post_init()
         | 
| 2088 | 
            +
             | 
| 2089 | 
            +
                def get_model(self):
         | 
| 2090 | 
            +
                    return self.model
         | 
| 2091 | 
            +
             | 
| 2092 | 
            +
                def forward(
         | 
| 2093 | 
            +
                    self,
         | 
| 2094 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 2095 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 2096 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 2097 | 
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 2098 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 2099 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 2100 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 2101 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 2102 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 2103 | 
            +
                    images: Optional[torch.FloatTensor] = None,
         | 
| 2104 | 
            +
                    image_sizes: Optional[List[List[int]]] = None,
         | 
| 2105 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 2106 | 
            +
                    cache_position=None,
         | 
| 2107 | 
            +
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         | 
| 2108 | 
            +
             | 
| 2109 | 
            +
                    if inputs_embeds is None:
         | 
| 2110 | 
            +
                        (
         | 
| 2111 | 
            +
                            input_ids,
         | 
| 2112 | 
            +
                            position_ids,
         | 
| 2113 | 
            +
                            attention_mask,
         | 
| 2114 | 
            +
                            past_key_values,
         | 
| 2115 | 
            +
                            inputs_embeds,
         | 
| 2116 | 
            +
                            labels
         | 
| 2117 | 
            +
                        ) = self.prepare_inputs_labels_for_multimodal(
         | 
| 2118 | 
            +
                            input_ids,
         | 
| 2119 | 
            +
                            position_ids,
         | 
| 2120 | 
            +
                            attention_mask,
         | 
| 2121 | 
            +
                            past_key_values,
         | 
| 2122 | 
            +
                            labels,
         | 
| 2123 | 
            +
                            images,
         | 
| 2124 | 
            +
                            image_sizes
         | 
| 2125 | 
            +
                        )
         | 
| 2126 | 
            +
             | 
| 2127 | 
            +
                    return super().forward(
         | 
| 2128 | 
            +
                        input_ids=input_ids,
         | 
| 2129 | 
            +
                        attention_mask=attention_mask,
         | 
| 2130 | 
            +
                        position_ids=position_ids,
         | 
| 2131 | 
            +
                        past_key_values=past_key_values,
         | 
| 2132 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 2133 | 
            +
                        labels=labels,
         | 
| 2134 | 
            +
                        use_cache=use_cache,
         | 
| 2135 | 
            +
                        output_attentions=output_attentions,
         | 
| 2136 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 2137 | 
            +
                        return_dict=return_dict
         | 
| 2138 | 
            +
                    )
         | 
| 2139 | 
            +
             | 
| 2140 | 
            +
                @torch.no_grad()
         | 
| 2141 | 
            +
                def generate(
         | 
| 2142 | 
            +
                    self,
         | 
| 2143 | 
            +
                    inputs: Optional[torch.Tensor] = None,
         | 
| 2144 | 
            +
                    images: Optional[torch.Tensor] = None,
         | 
| 2145 | 
            +
                    image_sizes: Optional[torch.Tensor] = None,
         | 
| 2146 | 
            +
                    **kwargs,
         | 
| 2147 | 
            +
                ) -> Union[GenerateOutput, torch.LongTensor]:
         | 
| 2148 | 
            +
                    position_ids = kwargs.pop("position_ids", None)
         | 
| 2149 | 
            +
                    attention_mask = kwargs.pop("attention_mask", None)
         | 
| 2150 | 
            +
                    if "inputs_embeds" in kwargs:
         | 
| 2151 | 
            +
                        raise NotImplementedError("`inputs_embeds` is not supported")
         | 
| 2152 | 
            +
             | 
| 2153 | 
            +
                    if images is not None:
         | 
| 2154 | 
            +
                        (
         | 
| 2155 | 
            +
                            inputs,
         | 
| 2156 | 
            +
                            position_ids,
         | 
| 2157 | 
            +
                            attention_mask,
         | 
| 2158 | 
            +
                            _,
         | 
| 2159 | 
            +
                            inputs_embeds,
         | 
| 2160 | 
            +
                            _
         | 
| 2161 | 
            +
                        ) = self.prepare_inputs_labels_for_multimodal(
         | 
| 2162 | 
            +
                            inputs,
         | 
| 2163 | 
            +
                            position_ids,
         | 
| 2164 | 
            +
                            attention_mask,
         | 
| 2165 | 
            +
                            None,
         | 
| 2166 | 
            +
                            None,
         | 
| 2167 | 
            +
                            images,
         | 
| 2168 | 
            +
                            image_sizes=image_sizes
         | 
| 2169 | 
            +
                        )
         | 
| 2170 | 
            +
                    else:
         | 
| 2171 | 
            +
                        inputs_embeds = self.get_model().embed_tokens(inputs)
         | 
| 2172 | 
            +
             | 
| 2173 | 
            +
                    return super().generate(
         | 
| 2174 | 
            +
                        position_ids=position_ids,
         | 
| 2175 | 
            +
                        attention_mask=attention_mask,
         | 
| 2176 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 2177 | 
            +
                        **kwargs
         | 
| 2178 | 
            +
                    )
         | 
| 2179 | 
            +
             | 
| 2180 | 
            +
                def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
         | 
| 2181 | 
            +
                                                  inputs_embeds=None, **kwargs):
         | 
| 2182 | 
            +
                    images = kwargs.pop("images", None)
         | 
| 2183 | 
            +
                    image_sizes = kwargs.pop("image_sizes", None)
         | 
| 2184 | 
            +
                    inputs = super().prepare_inputs_for_generation(
         | 
| 2185 | 
            +
                        input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
         | 
| 2186 | 
            +
                    )
         | 
| 2187 | 
            +
                    if images is not None:
         | 
| 2188 | 
            +
                        inputs['images'] = images
         | 
| 2189 | 
            +
                    if image_sizes is not None:
         | 
| 2190 | 
            +
                        inputs['image_sizes'] = image_sizes
         | 
| 2191 | 
            +
                    return inputs
         | 
| 2192 | 
            +
             | 
| 2193 | 
            +
             | 
| 2194 | 
            +
            AutoConfig.register("llava_qwen2", LlavaConfig)
         | 
| 2195 | 
            +
            AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)
         | 
