diff --git "a/pr-16104-summary.md" "b/pr-16104-summary.md" new file mode 100644--- /dev/null +++ "b/pr-16104-summary.md" @@ -0,0 +1,3984 @@ +# PR #16104: [Model] Support Llama4 in vLLM + +## § Overview +- **Author:** @houseroad +- **Status:** merged +- **Created:** 2025-04-05 +- **Merged:** 2025-04-06 +- **Base:** v0.8.3 ← **Head:** init_pr + +## § Description + +Add the support for Llama4 Scout (17B x 16 Experts) and Maverick (17B x 128 Experts) in vLLM. + +Using 8xH100, vLLM can serve Scout with 1M context and Maverick with about 430K. +``` +vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct \ + --tensor-parallel-size 8 \ + --max-model-len 1280000 + +vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \ + --tensor-parallel-size 8 \ + --max-model-len 430000 +``` + +Using 8xH200, vLLM can serve Scout with 3.6M context and Maverick with full 1M context. +``` +vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct \ + --tensor-parallel-size 8 \ + --max-model-len 3600000 + +vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \ + --tensor-parallel-size 8 +``` + +Using MI300x, we can run with default settings. +``` +VLLM_WORKER_MULTIPROC_METHOD=spawn \ +VLLM_USE_MODELSCOPE=False \ +SAFETENSORS_FAST_GPU=1 VLLM_USE_V1=1 vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct \ + --disable-log-requests -tp 8 \ + --max-num-seqs 64 +``` + +Check out blog post [link coming soon] for performance enhancement and leveraging long context. + + +FIX #16106 + +## § Linked Issues +- **#16106: [New Model]: Llama4 Support** (CLOSED) + ### 🚀 The feature, motivation and pitch + + Meta released 2 Variants: + + Llama 4 Scout: + A high-performing small model with 17B activated parameters across 16 experts. Extremely fast, natively multimodal, supports a 10M+ token context window, and runs on a single GPU. + + Llama 4 Maverick: + A top-tier multimodal model outperforming GPT-4o and Gemini 2.0 Flash, with performance on par with DeepSeek V3 at half the active parameters. ELO 1417 on LMArena and runs on a single host. + + ### Alternatives + + _No response_ + + ### Additional context + + _No response_ + + ### Before submitting a new issue... + + - [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + + +## § Files Changed (35 files) +- `benchmarks/kernels/benchmark_moe.py` +- `docs/source/models/supported_models.md` +- `examples/offline_inference/vision_language.py` +- `examples/offline_inference/vision_language_multi_image.py` +- `requirements/common.txt` +- `requirements/test.in` +- `requirements/test.txt` +- `tests/models/decoder_only/vision_language/test_models.py` +- `tests/models/multimodal/processing/test_common.py` +- `tests/models/multimodal/processing/test_llama4.py` +- `tests/models/registry.py` +- `tests/models/test_registry.py` +- `vllm/config.py` +- `vllm/entrypoints/chat_utils.py` +- `vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json` +- `vllm/model_executor/layers/fused_moe/cutlass_moe.py` +- `vllm/model_executor/layers/fused_moe/fused_moe.py` +- `vllm/model_executor/layers/fused_moe/layer.py` +- `vllm/model_executor/layers/layernorm.py` +- `vllm/model_executor/layers/quantization/awq_marlin.py` +- `vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py` +- `vllm/model_executor/layers/quantization/experts_int8.py` +- `vllm/model_executor/layers/quantization/fp8.py` +- `vllm/model_executor/layers/quantization/gguf.py` +- `vllm/model_executor/layers/quantization/gptq_marlin.py` +- `vllm/model_executor/layers/quantization/moe_wna16.py` +- `vllm/model_executor/layers/quantization/quark/quark_moe.py` +- `vllm/model_executor/layers/rotary_embedding.py` +- `vllm/model_executor/models/llama.py` +- `vllm/model_executor/models/llama4.py` +- `vllm/model_executor/models/mllama4.py` +- `vllm/model_executor/models/registry.py` +- `vllm/v1/attention/backends/flash_attn.py` +- `vllm/v1/attention/backends/triton_attn.py` +- `vllm/v1/worker/gpu_model_runner.py` + +## § Code Review Comments +### § Review by @ywang96 (Commented) +> Multimodal part looks fine to me - left some nits but we can fix them later + +**File:** `tests/models/registry.py:4` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -337,6 +337,7 @@ def check_available_online( + tokenizer="facebook/bart-base", + trust_remote_code=True), # noqa: E501 + "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 ++ "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 +``` +**Comment:** We need to add a `min_transformers_version` on this + +**File:** `vllm/model_executor/models/mllama4.py:438` +**Context:** +```diff +... ++ grads on inputs. ++ """ ++ return self.patch_embedding +``` +**Comment:** I don't think this function is used anywhere? Can we remove it? + +**File:** `vllm/model_executor/models/mllama4.py:474` +**Context:** +```diff +... ++ # Remove CLS token output ++ hidden_state = hidden_state[:, :-1, :] ++ ++ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings ++ hidden_state = self.vision_adapter(hidden_state) +... +``` +**Comment:** We don't really need to wrap the output hidden state insinde `BaseModelOutput` class (this is something only `transformers` requires) + + +### § Review by @mgoin (Commented) +**File:** `vllm/model_executor/layers/fused_moe/cutlass_moe.py:13` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -96,8 +97,14 @@ def cutlass_moe_fp8( + n = w2_q.size(1) + + topk = topk_ids.size(1) ++ assert topk == 1, \ +``` +**Comment:** Should we move this assert to be in the `if apply_router_weight_on_input:` conditional? This seems restrictive without checking if apply_router_weight_on_input is true + +**File:** `vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py:26` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -240,24 +240,28 @@ def apply( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(x, +``` +**Comment:** Forgot to add attribute like in other method + + +### § Review by @simon-mo (Commented) +**File:** `vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py:26` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -240,24 +240,28 @@ def apply( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(x, +``` +**Comment:** This is WIP by @luccafong + + +### § Review by @ywang96 (Commented) +**File:** `tests/models/registry.py:4` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -337,6 +337,7 @@ def check_available_online( + tokenizer="facebook/bart-base", + trust_remote_code=True), # noqa: E501 + "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 ++ "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 +``` +**Comment:** Resolving per [a19cf7b](https://github.com/vllm-project/vllm/pull/16104/commits/a19cf7bafcb37adb359742cb4258da94605a7a3e) + + +### § Review by @ywang96 (Commented) +**File:** `vllm/model_executor/models/mllama4.py:474` +**Context:** +```diff +... ++ # Remove CLS token output ++ hidden_state = hidden_state[:, :-1, :] ++ ++ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings ++ hidden_state = self.vision_adapter(hidden_state) +... +``` +**Comment:** Resolving - we will fix it in a later PR + + +### § Review by @ywang96 (Commented) +**File:** `vllm/model_executor/models/mllama4.py:438` +**Context:** +```diff +... ++ grads on inputs. ++ """ ++ return self.patch_embedding +``` +**Comment:** Resolved in [bacd195](https://github.com/vllm-project/vllm/pull/16104/commits/bacd1954f5df85b9bed917ae13179a4a2abba1e3) + + +### § Review by @tlrmchlsmth (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** Is tp 8 overkill for the scout model? + + +### § Review by @houseroad (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** we can do tp4. + + +### § Review by @ywang96 (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** ~I think TP=4 should be fine for this model on most devices?~ nvm - confirmed it's better to recommend running TP=8 + + +### § Review by @houseroad (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** TP2 may not have enough HBM. + + +### § Review by @LucasWilkinson (Commented) +**File:** `vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py:26` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -240,24 +240,28 @@ def apply( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(x, +``` +**Comment:** Are you guys referring to the pre-commit failure? sorry I think this was from my changes, @luccafong I have a fix for this I can push if you want, otherwise I can send you a patch (if you haven't already fixed it) + + +### § Review by @yeqcharlotte (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** we can try to create an fp8 version with llm compressor for scout which is runnable on tp2 later + + +### § Review by @ywang96 (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** Yea this is fine for this release - examples aren't tied to release and they can be updated anytime anyways + + +### § Review by @tlrmchlsmth (Commented) +**File:** `vllm/v1/attention/backends/flash_attn.py:75` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -96,6 +96,183 @@ class FlashAttentionMetadata: + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + ++ # for local attention +``` +**Comment:** Should it be named `q_seqlens_np`? + +**File:** `vllm/v1/attention/backends/flash_attn.py` (File-level comment) +**Comment:** `LocalAttentionMetadata` and `make_local_attention_virtual_batches` look good to me. BTW has anybody profiled this? We should look at writing a "kernel" as a followup + + +### § Review by @yeqcharlotte (Commented) +**File:** `vllm/model_executor/models/registry.py:4` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -73,6 +73,7 @@ + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), ++ "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), +``` +**Comment:** @houseroad @ywang96 do you think we can just delete this + + +### § Review by @ywang96 (Commented) +**File:** `vllm/model_executor/models/registry.py:4` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -73,6 +73,7 @@ + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), ++ "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), +``` +**Comment:** We actually need it for the following to work +```python + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + architectures=["Llama4ForCausalLM"], + prefix=maybe_prefix(prefix, "language_model")) +``` +[b4533e3](https://github.com/vllm-project/vllm/pull/16104/commits/b4533e3aaa54ba48a3a4814b6d3b8d47347cad18) should fix the CI error on this. + + + +### § Review by @ywang96 (Commented) +**File:** `vllm/model_executor/models/registry.py:4` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -73,6 +73,7 @@ + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), ++ "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), +``` +**Comment:** We can also change the code in `mllama4.py` to import `Llama4ForCausalLM` directly so that we don't need to register it, but I feel like it's probably better to do it in a follow-up PR. + + +### § Review by @AlekseyKorshuk (Commented) +**File:** `vllm/model_executor/models/mllama4.py:491` +**Context:** +```diff +... ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} +``` +**Comment:** Why limit it to 10 images only if the model has to support way more, given its context length and benchmark results published by Meta claiming of processing up to 20 hours of video? + + +### § Review by @ywang96 (Commented) +**File:** `vllm/model_executor/models/mllama4.py:491` +**Context:** +```diff +... ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} +``` +**Comment:** I don't think video inference is the scope of this release yet? + +This PR doesn't support video modality so I guess it'll come in the next model update? + + +### § Review by @yeqcharlotte (Commented) +**File:** `vllm/model_executor/models/mllama4.py:491` +**Context:** +```diff +... ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} +``` +**Comment:** @AlekseyKorshuk 8-10 image is the recommended mm limit giving you acceptable quality although from the infra perspective it can do more. + +Llama4’s video tokenizer works slightly different form image and we’ll update that once it’s available. + + +### § Review by @AlekseyKorshuk (Commented) +**File:** `vllm/model_executor/models/mllama4.py:491` +**Context:** +```diff +... ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} +``` +**Comment:** That's a fair point, but it raises an error if set cli argument to >10 multimodal limit. Shouldn't 10 be a default value, but not the hard limit that is not possible to overcome without changing the code? + + +### § Review by @DarkLight1337 (Commented) +**File:** `docs/source/models/supported_models.md:43` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -982,10 +989,10 @@ See [this page](#generative-models) for more information on how to use generativ + * ✅︎ + ::: + +-^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM. +``` +**Comment:** These whitespaces should not be removed. They are intentional to add paragraph spacing + + +### § Review by @ywang96 (Commented) +**File:** `vllm/model_executor/models/mllama4.py:491` +**Context:** +```diff +... ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} +``` +**Comment:** I think I'm okay with not capping it at 10, but setting a default value for this will be something model-dependent which we currently don't support today on vLLM (and it's tricky to do that since today there's no standard on how many images a model can support up to), so we let user do it by passing `limit-mm-per-prompt`. + + +### § Review by @AlekseyKorshuk (Commented) +**File:** `vllm/model_executor/models/mllama4.py:491` +**Context:** +```diff +... ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} +``` +**Comment:** Sounds good, just wanted to make sure that this value is easy for users to change based on their needs. Thanks for the reply, gonna resolve the conversation + + +### § Review by @ywang96 (Commented) +**File:** `docs/source/models/supported_models.md:43` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -982,10 +989,10 @@ See [this page](#generative-models) for more information on how to use generativ + * ✅︎ + ::: + +-^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM. +``` +**Comment:** Given this is only added to the 0.8.3 release branch, I will fix it on main once it's merged there + + +### § Review by @LucasWilkinson (Commented) +**File:** `vllm/v1/attention/backends/flash_attn.py` (File-level comment) +**Comment:** I don't believe so, atleast I never did. I think as a first cut we could even just write a C++ op, this code is ALOT easier to understand as a loop and honestly would probably be faster as a loop (assuming its a C++ loop and not a python loop) since theres sooo many numpy calls in this version. I just wrote it this way assuming it would scale to larger batch sizes better than a python loop. + + +### § Review by @LucasWilkinson (Commented) +**File:** `vllm/v1/attention/backends/flash_attn.py:75` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -96,6 +96,183 @@ class FlashAttentionMetadata: + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + ++ # for local attention +``` +**Comment:** could be, I just dropped the np suffixes in this function since they are all numpy arrays, but we could add them back in a future PR + + +### § Review by @wenmengzhou (Commented) +**File:** `examples/offline_inference/vision_language.py:23` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: + ) + + ++def run_llama4(questions: list[str], modality: str): +``` +**Comment:** missing content of image, it should be +{ + "type": "image", + "image": "https://path/to/your/image.jpg" +} + + +### § Review by @ywang96 (Commented) +**File:** `examples/offline_inference/vision_language.py:23` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: + ) + + ++def run_llama4(questions: list[str], modality: str): +``` +**Comment:** The way it works with our offline inference `llm.generate` interface is actually a bit different from huggingface interface. In this case we're adding this chunk here only for it to insert the image placeholder token into the prompt when we apply the chat template from the tokenizer. + + +### § Review by @jianyuh (Commented) +**File:** `examples/offline_inference/vision_language_multi_image.py:11` +**Context:** +```diff +(Could not locate the specific line in the diff hunk) +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: +``` +**Comment:** @tlrmchlsmth nice to see you here :) In Llama4 blog (https://ai.meta.com/blog/llama-4-multimodal-intelligence/), it mentions we can fit scout model in a single H100 GPU so using tp1: + +The former fits on a single H100 GPU (with Int4 quantization) while the latter fits on a single H100 host. + +https://github.com/meta-llama/llama-stack/blob/3f92b2bf85df6762b039c22ef54c6cad3c45f2c9/llama_stack/providers/inline/inference/meta_reference/llama4/generation.py#L44 + +--> + +https://github.com/meta-llama/llama-stack/blob/3f92b2bf85df6762b039c22ef54c6cad3c45f2c9/llama_stack/providers/inline/inference/meta_reference/quantize_impls.py#L284 + + + + +## § General Comments +### § @github-actions - 2025-04-05 19:09 +👋 Hi! Thank you for contributing to the vLLM project. + +💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. + +Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. + +Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. + +To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge. + +🚀 + +### § @robertgshaw2-redhat - 2025-04-05 19:39 +🔥 + +### § @dsingal0 - 2025-04-05 20:50 +Is it expected to get this error: + File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 451, in inspect_model_cls + + return self._raise_for_unsupported(architectures) + + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 401, in _raise_for_unsupported + + raise ValueError( + +ValueError: Model architectures ['Llama4ForConditionalGeneration'] failed to be inspected. Please check the logs for more details. + +### § @ywang96 - 2025-04-05 20:54 +> Is it expected to get this error: File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 451, in inspect_model_cls +> +> ``` +> return self._raise_for_unsupported(architectures) +> +> ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +> ``` +> +> File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 401, in _raise_for_unsupported +> +> ``` +> raise ValueError( +> ``` +> +> ValueError: Model architectures ['Llama4ForConditionalGeneration'] failed to be inspected. Please check the logs for more details. + +@dsingal0 Which version of ` transformers ` are you on? + +### § @dsingal0 - 2025-04-05 20:58 +> > Is it expected to get this error: File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 451, in inspect_model_cls +> > ``` +> > return self._raise_for_unsupported(architectures) +> > +> > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +> > ``` +> > +> > +> > +> > +> > +> > +> > +> > +> > +> > +> > +> > File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 401, in _raise_for_unsupported +> > ``` +> > raise ValueError( +> > ``` +> > +> > +> > +> > +> > +> > +> > +> > +> > +> > +> > +> > ValueError: Model architectures ['Llama4ForConditionalGeneration'] failed to be inspected. Please check the logs for more details. +> +> @dsingal0 Which version of `transformers` are you on? + +transformers-4.52.0.dev0 + +### § @dsingal0 - 2025-04-05 21:33 +I think transformers.models.llama4.image_processing_llama4 needs to be changed to transformers.models.llama4.image_processing_llama4_fast + +### § @ywang96 - 2025-04-05 21:33 +> I think transformers.models.llama4.image_processing_llama4 needs to be changed to transformers.models.llama4.image_processing_llama4_fast + +Yea it's been addressed in [62e9744](https://github.com/vllm-project/vllm/pull/16104/commits/62e974401ae3cd4240a6eb109cb585c25a40da29) already + +### § @fsaudm - 2025-04-06 16:25 +Quantization support? + + +## § Commits (16 commits) +- `58d9c2f`: Initial Llama4 enablement for vLLM +- `dcb2c77`: revert changes in vllm/assets/image.py (#116) +- `89083a6`: fix inplace_fused_experts_fake (#117) +- `188bb52`: Bump transformers version to 4.51.0 (#119) +- `6ad393f`: clean up model names and whitespaces (#120) +- `ee170a7`: Revert "Bump transformers version to 4.51.0 (#119)" +- `a19cf7b`: Reapply "Bump transformers version to 4.51.0 (#119)" +- `bacd195`: remove used method +- `ec6cdaa`: fix MOE lint (#2) +- `62e9744`: fix llama4 processing (#3) +- `b4533e3`: skip llama4 standalone +- `0587bc7`: add marks +- `c0ca739`: Add apply_router_weight_on_input to CompressedTensorsWNA16MoEMethod's… +- `866b94a`: precommit +- `1b8b67a`: Add apply_router_weight_on_input to all FusedMoEMethodBase classes (#5) +- `4e45bfc`: fix basic model test + +## § PR Diff + +```diff +diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py +index f1803b39c883..afe0b53077a7 100644 +--- a/benchmarks/kernels/benchmark_moe.py ++++ b/benchmarks/kernels/benchmark_moe.py +@@ -553,6 +553,9 @@ def main(args: argparse.Namespace): + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + else: ++ if not hasattr(config, "hidden_size"): ++ # Support for llama4 ++ config = config.text_config + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok +diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md +index 74b4eab92043..bb318be988b3 100644 +--- a/docs/source/models/supported_models.md ++++ b/docs/source/models/supported_models.md +@@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi + + To check if the modeling backend is Transformers, you can simply do this: + +-```python ++```python + from vllm import LLM + llm = LLM(model=..., task="generate") # Name or path of your model + llm.apply_model(lambda model: print(type(model))) +@@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti + Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers. + Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM! + +-```python ++```python + from vllm import LLM + llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model + llm.apply_model(lambda model: print(model.__class__)) +@@ -840,6 +840,13 @@ See [this page](#generative-models) for more information on how to use generativ + * + * ✅︎ + * ✅︎ ++- * `Llama4ForConditionalGeneration` ++ * Llama-4-17B-Omni-Instruct ++ * T + I+ ++ * `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. ++ * ++ * ++ * ✅︎ + - * `LlavaForConditionalGeneration` + * LLaVA-1.5 + * T + IE+ +@@ -982,10 +989,10 @@ See [this page](#generative-models) for more information on how to use generativ + * ✅︎ + ::: + +-^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM. +-    • For example, to use DeepSeek-VL2 series models: +-      `--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` +-E Pre-computed embeddings can be inputted for this modality. ++^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM. ++    • For example, to use DeepSeek-VL2 series models: ++      `--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` ++E Pre-computed embeddings can be inputted for this modality. + + Multiple items can be inputted per text prompt for this modality. + + :::{important} +diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py +index c1115708505a..61d53dda1c47 100644 +--- a/examples/offline_inference/vision_language.py ++++ b/examples/offline_inference/vision_language.py +@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: + ) + + ++def run_llama4(questions: list[str], modality: str): ++ assert modality == "image" ++ ++ model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" ++ ++ engine_args = EngineArgs( ++ model=model_name, ++ max_model_len=8192, ++ max_num_seqs=4, ++ tensor_parallel_size=8, ++ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ++ gpu_memory_utilization=0.4, ++ ) ++ ++ tokenizer = AutoTokenizer.from_pretrained(model_name) ++ messages = [[{ ++ "role": ++ "user", ++ "content": [{ ++ "type": "image" ++ }, { ++ "type": "text", ++ "text": f"{question}" ++ }] ++ }] for question in questions] ++ prompts = tokenizer.apply_chat_template(messages, ++ add_generation_prompt=True, ++ tokenize=False) ++ stop_token_ids = None ++ return ModelRequestData( ++ engine_args=engine_args, ++ prompts=prompts, ++ stop_token_ids=stop_token_ids, ++ ) ++ ++ + # Molmo + def run_molmo(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" +@@ -907,6 +943,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: + "minicpmv": run_minicpmv, + "mistral3": run_mistral3, + "mllama": run_mllama, ++ "llama4": run_llama4, + "molmo": run_molmo, + "NVLM_D": run_nvlm_d, + "paligemma": run_paligemma, +diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py +index 39951e5e89c4..e03ebe485eaa 100644 +--- a/examples/offline_inference/vision_language_multi_image.py ++++ b/examples/offline_inference/vision_language_multi_image.py +@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + + ++def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ++ model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" ++ ++ engine_args = EngineArgs( ++ model=model_name, ++ max_model_len=8192, ++ max_num_seqs=4, ++ tensor_parallel_size=8, ++ limit_mm_per_prompt={"image": len(image_urls)}, ++ ) ++ ++ placeholders = [{"type": "image", "image": url} for url in image_urls] ++ messages = [{ ++ "role": ++ "user", ++ "content": [ ++ *placeholders, ++ { ++ "type": "text", ++ "text": question ++ }, ++ ], ++ }] ++ ++ processor = AutoProcessor.from_pretrained(model_name) ++ ++ prompt = processor.apply_chat_template(messages, ++ tokenize=False, ++ add_generation_prompt=True) ++ ++ return ModelRequestData( ++ engine_args=engine_args, ++ prompt=prompt, ++ image_data=[fetch_image(url) for url in image_urls], ++ ) ++ ++ + def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + +@@ -567,6 +604,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: + "h2ovl_chat": load_h2ovl, + "idefics3": load_idefics3, + "internvl_chat": load_internvl, ++ "llama4": load_llama4, + "mistral3": load_mistral3, + "mllama": load_mllama, + "NVLM_D": load_nvlm_d, +diff --git a/requirements/common.txt b/requirements/common.txt +index 7365a5b46a30..24a1e6d67ac2 100644 +--- a/requirements/common.txt ++++ b/requirements/common.txt +@@ -6,7 +6,7 @@ requests >= 2.26.0 + tqdm + blake3 + py-cpuinfo +-transformers >= 4.50.3 ++transformers >= 4.51.0 + huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. + tokenizers >= 0.19.1 # Required for Llama 3. + protobuf # Required by LlamaTokenizer. +diff --git a/requirements/test.in b/requirements/test.in +index 364747e9c08f..ac7f451e96a8 100644 +--- a/requirements/test.in ++++ b/requirements/test.in +@@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test + opencv-python-headless >= 4.11.0 # required for video test + datamodel_code_generator # required for minicpm3 test + lm-eval[api]==0.4.8 # required for model evaluation test +-transformers==4.50.3 ++transformers==4.51.0 + huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. + # quantization + bitsandbytes>=0.45.3 +diff --git a/requirements/test.txt b/requirements/test.txt +index 236b8be32805..39d6ed1acff0 100644 +--- a/requirements/test.txt ++++ b/requirements/test.txt +@@ -645,7 +645,7 @@ tqdm==4.66.6 + # transformers + tqdm-multiprocess==0.0.11 + # via lm-eval +-transformers==4.50.3 ++transformers==4.51.0 + # via + # -r requirements/test.in + # genai-perf +diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py +index 3b34f012f626..a9386971a034 100644 +--- a/tests/models/decoder_only/vision_language/test_models.py ++++ b/tests/models/decoder_only/vision_language/test_models.py +@@ -536,6 +536,22 @@ + limit_mm_per_prompt={"image": 1}, + )], + ), ++ "llama4": VLMTestInfo( ++ models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], ++ prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 ++ img_idx_to_prompt=lambda _: "<|image|>", ++ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), ++ distributed_executor_backend="mp", ++ image_size_factors=[(.25, 0.5, 1.0)], ++ hf_model_kwargs={"device_map": "auto"}, ++ max_model_len=8192, ++ max_num_seqs=4, ++ dtype="bfloat16", ++ auto_cls=AutoModelForImageTextToText, ++ tensor_parallel_size=8, ++ vllm_runner_kwargs={"gpu_memory_utilization": 0.8}, ++ marks=[large_gpu_mark(min_gb=80), multi_gpu_marks(num_gpus=8)], ++ ), + } + # yapf: enable + +diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py +index fdcd7a9e1738..cb4e4cdb3eee 100644 +--- a/tests/models/multimodal/processing/test_common.py ++++ b/tests/models/multimodal/processing/test_common.py +@@ -280,6 +280,7 @@ def _test_processing_correctness_mistral( + "Skywork/Skywork-R1V-38B", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "openai/whisper-large-v3", ++ "meta-llama/Llama-4-Scout-17B-16E-Instruct", + ]) + @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) + @pytest.mark.parametrize("num_batches", [32]) +diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py +new file mode 100644 +index 000000000000..7ec7c8002974 +--- /dev/null ++++ b/tests/models/multimodal/processing/test_llama4.py +@@ -0,0 +1,99 @@ ++# SPDX-License-Identifier: Apache-2.0 ++"""Tests for Llama4's multimodal preprocessing kwargs.""" ++ ++import pytest ++ ++from vllm.multimodal import MULTIMODAL_REGISTRY ++from vllm.transformers_utils.tokenizer import encode_tokens ++ ++from ....conftest import _ImageAssets ++from ...utils import build_model_context ++ ++ ++@pytest.mark.parametrize("model_id", ++ ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) ++@pytest.mark.parametrize("mm_processor_kwargs", [{}]) ++@pytest.mark.parametrize("num_imgs", [1, 5]) ++@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) ++@pytest.mark.parametrize("tokenized_prompt", [True, False]) ++def test_processor_override( ++ image_assets: _ImageAssets, ++ model_id: str, ++ mm_processor_kwargs: dict, ++ num_imgs: int, ++ disable_mm_preprocessor_cache: bool, ++ tokenized_prompt: bool, ++): ++ """Ensure llama4 processor works properly.""" ++ ctx = build_model_context( ++ model_id, ++ mm_processor_kwargs=mm_processor_kwargs, ++ limit_mm_per_prompt={"image": num_imgs}, ++ disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, ++ ) ++ processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) ++ config = processor.info.get_hf_config() ++ tokenizer = processor.info.get_tokenizer() ++ hf_processor = processor.info.get_hf_processor() ++ vocab = tokenizer.get_vocab() ++ ++ prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \ ++ + "<|image|>" * num_imgs \ ++ + "<|eot|><|header_start|>assistant<|header_end|>" ++ mm_data = { ++ "image": [ ++ image_assets[(i % len(image_assets))].pil_image ++ for i in range(num_imgs) ++ ] ++ } ++ if tokenized_prompt: ++ prompt = encode_tokens(tokenizer, prompt) ++ ++ processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) ++ mm_kwargs = processed_inputs["mm_kwargs"] ++ ++ # place holder replacements ++ prompt_token_ids = processed_inputs["prompt_token_ids"] ++ assert prompt_token_ids.count(config.boi_token_index) == num_imgs ++ assert prompt_token_ids.count(config.eoi_token_index) == num_imgs ++ assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs ++ aspect_ratios = mm_kwargs["aspect_ratios"] ++ num_x_separators = num_y_separators = 0 ++ for tiles_y, tiles_x in aspect_ratios: ++ if tiles_x * tiles_y > 1: ++ num_x_separators += (tiles_x - 1) * tiles_y ++ num_y_separators += tiles_y ++ assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \ ++ == num_x_separators ++ assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \ ++ == num_y_separators ++ ++ # image token offsets ++ img_locs = processed_inputs["mm_placeholders"].get("image", []) ++ assert len(img_locs) == num_imgs ++ assert [img_loc["offset"] for img_loc in img_locs] == \ ++ [i for i, v in enumerate(prompt_token_ids) \ ++ if v == config.boi_token_index] ++ ++ # patch sizes and masks ++ assert prompt_token_ids.count(config.image_token_index) \ ++ == sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"]) ++ patch_token_id = vocab[hf_processor.img_patch_token] ++ num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id) ++ mm_counts = {"image": num_imgs} ++ assert num_patches / num_imgs <= \ ++ processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"] ++ num_patches_per_chunk = processor.info.get_patch_per_chunk( ++ config.vision_config) ++ assert prompt_token_ids.count(config.image_token_index) \ ++ == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk ++ assert mm_kwargs["pixel_values"].shape[0] \ ++ == mm_kwargs["patches_per_image"].sum() ++ ++ for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"], ++ mm_kwargs["aspect_ratios"]): ++ assert embed_is_patch.shape[0] == \ ++ len(tokenizer.encode( ++ hf_processor._prompt_split_image( ++ aspect_ratio, num_patches_per_chunk), ++ add_special_tokens=False)) +diff --git a/tests/models/registry.py b/tests/models/registry.py +index 39e104a11ab1..d508c5f44dab 100644 +--- a/tests/models/registry.py ++++ b/tests/models/registry.py +@@ -337,6 +337,7 @@ def check_available_online( + tokenizer="facebook/bart-base", + trust_remote_code=True), # noqa: E501 + "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 ++ "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 + "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 + } + +diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py +index 3282284b6b27..4c5572a569ea 100644 +--- a/tests/models/test_registry.py ++++ b/tests/models/test_registry.py +@@ -23,6 +23,11 @@ + + @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) + def test_registry_imports(model_arch): ++ ++ # Llama4ForCausalLM does not have a standalone model ++ if model_arch == "Llama4ForCausalLM": ++ return ++ + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_transformers_version(on_fail="skip") + +@@ -91,8 +96,11 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda): + + + def test_hf_registry_coverage(): +- untested_archs = (ModelRegistry.get_supported_archs() - +- HF_EXAMPLE_MODELS.get_supported_archs()) ++ untested_archs = set(ModelRegistry.get_supported_archs() - ++ HF_EXAMPLE_MODELS.get_supported_archs()) ++ ++ # Llama4ForCausalLM does not have a standalone model ++ untested_archs.discard("Llama4ForCausalLM") + + assert not untested_archs, ( + "Please add the following architectures to " +diff --git a/vllm/config.py b/vllm/config.py +index 2669d1a13b37..bd52fc90b0a2 100644 +--- a/vllm/config.py ++++ b/vllm/config.py +@@ -354,6 +354,8 @@ def __init__( + self.hf_config = hf_config + + self.hf_text_config = get_hf_text_config(self.hf_config) ++ self.attention_chunk_size = getattr(self.hf_text_config, ++ "attention_chunk_size", None) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, revision) +diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py +index ff2d1aacbece..e48ffae9bf58 100644 +--- a/vllm/entrypoints/chat_utils.py ++++ b/vllm/entrypoints/chat_utils.py +@@ -500,7 +500,7 @@ def _placeholder_str(self, modality: ModalityStr, + "internvl_chat", "skywork_chat", "NVLM_D", + "h2ovl_chat"): + return "" +- if model_type == "mllama": ++ if model_type in ("mllama", "llama4"): + return "<|image|>" + if model_type in ("qwen2_vl", "qwen2_5_vl"): + return "<|vision_start|><|image_pad|><|vision_end|>" +diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +new file mode 100644 +index 000000000000..f10e39482e58 +--- /dev/null ++++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +@@ -0,0 +1,200 @@ ++{ ++ "1": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 16, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 1 ++ }, ++ "2": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 16, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 1 ++ }, ++ "4": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 16, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 2, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 1 ++ }, ++ "8": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "16": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 32, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 2, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 1 ++ }, ++ "24": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "32": { ++ "BLOCK_SIZE_M": 32, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 1 ++ }, ++ "48": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 1 ++ }, ++ "64": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 2, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "96": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "128": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "256": { ++ "BLOCK_SIZE_M": 32, ++ "BLOCK_SIZE_N": 32, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 8, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "512": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "1024": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "1536": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 8, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "2048": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "3072": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "4096": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ } ++} +diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py +index a17afd1b357e..d6a27aa0ddc4 100644 +--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py ++++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py +@@ -23,6 +23,7 @@ def cutlass_moe_fp8( + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, ++ apply_router_weight_on_input: bool = False, + ) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer +@@ -96,8 +97,14 @@ def cutlass_moe_fp8( + n = w2_q.size(1) + + topk = topk_ids.size(1) ++ + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) ++ if apply_router_weight_on_input: ++ assert topk == 1, \ ++ "apply_router_weight_on_input is only implemented for topk=1" ++ # TODO: this only works for topK=1, will need to update for topK>1 ++ a = a * topk_weights.to(out_dtype) + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) +@@ -139,6 +146,8 @@ def cutlass_moe_fp8( + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, ab_strides2, + ab_strides2, c_strides2) +- +- return (c2[c_map].view(m, topk, k) * +- topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) ++ # Gather tokens ++ c2 = c2[c_map].view(m, topk, k) ++ if not apply_router_weight_on_input: ++ c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) ++ return c2.sum(dim=1) +diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py +index 0817879c4d57..4ab99acb742f 100644 +--- a/vllm/model_executor/layers/fused_moe/fused_moe.py ++++ b/vllm/model_executor/layers/fused_moe/fused_moe.py +@@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", ++ apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, +@@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> None: + fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, +- activation, use_fp8_w8a8, use_int8_w8a16, +- use_int4_w4a16, global_num_experts, expert_map, +- w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, +- block_shape) ++ activation, apply_router_weight_on_input, use_fp8_w8a8, ++ use_int8_w8a16, use_int4_w4a16, global_num_experts, ++ expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, ++ a2_scale, block_shape) + + + def inplace_fused_experts_fake( +@@ -980,6 +981,7 @@ def inplace_fused_experts_fake( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", ++ apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, +@@ -1010,6 +1012,7 @@ def outplace_fused_experts( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", ++ apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, +@@ -1023,10 +1026,11 @@ def outplace_fused_experts( + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: + return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, +- False, activation, use_fp8_w8a8, use_int8_w8a16, +- use_int4_w4a16, global_num_experts, expert_map, +- w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, +- a2_scale, block_shape) ++ False, activation, apply_router_weight_on_input, ++ use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, ++ global_num_experts, expert_map, w1_scale, ++ w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, ++ block_shape) + + + def outplace_fused_experts_fake( +@@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", ++ apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, +@@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor, + allow_deep_gemm: bool = False) -> torch.Tensor: + if (allow_deep_gemm and use_fp8_w8a8 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): ++ assert apply_router_weight_on_input is False + return deep_gemm_moe_fp8( + hidden_states=hidden_states, + w1=w1, +@@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, ++ apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, +@@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", ++ apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, +@@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, +- False, ++ apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, +@@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, +- True, ++ not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, +diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py +index 661fb52bbee2..0e35d8a80988 100644 +--- a/vllm/model_executor/layers/fused_moe/layer.py ++++ b/vllm/model_executor/layers/fused_moe/layer.py +@@ -65,7 +65,9 @@ def apply( + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", +- e_score_correction_bias: Optional[torch.Tensor] = None ++ e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, ++ activation: str = "silu", + ) -> torch.Tensor: + raise NotImplementedError + +@@ -156,22 +158,25 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: +- return self.forward(x=x, +- layer=layer, +- router_logits=router_logits, +- top_k=top_k, +- renormalize=renormalize, +- use_grouped_topk=use_grouped_topk, +- topk_group=topk_group, +- num_expert_group=num_expert_group, +- global_num_experts=global_num_experts, +- expert_map=expert_map, +- custom_routing_function=custom_routing_function, +- scoring_func=scoring_func, +- e_score_correction_bias=e_score_correction_bias, +- activation=activation) ++ return self.forward( ++ x=x, ++ layer=layer, ++ router_logits=router_logits, ++ top_k=top_k, ++ renormalize=renormalize, ++ use_grouped_topk=use_grouped_topk, ++ topk_group=topk_group, ++ num_expert_group=num_expert_group, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map, ++ custom_routing_function=custom_routing_function, ++ scoring_func=scoring_func, ++ e_score_correction_bias=e_score_correction_bias, ++ activation=activation, ++ apply_router_weight_on_input=apply_router_weight_on_input) + + def forward_cuda( + self, +@@ -188,6 +193,7 @@ def forward_cuda( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( +@@ -202,15 +208,17 @@ def forward_cuda( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(hidden_states=x, +- w1=layer.w13_weight, +- w2=layer.w2_weight, +- topk_weights=topk_weights, +- topk_ids=topk_ids, +- inplace=True, +- activation=activation, +- global_num_experts=global_num_experts, +- expert_map=expert_map) ++ return fused_experts( ++ hidden_states=x, ++ w1=layer.w13_weight, ++ w2=layer.w2_weight, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ inplace=True, ++ activation=activation, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map) + + def forward_cpu( + self, +@@ -228,9 +236,11 @@ def forward_cpu( + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ++ apply_router_weight_on_input: bool = False, + **kwargs, + ): + assert activation == "silu", f"{activation} is not supported." ++ assert apply_router_weight_on_input is False + return layer.ipex_fusion( + x, + use_grouped_topk, +@@ -259,6 +269,7 @@ def forward_hpu( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert not use_grouped_topk +@@ -266,6 +277,7 @@ def forward_hpu( + assert topk_group is None + assert custom_routing_function is None + assert layer is not None ++ assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for HPU.") +@@ -290,12 +302,14 @@ def forward_tpu( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None ++ assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for TPU.") +@@ -401,6 +415,7 @@ def __init__( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + super().__init__() +@@ -486,6 +501,7 @@ def __init__( + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + ++ self.apply_router_weight_on_input = apply_router_weight_on_input + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, +@@ -853,6 +869,7 @@ def forward_impl(self, hidden_states: torch.Tensor, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, ++ apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + if self.dp_size > 1: +diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py +index 76d3acb92fb8..5e8eb6c54c89 100644 +--- a/vllm/model_executor/layers/layernorm.py ++++ b/vllm/model_executor/layers/layernorm.py +@@ -92,6 +92,7 @@ def __init__( + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, ++ dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + +@@ -100,8 +101,10 @@ def __init__( + self.variance_size_override = (None if var_hidden_size == hidden_size + else var_hidden_size) + self.has_weight = has_weight +- +- self.weight = torch.ones(hidden_size) ++ if dtype is not None: ++ self.weight = torch.ones(hidden_size, dtype=dtype) ++ else: ++ self.weight = torch.ones(hidden_size) + if self.has_weight: + self.weight = nn.Parameter(self.weight) + +diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py +index 473816fcc3ec..cb1d5400f3a0 100644 +--- a/vllm/model_executor/layers/quantization/awq_marlin.py ++++ b/vllm/model_executor/layers/quantization/awq_marlin.py +@@ -469,6 +469,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." +@@ -476,6 +477,10 @@ def apply( + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") ++ if apply_router_weight_on_input: ++ raise NotImplementedError( ++ "Apply router weight on input is not supported for" ++ "fused Marlin MoE method.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, +diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +index bf32bee89e89..f573c8ae5131 100644 +--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py ++++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +@@ -224,6 +224,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts +@@ -240,20 +241,22 @@ def apply( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(x, +- layer.w13_weight, +- layer.w2_weight, +- topk_weights=topk_weights, +- topk_ids=topk_ids, +- inplace=True, +- activation=activation, +- use_fp8_w8a8=True, +- global_num_experts=global_num_experts, +- expert_map=expert_map, +- w1_scale=layer.w13_weight_scale, +- w2_scale=layer.w2_weight_scale, +- a1_scale=layer.w13_input_scale, +- a2_scale=layer.w2_input_scale) ++ return fused_experts( ++ x, ++ layer.w13_weight, ++ layer.w2_weight, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ inplace=True, ++ activation=activation, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ use_fp8_w8a8=True, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map, ++ w1_scale=layer.w13_weight_scale, ++ w2_scale=layer.w2_weight_scale, ++ a1_scale=layer.w13_input_scale, ++ a2_scale=layer.w2_input_scale) + + + class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): +@@ -438,6 +441,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + +@@ -474,6 +478,7 @@ def apply( + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + out_dtype=x.dtype, ++ apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +@@ -778,6 +783,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." +@@ -785,6 +791,10 @@ def apply( + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") ++ if apply_router_weight_on_input: ++ raise NotImplementedError( ++ "Apply router weight on input is not supported for " ++ "fused Marlin MoE method.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, +diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py +index d18ca55afebd..be19b80975ec 100644 +--- a/vllm/model_executor/layers/quantization/experts_int8.py ++++ b/vllm/model_executor/layers/quantization/experts_int8.py +@@ -113,6 +113,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts +@@ -129,18 +130,20 @@ def apply( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(x, +- layer.w13_weight, +- layer.w2_weight, +- topk_weights=topk_weights, +- topk_ids=topk_ids, +- inplace=True, +- activation=activation, +- use_int8_w8a16=True, +- global_num_experts=global_num_experts, +- expert_map=expert_map, +- w1_scale=layer.w13_scale, +- w2_scale=layer.w2_scale) ++ return fused_experts( ++ x, ++ layer.w13_weight, ++ layer.w2_weight, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ inplace=True, ++ activation=activation, ++ use_int8_w8a16=True, ++ global_num_experts=global_num_experts, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ expert_map=expert_map, ++ w1_scale=layer.w13_scale, ++ w2_scale=layer.w2_scale) + + @staticmethod + def quantizing_weight_loader(layer, weight_loader): +diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py +index e7c733db5c00..4435644c4f84 100644 +--- a/vllm/model_executor/layers/quantization/fp8.py ++++ b/vllm/model_executor/layers/quantization/fp8.py +@@ -773,6 +773,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts +@@ -800,6 +801,7 @@ def apply( + activation=activation, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, ++ apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), +diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py +index 9861e0a85b3f..6b499f81c55f 100644 +--- a/vllm/model_executor/layers/quantization/gguf.py ++++ b/vllm/model_executor/layers/quantization/gguf.py +@@ -338,9 +338,15 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + assert activation == "silu", "Only SiLU activation is supported." ++ if apply_router_weight_on_input: ++ raise NotImplementedError( ++ "Apply router weight on input is not supported for" ++ "fused GGUF MoE method.") ++ + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, +diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py +index 9f53ffc1d7f6..0615bb4ab4df 100644 +--- a/vllm/model_executor/layers/quantization/gptq_marlin.py ++++ b/vllm/model_executor/layers/quantization/gptq_marlin.py +@@ -592,9 +592,14 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." ++ if apply_router_weight_on_input is not None: ++ raise NotImplementedError( ++ "Apply router weight on input is not supported for" ++ "fused Marlin MoE method.") + + # The input must currently be float16 + orig_dtype = x.dtype +diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py +index 41b75c9be05a..00c4b661ef2c 100644 +--- a/vllm/model_executor/layers/quantization/moe_wna16.py ++++ b/vllm/model_executor/layers/quantization/moe_wna16.py +@@ -293,6 +293,7 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts +@@ -312,21 +313,23 @@ def apply( + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + +- return fused_experts(x, +- layer.w13_qweight, +- layer.w2_qweight, +- topk_weights=topk_weights, +- topk_ids=topk_ids, +- inplace=True, +- use_int4_w4a16=weight_bits == 4, +- use_int8_w8a16=weight_bits == 8, +- global_num_experts=global_num_experts, +- expert_map=expert_map, +- w1_scale=layer.w13_scales, +- w2_scale=layer.w2_scales, +- w1_zp=layer.w13_qzeros if has_zp else None, +- w2_zp=layer.w2_qzeros if has_zp else None, +- block_shape=[0, layer.group_size]) ++ return fused_experts( ++ x, ++ layer.w13_qweight, ++ layer.w2_qweight, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ inplace=True, ++ use_int4_w4a16=weight_bits == 4, ++ use_int8_w8a16=weight_bits == 8, ++ global_num_experts=global_num_experts, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ expert_map=expert_map, ++ w1_scale=layer.w13_scales, ++ w2_scale=layer.w2_scales, ++ w1_zp=layer.w13_qzeros if has_zp else None, ++ w2_zp=layer.w2_qzeros if has_zp else None, ++ block_shape=[0, layer.group_size]) + + @staticmethod + def get_weight_loader(layer, weight_loader): +diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py +index bc26a455c6f2..d1146c0f039d 100644 +--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py ++++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py +@@ -202,6 +202,8 @@ def apply( + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, ++ activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + +@@ -217,16 +219,18 @@ def apply( + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + +- return fused_experts(x, +- layer.w13_weight, +- layer.w2_weight, +- topk_weights=topk_weights, +- topk_ids=topk_ids, +- inplace=True, +- use_fp8_w8a8=True, +- global_num_experts=global_num_experts, +- expert_map=expert_map, +- w1_scale=layer.w13_weight_scale, +- w2_scale=layer.w2_weight_scale, +- a1_scale=layer.w13_input_scale, +- a2_scale=layer.w2_input_scale) ++ return fused_experts( ++ x, ++ layer.w13_weight, ++ layer.w2_weight, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ inplace=True, ++ use_fp8_w8a8=True, ++ global_num_experts=global_num_experts, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ expert_map=expert_map, ++ w1_scale=layer.w13_weight_scale, ++ w2_scale=layer.w2_weight_scale, ++ a1_scale=layer.w13_input_scale, ++ a2_scale=layer.w2_input_scale) +diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py +index fd27775b7dc0..624ed63ab8b4 100644 +--- a/vllm/model_executor/layers/rotary_embedding.py ++++ b/vllm/model_executor/layers/rotary_embedding.py +@@ -851,6 +851,70 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + return new_freqs + + ++class Llama4VisionRotaryEmbedding(RotaryEmbedding): ++ ++ def __init__( ++ self, ++ head_size: int, ++ rotary_dim: int, ++ max_position_embeddings: int, ++ base: int, ++ is_neox_style: bool, ++ dtype: torch.dtype, ++ ): ++ super().__init__(head_size, rotary_dim, max_position_embeddings, base, ++ is_neox_style, dtype) ++ ++ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: ++ inv_freqs = super()._compute_inv_freq(base) ++ inv_freqs = inv_freqs[:(self.rotary_dim // 2)] ++ return inv_freqs ++ ++ def _compute_cos_sin_cache(self) -> torch.Tensor: ++ inv_freq = self._compute_inv_freq(self.base) ++ ++ # self.max_position_embeddings here is number of image patches ++ # i.e. (image_size // patch_size) ** 2 ++ num_patches = self.max_position_embeddings ++ img_idx = torch.arange(num_patches, ++ dtype=torch.int32) \ ++ .reshape(num_patches, 1) ++ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) ++ img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN ++ num_patches_single_dim = int(math.sqrt(num_patches)) ++ frequencies_x = img_idx % num_patches_single_dim ++ frequencies_y = img_idx // num_patches_single_dim ++ freqs_x = ((frequencies_x + 1)[..., None] * ++ inv_freq[None, None, :]).repeat_interleave(2, dim=-1) ++ freqs_y = ((frequencies_y + 1)[..., None] * ++ inv_freq[None, None, :]).repeat_interleave(2, dim=-1) ++ freqs = torch.cat([freqs_x, freqs_y], ++ dim=-1).float().contiguous()[..., ::2] ++ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) ++ cache = torch.view_as_complex( ++ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) ++ return cache ++ ++ def forward( ++ self, ++ query: torch.Tensor, ++ key: torch.Tensor, ++ ) -> Tuple[torch.Tensor, torch.Tensor]: ++ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) ++ query_ = torch.view_as_complex(query.float().reshape( ++ *query.shape[:-1], -1, 2)) ++ key_ = torch.view_as_complex(key.float().reshape( ++ *key.shape[:-1], -1, 2)) ++ broadcast_shape = [ ++ d if i == 1 or i == (query_.ndim - 1) else 1 ++ for i, d in enumerate(query_.shape) ++ ] ++ freqs_ci = self.cos_sin_cache.view(*broadcast_shape) ++ query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) ++ key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) ++ return query_out.type_as(query), key_out.type_as(key) ++ ++ + class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + +@@ -1130,6 +1194,10 @@ def get_rope( + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) ++ elif scaling_type == "mllama4": ++ rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, ++ max_position, base, ++ is_neox_style, dtype) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( +diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py +index 81b5d9bda9ac..4a5982ecbcb6 100644 +--- a/vllm/model_executor/models/llama.py ++++ b/vllm/model_executor/models/llama.py +@@ -65,6 +65,7 @@ def __init__( + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", ++ reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( +@@ -79,6 +80,7 @@ def __init__( + output_size=hidden_size, + bias=bias, + quant_config=quant_config, ++ reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": +@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", +- "norm": "model.norm" ++ "norm": "model.norm", + } + +- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ++ def __init__(self, ++ *, ++ vllm_config: VllmConfig, ++ prefix: str = "", ++ layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config +@@ -478,7 +484,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.lora_config = lora_config + + self.model = self._init_model(vllm_config=vllm_config, +- prefix=maybe_prefix(prefix, "model")) ++ prefix=maybe_prefix(prefix, "model"), ++ layer_type=layer_type) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size +@@ -513,8 +520,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + +- def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): +- return LlamaModel(vllm_config=vllm_config, prefix=prefix) ++ def _init_model(self, ++ vllm_config: VllmConfig, ++ prefix: str = "", ++ layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): ++ return LlamaModel(vllm_config=vllm_config, ++ prefix=prefix, ++ layer_type=layer_type) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) +diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py +new file mode 100644 +index 000000000000..27c872072041 +--- /dev/null ++++ b/vllm/model_executor/models/llama4.py +@@ -0,0 +1,530 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# ++# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. ++# All rights reserved. ++# ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++"""Inference-only LLaMA model compatible with HuggingFace weights.""" ++from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type ++ ++import torch ++from torch import nn ++from transformers import Llama4TextConfig ++ ++from vllm.attention import Attention ++from vllm.compilation.decorators import support_torch_compile ++from vllm.config import CacheConfig, VllmConfig ++from vllm.distributed import (get_tensor_model_parallel_world_size, ++ tensor_model_parallel_all_reduce) ++from vllm.model_executor.layers.fused_moe import FusedMoE ++from vllm.model_executor.layers.layernorm import RMSNorm ++from vllm.model_executor.layers.linear import (QKVParallelLinear, ++ ReplicatedLinear, ++ RowParallelLinear) ++from vllm.model_executor.layers.quantization import QuantizationConfig ++from vllm.model_executor.layers.rotary_embedding import get_rope ++from vllm.model_executor.model_loader.weight_utils import default_weight_loader ++ ++from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel ++from .utils import (AutoWeightsLoader, extract_layer_index, ++ is_pp_missing_parameter) ++ ++ ++class Llama4MoE(nn.Module): ++ ++ @staticmethod ++ def custom_routing_function( ++ hidden_states: torch.Tensor, ++ gating_output: torch.Tensor, ++ topk: int, ++ renormalize: bool, ++ ) -> Tuple[torch.Tensor, torch.Tensor]: ++ router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) ++ router_scores = torch.sigmoid(router_scores.float()).to( ++ hidden_states.dtype) ++ return (router_scores, router_indices.to(torch.int32)) ++ ++ def __init__(self, ++ config: Llama4TextConfig, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = ""): ++ super().__init__() ++ self.tp_size = get_tensor_model_parallel_world_size() ++ self.top_k = config.num_experts_per_tok ++ ++ intermediate_size_moe = config.intermediate_size ++ self.router = ReplicatedLinear(config.hidden_size, ++ config.num_local_experts, ++ bias=False, ++ quant_config=None, ++ prefix=f"{prefix}.router") ++ ++ self.experts = FusedMoE( ++ num_experts=config.num_local_experts, ++ top_k=config.num_experts_per_tok, ++ hidden_size=config.hidden_size, ++ custom_routing_function=Llama4MoE.custom_routing_function, ++ intermediate_size=intermediate_size_moe, ++ apply_router_weight_on_input=True, ++ reduce_results=False, ++ renormalize=False, ++ quant_config=quant_config, ++ prefix=f"{prefix}.experts") ++ ++ self.shared_expert = LlamaMLP( ++ hidden_size=config.hidden_size, ++ intermediate_size=intermediate_size_moe, ++ hidden_act="silu", ++ quant_config=quant_config, ++ bias=False, ++ prefix=f"{prefix}.shared_expert", ++ reduce_results=False, # We need to do scatter before reduce ++ ) ++ ++ def forward(self, hidden_states): ++ router_logits, _ = self.router(hidden_states) ++ shared_out = self.shared_expert(hidden_states) ++ routed_out = self.experts( ++ hidden_states=hidden_states, ++ router_logits=router_logits, ++ ) ++ experts_out = routed_out + shared_out ++ ++ if self.tp_size > 1: ++ experts_out = tensor_model_parallel_all_reduce(experts_out) ++ ++ return experts_out ++ ++ ++class Llama4Attention(nn.Module): ++ ++ def __init__(self, ++ config: Llama4TextConfig, ++ hidden_size: int, ++ num_heads: int, ++ num_kv_heads: int, ++ rope_theta: float = 10000, ++ rope_scaling: Optional[Dict[str, Any]] = None, ++ max_position_embeddings: int = 8192, ++ quant_config: Optional[QuantizationConfig] = None, ++ bias: bool = False, ++ bias_o_proj: bool = False, ++ cache_config: Optional[CacheConfig] = None, ++ prefix: str = "") -> None: ++ super().__init__() ++ self.layer_idx = extract_layer_index(prefix) ++ self.hidden_size = hidden_size ++ self.no_rope_layers = config.no_rope_layers ++ self.nope = self.no_rope_layers[self.layer_idx] == 0 ++ self.use_qk_norm = config.use_qk_norm and not self.nope ++ tp_size = get_tensor_model_parallel_world_size() ++ self.total_num_heads = num_heads ++ assert self.total_num_heads % tp_size == 0 ++ self.num_heads = self.total_num_heads // tp_size ++ self.total_num_kv_heads = num_kv_heads ++ if self.total_num_kv_heads >= tp_size: ++ # Number of KV heads is greater than TP size, so we partition ++ # the KV heads across multiple tensor parallel GPUs. ++ assert self.total_num_kv_heads % tp_size == 0 ++ else: ++ # Number of KV heads is less than TP size, so we replicate ++ # the KV heads across multiple tensor parallel GPUs. ++ assert tp_size % self.total_num_kv_heads == 0 ++ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) ++ self.head_dim = config.head_dim ++ self.q_size = self.num_heads * self.head_dim ++ self.kv_size = self.num_kv_heads * self.head_dim ++ self.scaling = self.head_dim**-0.5 ++ # TODO: attn_temperature_tuning should be a bool in huggingface ++ self.attn_temperature_tuning = self.nope and \ ++ config.attn_temperature_tuning > 0 ++ ++ self.floor_scale = getattr(config, "floor_scale", 8192.0) ++ self.attn_scale = getattr(config, "attn_scale", 0.1) ++ self.rope_theta = rope_theta ++ self.max_position_embeddings = max_position_embeddings ++ self.n_rep = self.num_heads // self.num_kv_heads ++ self.q_norm = RMSNorm( ++ hidden_size=self.q_size, ++ eps=config.rms_norm_eps, ++ has_weight=False, ++ dtype=torch.float32, ++ ) if self.use_qk_norm else None ++ self.k_norm = RMSNorm( ++ hidden_size=self.kv_size, ++ eps=config.rms_norm_eps, ++ has_weight=False, ++ dtype=torch.float32, ++ ) if self.use_qk_norm else None ++ self.qkv_proj = QKVParallelLinear( ++ hidden_size=hidden_size, ++ head_size=self.head_dim, ++ total_num_heads=self.total_num_heads, ++ total_num_kv_heads=self.total_num_kv_heads, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.qkv_proj", ++ ) ++ ++ self.o_proj = RowParallelLinear( ++ input_size=self.total_num_heads * self.head_dim, ++ output_size=hidden_size, ++ bias=bias_o_proj, ++ quant_config=quant_config, ++ prefix=f"{prefix}.o_proj", ++ ) ++ is_neox_style = True ++ is_gguf = quant_config and quant_config.get_name() == "gguf" ++ if is_gguf and config.model_type == "llama": ++ is_neox_style = False ++ ++ self.rotary_emb = get_rope( ++ self.head_dim, ++ rotary_dim=self.head_dim, ++ max_position=max_position_embeddings, ++ base=int(rope_theta), ++ rope_scaling=rope_scaling if rope_scaling != "default" else None, ++ is_neox_style=is_neox_style, ++ ) if not self.nope else None ++ ++ self.attn = Attention( ++ self.num_heads, ++ self.head_dim, ++ self.scaling, ++ num_kv_heads=self.num_kv_heads, ++ cache_config=cache_config, ++ quant_config=quant_config, ++ per_layer_sliding_window=None, ++ use_irope=not self.nope, ++ prefix=f"{prefix}.attn", ++ ) ++ ++ def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: ++ floor = torch.floor((positions + 1.0) / self.floor_scale) ++ attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 ++ ++ return attn_scale.unsqueeze(-1) ++ ++ def forward( ++ self, ++ positions: torch.Tensor, ++ hidden_states: torch.Tensor, ++ ) -> torch.Tensor: ++ qkv, _ = self.qkv_proj(hidden_states) ++ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) ++ ++ if self.rotary_emb is not None: ++ q, k = self.rotary_emb(positions, q, k) ++ if self.q_norm is not None: ++ q = self.q_norm(q.float()).to(q.dtype) ++ if self.k_norm is not None: ++ k = self.k_norm(k.float()).to(k.dtype) ++ ++ # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) ++ # to NoPE layers, where the inference-time temperature tuning function ++ # is customized to not affect short context ++ # while working at very long context ++ # https://arxiv.org/abs/2501.19399 ++ # ++ # We should apply temperature tuning between (after) rotary / QK norm ++ # and (before) attention. ++ if self.attn_temperature_tuning and self.nope: ++ attn_scale = self._get_attn_scale(positions) ++ q = (q * attn_scale).to(q.dtype) ++ attn_output = self.attn(q, k, v) ++ output, _ = self.o_proj(attn_output) ++ return output ++ ++ ++class Llama4DecoderLayer(LlamaDecoderLayer): ++ ++ def __init__( ++ self, ++ config: Llama4TextConfig, ++ cache_config: Optional[CacheConfig] = None, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "", ++ ) -> None: ++ self.layer_idx = extract_layer_index(prefix) ++ nn.Module.__init__(self) ++ self.hidden_size = config.hidden_size ++ rope_theta = config.rope_theta ++ rope_scaling = config.rope_scaling ++ max_position_embeddings = config.max_position_embeddings ++ ++ self.self_attn = Llama4Attention( ++ config=config, ++ hidden_size=self.hidden_size, ++ num_heads=config.num_attention_heads, ++ num_kv_heads=config.num_key_value_heads, ++ rope_theta=rope_theta, ++ rope_scaling=rope_scaling, ++ max_position_embeddings=max_position_embeddings, ++ quant_config=quant_config, ++ bias=False, ++ bias_o_proj=False, ++ cache_config=cache_config, ++ prefix=f"{prefix}.self_attn", ++ ) ++ is_moe_layer = (self.layer_idx + ++ 1) % config.interleave_moe_layer_step == 0 ++ if is_moe_layer: ++ self.feed_forward = Llama4MoE( ++ config=config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.feed_forward", ++ ) ++ else: ++ self.feed_forward = LlamaMLP( ++ hidden_size=self.hidden_size, ++ intermediate_size=config.intermediate_size_mlp, ++ hidden_act="silu", ++ quant_config=quant_config, ++ bias=False, ++ prefix=f"{prefix}.feed_forward", ++ ) ++ self.input_layernorm = RMSNorm(config.hidden_size, ++ eps=config.rms_norm_eps) ++ self.post_attention_layernorm = RMSNorm(config.hidden_size, ++ eps=config.rms_norm_eps) ++ ++ def forward( ++ self, ++ positions: torch.Tensor, ++ hidden_states: torch.Tensor, ++ residual: Optional[torch.Tensor], ++ ) -> Tuple[torch.Tensor, torch.Tensor]: ++ # Self Attention ++ if residual is None: ++ residual = hidden_states ++ hidden_states = self.input_layernorm(hidden_states) ++ else: ++ hidden_states, residual = self.input_layernorm( ++ hidden_states, residual) ++ hidden_states = self.self_attn(positions=positions, ++ hidden_states=hidden_states) ++ ++ # Fully Connected ++ hidden_states, residual = self.post_attention_layernorm( ++ hidden_states, residual) ++ hidden_states = self.feed_forward(hidden_states) ++ return hidden_states, residual ++ ++ ++@support_torch_compile ++class Llama4Model(LlamaModel): ++ ++ def __init__(self, ++ *, ++ vllm_config: VllmConfig, ++ prefix: str = "", ++ layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): ++ self.num_experts = vllm_config.model_config.hf_config.num_local_experts ++ super().__init__(vllm_config=vllm_config, ++ prefix=prefix, ++ layer_type=layer_type) ++ ++ def load_moe_expert_weights( ++ self, ++ name: str, ++ loaded_weight: torch.Tensor, ++ params_dict: Dict[str, nn.Parameter], ++ loaded_params: Set[str], ++ expert_params_mapping: List[Tuple[str, str, int, str]], ++ fused: bool = True, ++ ) -> bool: ++ expert_param_loaded = False ++ if "experts.gate_up_proj" in name: ++ loaded_weight = loaded_weight.chunk(2, dim=-1) ++ for (param_name, weight_name, expert_id, ++ shard_id) in expert_params_mapping: ++ new_loaded_weight = loaded_weight ++ if fused: ++ e_str, _, proj_str, _ = weight_name.split('.') ++ weight_name = f"{e_str}.{proj_str}" ++ param_name = f"{param_name}weight" ++ if weight_name not in name: ++ continue ++ full_param_name = name.replace(weight_name, param_name) ++ # Skip layers on other devices. ++ if is_pp_missing_parameter(name, self): ++ continue ++ if ((name.endswith(".bias") or name.endswith("_bias")) ++ and name not in params_dict): ++ continue ++ param = params_dict[full_param_name] ++ weight_loader = param.weight_loader ++ if fused: ++ if "w13" in full_param_name: ++ shard_idx = 0 if shard_id == "w1" else 1 ++ new_loaded_weight = new_loaded_weight[shard_idx] ++ new_loaded_weight = new_loaded_weight.transpose(-1, -2) ++ layer_idx = extract_layer_index(name) ++ # EP mapping ++ expert_map = self.layers[ ++ layer_idx].feed_forward.experts.expert_map ++ if expert_map is not None: ++ local_expert_indices = (expert_map != -1) \ ++ .nonzero() \ ++ .flatten() \ ++ .to(new_loaded_weight.device) ++ new_loaded_weight = new_loaded_weight[local_expert_indices] ++ expert_id = local_expert_indices[0].item() ++ else: ++ # TODO: add EP support for non fused weights ++ pass ++ weight_loader(param, ++ new_loaded_weight, ++ full_param_name, ++ shard_id=shard_id, ++ expert_id=expert_id) ++ ++ loaded_params.add(full_param_name) ++ expert_param_loaded = True ++ return expert_param_loaded ++ ++ def load_weights(self, weights: Iterable[Tuple[str, ++ torch.Tensor]]) -> Set[str]: ++ stacked_params_mapping = [ ++ # (param_name, shard_name, shard_id) ++ (".qkv_proj", ".q_proj", "q"), ++ (".qkv_proj", ".k_proj", "k"), ++ (".qkv_proj", ".v_proj", "v"), ++ (".gate_up_proj", ".gate_proj", 0), ++ (".gate_up_proj", ".up_proj", 1), ++ ] ++ fused_experts_params = False ++ expert_params_mapping = FusedMoE.make_expert_params_mapping( ++ ckpt_gate_proj_name="gate_proj", ++ ckpt_down_proj_name="down_proj", ++ ckpt_up_proj_name="up_proj", ++ num_experts=self.num_experts) ++ expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( ++ ckpt_gate_proj_name="gate_up_proj", ++ ckpt_down_proj_name="down_proj", ++ ckpt_up_proj_name="gate_up_proj", ++ num_experts=1) ++ params_dict = dict(self.named_parameters()) ++ loaded_params: Set[str] = set() ++ for name, loaded_weight in weights: ++ if "experts.gate_up_proj" in name or "experts.down_proj" in name: ++ fused_experts_params = True ++ expert_params_mapping = expert_params_mapping_fused ++ if (self.quant_config is not None and ++ (scale_name := self.quant_config.get_cache_scale(name))): ++ # Loading kv cache quantization scales ++ param = params_dict[scale_name] ++ weight_loader = getattr(param, "weight_loader", ++ default_weight_loader) ++ loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else ++ loaded_weight[0]) ++ weight_loader(param, loaded_weight) ++ loaded_params.add(scale_name) ++ continue ++ for param_name, weight_name, shard_id in stacked_params_mapping: ++ if weight_name not in name or "experts" in name: ++ continue ++ name = name.replace(weight_name, param_name) ++ if is_pp_missing_parameter(name, self): ++ continue ++ param = params_dict[name] ++ weight_loader = param.weight_loader ++ weight_loader(param, loaded_weight, shard_id) ++ loaded_params.add(name) ++ break ++ else: ++ moe_loaded = self.load_moe_expert_weights( ++ name, ++ loaded_weight, ++ params_dict, ++ loaded_params, ++ expert_params_mapping, ++ fused=fused_experts_params) ++ ++ if not moe_loaded: ++ if is_pp_missing_parameter(name, self): ++ continue ++ param = params_dict[name] ++ weight_loader = getattr(param, "weight_loader", ++ default_weight_loader) ++ weight_loader(param, loaded_weight) ++ loaded_params.add(name) ++ return loaded_params ++ ++ ++class Llama4ForCausalLM(LlamaForCausalLM): ++ ++ packed_modules_mapping = { ++ "qkv_proj": ["q_proj", "k_proj", "v_proj"], ++ "gate_up_proj": ["gate_proj", "up_proj"], ++ } ++ ++ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ++ # Update temperature tuning config from generation config ++ gen_config = vllm_config.model_config.try_get_generation_config() ++ gen_config.update(vllm_config.model_config.override_generation_config) ++ vllm_config.model_config.hf_config.attn_temperature_tuning \ ++ = gen_config.get("attn_temperature_tuning", False) ++ LlamaForCausalLM.__init__(self, ++ vllm_config=vllm_config, ++ prefix=prefix, ++ layer_type=Llama4DecoderLayer) ++ ++ def _init_model(self, ++ vllm_config: VllmConfig, ++ prefix: str = "", ++ layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): ++ return Llama4Model(vllm_config=vllm_config, ++ prefix=prefix, ++ layer_type=layer_type) ++ ++ def load_weights(self, weights: Iterable[Tuple[str, ++ torch.Tensor]]) -> Set[str]: ++ loader = AutoWeightsLoader( ++ self, ++ skip_prefixes=(["lm_head."] ++ if self.config.tie_word_embeddings else None), ++ ) ++ weights = [ ++ self.permute_qk_weight_for_rotary(name, loaded_weight) ++ for name, loaded_weight in weights ++ ] ++ return loader.load_weights(weights) ++ ++ def permute_qk_weight_for_rotary( ++ self, ++ name: str, ++ loaded_weight: torch.Tensor, ++ ) -> Tuple[str, torch.Tensor]: ++ ++ def permute(w: torch.Tensor, n_heads: int): ++ attn_in = self.config.head_dim * n_heads ++ attn_out = self.config.hidden_size ++ ++ return w.view(n_heads, attn_in // n_heads // 2, 2, ++ attn_out).transpose(1, 2).reshape(attn_in, attn_out) ++ ++ modules = name.split(".") ++ ++ # rotary embeds should be sliced ++ if ("wk" in modules or "k_proj" in modules) \ ++ and modules[-1] == "weight": ++ loaded_weight = permute(loaded_weight, ++ self.config.num_key_value_heads) ++ elif ("wq" in modules or "q_proj" in modules) \ ++ and modules[-1] == "weight": ++ loaded_weight = permute(loaded_weight, ++ self.config.num_attention_heads) ++ ++ return name, loaded_weight +diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py +new file mode 100644 +index 000000000000..012178c7c093 +--- /dev/null ++++ b/vllm/model_executor/models/mllama4.py +@@ -0,0 +1,886 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# ++# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. ++# All rights reserved. ++# ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++import math ++from collections.abc import Iterable, Mapping ++from itertools import tee ++from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union ++ ++import torch ++from torch import nn ++from transformers import BatchFeature, Llama4Config, Llama4VisionConfig ++from transformers.image_utils import SizeDict ++from transformers.modeling_outputs import BaseModelOutput ++from transformers.models.llama4 import Llama4Processor ++from transformers.models.llama4.image_processing_llama4_fast import ( ++ find_supported_resolutions, get_best_fit) ++ ++from vllm.attention.layer import MultiHeadAttention ++from vllm.config import VllmConfig ++from vllm.distributed import get_tensor_model_parallel_world_size ++from vllm.inputs import InputProcessingContext ++from vllm.logger import init_logger ++from vllm.model_executor.layers.linear import (ColumnParallelLinear, ++ QKVParallelLinear, ++ RowParallelLinear) ++from vllm.model_executor.layers.quantization import QuantizationConfig ++from vllm.model_executor.layers.rotary_embedding import get_rope ++from vllm.model_executor.layers.sampler import SamplerOutput ++from vllm.model_executor.model_loader.weight_utils import default_weight_loader ++from vllm.model_executor.sampling_metadata import SamplingMetadata ++from vllm.multimodal import MULTIMODAL_REGISTRY ++from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ++ NestedTensors) ++from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ++ MultiModalDataItems) ++from vllm.multimodal.processing import (BaseMultiModalProcessor, ++ BaseProcessingInfo, PromptReplacement, ++ PromptUpdate) ++from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ++from vllm.sequence import IntermediateTensors ++from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config ++ ++from .interfaces import MultiModalEmbeddings, SupportsMultiModal ++from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, ++ maybe_prefix, merge_multimodal_embeddings) ++from .vision import scatter_patch_features, select_patch_features ++ ++logger = init_logger(__name__) ++ ++ ++class Llama4ImagePatchInputs(TypedDict): ++ type: Literal["pixel_values"] ++ flat_data: torch.Tensor ++ """ ++ Shape: ++ `(batch_size * num_chunks, num_channels, image size, image size)` ++ """ ++ patches_per_image: torch.Tensor ++ """ ++ The number of total patches for each image in the batch. ++ ++ This is used to split the embeddings which has the first two dimensions ++ flattened just like `flat_data`. ++ """ ++ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] ++ """ ++ A boolean mask indicating which image embeddings correspond ++ to patch tokens. ++ """ ++ aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] ++ """ ++ A list of aspect ratios corresponding to the number of tiles ++ in each dimension that each image in the batch corresponds to. ++ ++ Shape: ++ `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` ++ """ ++ ++ ++class Llama4VisionMLP(nn.Module): ++ ++ def __init__(self, ++ input_size: int, ++ intermediate_size: int, ++ output_size: int, ++ bias: bool, ++ output_activation: bool, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = ""): ++ super().__init__() ++ self.fc1 = ColumnParallelLinear( ++ input_size=input_size, ++ output_size=intermediate_size, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.fc1", ++ ) ++ self.fc2 = RowParallelLinear( ++ input_size=intermediate_size, ++ output_size=output_size, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.fc2", ++ ) ++ self.activation_fn = nn.GELU() ++ self.output_activation = output_activation ++ ++ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ++ hidden_states, _ = self.fc1(hidden_states) ++ hidden_states = self.activation_fn(hidden_states) ++ hidden_states, _ = self.fc2(hidden_states) ++ if self.output_activation: ++ return self.activation_fn(hidden_states) ++ return hidden_states ++ ++ ++class Llama4MultiModalProjector(nn.Module): ++ ++ def __init__( ++ self, ++ config, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "", ++ ): ++ super().__init__() ++ self.linear_1 = ColumnParallelLinear( ++ input_size=config.vision_config.vision_output_dim, ++ output_size=config.text_config.hidden_size, ++ bias=False, ++ quant_config=quant_config, ++ gather_output=True, ++ prefix=f"{prefix}.linear_1", ++ ) ++ ++ def forward(self, image_features): ++ hidden_states, _ = self.linear_1(image_features) ++ return hidden_states ++ ++ ++def pixel_shuffle(input_tensor, shuffle_ratio): ++ # input_tensor: [batch_size, num_patches, channels] ++ batch_size, num_patches, channels = input_tensor.shape ++ patch_size = int(math.sqrt(num_patches)) ++ ++ input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) ++ batch_size, height, width, channels = input_tensor.size() ++ ++ reshaped_tensor = input_tensor.view(batch_size, height, ++ int(width * shuffle_ratio), ++ int(channels / shuffle_ratio)) ++ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() ++ ++ reshaped_tensor = reshaped_tensor.view(batch_size, ++ int(height * shuffle_ratio), ++ int(width * shuffle_ratio), ++ int(channels / (shuffle_ratio**2))) ++ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() ++ ++ output_tensor = reshaped_tensor.view(batch_size, -1, ++ reshaped_tensor.shape[-1]) ++ return output_tensor ++ ++ ++class Llama4VisionPixelShuffleMLP(nn.Module): ++ ++ def __init__( ++ self, ++ config, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "", ++ ): ++ super().__init__() ++ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio ++ self.inner_dim = int(config.projector_input_dim // ++ (self.pixel_shuffle_ratio**2)) ++ self.output_dim = config.projector_output_dim ++ self.mlp = Llama4VisionMLP( ++ input_size=config.intermediate_size, ++ intermediate_size=config.projector_input_dim, ++ output_size=config.projector_output_dim, ++ bias=config.multi_modal_projector_bias, ++ output_activation=True, ++ quant_config=quant_config, ++ prefix=f"{prefix}.mlp") ++ ++ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: ++ encoded_patches = pixel_shuffle(encoded_patches, ++ self.pixel_shuffle_ratio) ++ return self.mlp(encoded_patches) ++ ++ ++class Llama4VisionAttention(nn.Module): ++ ++ def __init__( ++ self, ++ config: Llama4VisionConfig, ++ quant_config: Optional[QuantizationConfig], ++ prefix: str = "", ++ ): ++ super().__init__() ++ self.config = config ++ self.tp_size = get_tensor_model_parallel_world_size() ++ self.embed_dim = config.hidden_size ++ self.num_heads = config.num_attention_heads ++ self.head_dim = config.hidden_size // self.num_heads ++ assert self.num_heads % self.tp_size == 0 ++ self.num_local_heads = self.num_heads // self.tp_size ++ self.q_size = self.num_local_heads * self.head_dim ++ self.kv_size = self.num_local_heads * self.head_dim ++ self.attention_dropout = config.attention_dropout ++ self.scaling = self.head_dim**-0.5 ++ ++ self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, ++ self.scaling) ++ self.qkv_proj = QKVParallelLinear( ++ self.embed_dim, ++ self.head_dim, ++ self.num_heads, ++ bias=True, ++ quant_config=quant_config, ++ prefix=f"{prefix}.qkv_proj", ++ ) ++ self.o_proj = RowParallelLinear( ++ self.num_heads * self.head_dim, ++ self.embed_dim, ++ bias=True, ++ input_is_parallel=True, ++ quant_config=quant_config, ++ prefix=f"{prefix}.o_proj", ++ ) ++ ++ self.rotary_emb = get_rope( ++ head_size=self.head_dim, ++ rotary_dim=config.hidden_size // config.num_attention_heads // 2, ++ # number of image patches ++ max_position=(config.image_size // config.patch_size)**2, ++ base=config.rope_theta, ++ rope_scaling={"rope_type": "mllama4"}, ++ is_neox_style=False, ++ dtype=torch.complex64, # important ++ ) ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ ) -> torch.Tensor: ++ input_shape = hidden_states.shape[:-1] ++ ++ qkv, _ = self.qkv_proj(hidden_states) ++ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) ++ ++ q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim) ++ k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim) ++ q, k = self.rotary_emb(q, k) ++ ++ q = q.view(q.shape[0], q.shape[1], -1) ++ k = k.view(k.shape[0], k.shape[1], -1) ++ ++ attn_output = self.attn(q, k, v) ++ attn_output = attn_output.reshape(*input_shape, -1).contiguous() ++ attn_output, _ = self.o_proj(attn_output) ++ ++ return attn_output ++ ++ ++class Llama4VisionEncoderLayer(nn.Module): ++ ++ def __init__( ++ self, ++ config: Llama4VisionConfig, ++ quant_config: Optional[QuantizationConfig], ++ prefix: str = "", ++ ): ++ super().__init__() ++ self.hidden_size = config.hidden_size ++ self.num_attention_heads = config.num_attention_heads ++ self.intermediate_size = config.intermediate_size ++ ++ self.self_attn = Llama4VisionAttention(config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.self_attn") ++ self.mlp = Llama4VisionMLP(input_size=config.hidden_size, ++ intermediate_size=config.intermediate_size, ++ output_size=config.hidden_size, ++ bias=True, ++ output_activation=False, ++ quant_config=quant_config, ++ prefix=f"{prefix}.mlp") ++ ++ self.input_layernorm = nn.LayerNorm(config.hidden_size) ++ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) ++ ++ def forward( ++ self, ++ hidden_state: torch.Tensor, ++ ): ++ # Self Attention ++ residual = hidden_state ++ hidden_state = self.input_layernorm(hidden_state) ++ hidden_state = self.self_attn(hidden_state) ++ hidden_state = residual + hidden_state ++ ++ # Feed forward ++ residual = hidden_state ++ hidden_state = self.post_attention_layernorm(hidden_state) ++ hidden_state = self.mlp(hidden_state) ++ hidden_state = residual + hidden_state ++ ++ outputs = (hidden_state, ) ++ return outputs ++ ++ ++class Llama4VisionEncoder(nn.Module): ++ ++ def __init__( ++ self, ++ config: Llama4VisionConfig, ++ quant_config: Optional[QuantizationConfig], ++ prefix: str = "", ++ ): ++ super().__init__() ++ self.config = config ++ self.layers = nn.ModuleList([ ++ Llama4VisionEncoderLayer( ++ config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.layers.{layer_idx}", ++ ) for layer_idx in range(config.num_hidden_layers) ++ ]) ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ ) -> BaseModelOutput: ++ r""" ++ Args: ++ inputs_embeds (`torch.FloatTensor` of shape ++ `(batch_size, sequence_length, hidden_size)`): ++ Optionally, instead of passing `input_ids` you can choose to ++ directly pass an embedded representation. This is useful if you ++ want more control over how to convert `input_ids` indices into ++ associated vectors than the model's internal embedding ++ lookup matrix. ++ """ ++ ++ for encoder_layer in self.layers: ++ layer_outputs = encoder_layer(hidden_states) ++ hidden_states = layer_outputs[0] ++ ++ return BaseModelOutput(last_hidden_state=hidden_states, ) ++ ++ ++class Llama4UnfoldConvolution(nn.Module): ++ ++ def __init__(self, ++ config: Llama4VisionConfig, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = ""): ++ super().__init__() ++ kernel_size = config.patch_size ++ if isinstance(kernel_size, int): ++ kernel_size = (kernel_size, kernel_size) ++ self.unfold = torch.nn.Unfold(kernel_size=kernel_size, ++ stride=config.patch_size) ++ self.linear = ColumnParallelLinear(config.num_channels * ++ kernel_size[0] * kernel_size[1], ++ config.hidden_size, ++ bias=False, ++ quant_config=quant_config, ++ gather_output=True, ++ prefix=f"{prefix}.linear") ++ ++ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ++ hidden_states = self.unfold(hidden_states) ++ hidden_states = hidden_states.permute(0, 2, 1) ++ hidden_states, _ = self.linear(hidden_states) ++ return hidden_states ++ ++ ++class Llama4VisionModel(nn.Module): ++ ++ def __init__( ++ self, ++ config: Llama4VisionConfig, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "", ++ ): ++ super().__init__() ++ self.config = config ++ self.image_size = config.image_size ++ self.patch_size = config.patch_size ++ self.hidden_size = config.hidden_size ++ self.num_channels = config.num_channels ++ ++ self.num_patches = (self.image_size // self.patch_size)**2 + 1 ++ self.scale = config.hidden_size**-0.5 ++ ++ self.patch_embedding = Llama4UnfoldConvolution( ++ config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.patch_embedding") ++ ++ self.class_embedding = nn.Parameter(self.scale * ++ torch.randn(self.hidden_size)) ++ self.positional_embedding_vlm = nn.Parameter( ++ self.scale * torch.randn(self.num_patches, self.hidden_size)) ++ ++ # layer norms ++ self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5) ++ self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5) ++ ++ # encoders ++ self.model = Llama4VisionEncoder(config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.model") ++ self.vision_adapter = Llama4VisionPixelShuffleMLP( ++ config, quant_config, prefix=f"{prefix}.vision_adapter") ++ ++ def forward( ++ self, ++ images_flattened: torch.Tensor, ++ ) -> BaseModelOutput: ++ # Patch embedding ++ hidden_state = self.patch_embedding(images_flattened) ++ num_tiles, num_patches, hidden_dim = hidden_state.shape ++ ++ # Add cls token ++ class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, ++ hidden_state.shape[-1]) ++ hidden_state = torch.cat([hidden_state, class_embedding], dim=1) ++ num_patches += 1 ++ ++ # Position embeddings ++ hidden_state = hidden_state.reshape( ++ num_tiles, ++ 1, ++ num_patches, ++ hidden_dim, ++ ) ++ positional_embedding = self.positional_embedding_vlm.to( ++ dtype=hidden_state.dtype, device=hidden_state.device) ++ hidden_state = hidden_state + positional_embedding ++ hidden_state = self.layernorm_pre(hidden_state) ++ hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) ++ ++ # Apply encoder ++ output = self.model(hidden_state) ++ hidden_state = output.last_hidden_state ++ hidden_state = self.layernorm_post(hidden_state) ++ ++ # Remove CLS token output ++ hidden_state = hidden_state[:, :-1, :] ++ ++ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings ++ hidden_state = self.vision_adapter(hidden_state) ++ ++ return BaseModelOutput( ++ last_hidden_state=hidden_state, ++ attentions=None, ++ ) ++ ++ ++class Mllama4ProcessingInfo(BaseProcessingInfo): ++ ++ def __init__(self, ctx: InputProcessingContext) -> None: ++ super().__init__(ctx) ++ ++ def get_hf_config(self) -> Llama4Config: ++ return self.ctx.get_hf_config(Llama4Config) ++ ++ def get_hf_processor(self, **kwargs: object) -> Llama4Processor: ++ return self.ctx.get_hf_processor(Llama4Processor, ++ use_fast=True, ++ **kwargs) ++ ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": 10} ++ ++ @staticmethod ++ def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int: ++ image_size = vision_config.image_size ++ patch_size = vision_config.patch_size ++ ++ assert ( ++ image_size % ++ patch_size == 0), f"chunk size {image_size} should be multiple of " ++ f"patch_size {patch_size}" ++ ++ ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2))) ++ return (image_size // patch_size)**2 // ds_ratio ++ ++ def get_max_num_tiles(self) -> int: ++ image_processor = self.get_hf_processor().image_processor ++ return image_processor.max_patches ++ ++ def get_mm_max_tokens_per_item( ++ self, ++ seq_len: int, ++ mm_counts: Mapping[str, int], ++ ) -> Mapping[str, int]: ++ vision_config = self.get_hf_config().vision_config ++ # image_start + local tiles * (patches + 1 x separator) + ++ # 1 global tile * (image x 1 + patches) + image_end ++ token_per_chunk = self.get_patch_per_chunk(vision_config) + 1 ++ mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2 ++ return {"image": mm_max_tokens} ++ ++ def get_image_size_with_most_features(self) -> ImageSize: ++ vision_config = self.get_hf_config().vision_config ++ image_size = vision_config.image_size ++ # Result in the max possible feature size (h:w = 16:1) ++ return ImageSize(height=self.get_max_num_tiles() * image_size, ++ width=image_size) ++ ++ ++class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ++ ): ++ ++ def _call_hf_processor( ++ self, ++ prompt: str, ++ mm_data: Mapping[str, object], ++ mm_kwargs: Mapping[str, object], ++ ) -> BatchFeature: ++ tokenizer = self.info.get_tokenizer() ++ ++ if mm_data is None: ++ return tokenizer(prompt, add_special_tokens=False) # exclude bos ++ processed_outputs = super()._call_hf_processor( ++ prompt=prompt, ++ mm_data=mm_data, ++ mm_kwargs=mm_kwargs, ++ ) ++ ++ processor = self.info.get_hf_processor(**mm_kwargs) ++ image_processor = processor.image_processor ++ vision_config = self.info.get_hf_config().vision_config ++ ++ if processed_outputs.get("pixel_values") is not None: ++ assert "images" in mm_data, \ ++ "images expected to be in mm_data when pixel_values is present" ++ ++ images = mm_data["images"] ++ parsed_images = (self._get_data_parser().parse_mm_data({ ++ "image": ++ images ++ }).get_items("image", ImageProcessorItems)) ++ ++ tile_size = vision_config.image_size ++ possible_resolutions = find_supported_resolutions( ++ max_num_chunks=self.info.get_max_num_tiles(), ++ patch_size=SizeDict(height=tile_size, width=tile_size), ++ ) ++ best_fit_sizes = [ ++ get_best_fit( ++ (image.size[1], image.size[0]), ++ torch.tensor(possible_resolutions), ++ resize_to_max_canvas=image_processor.resize_to_max_canvas) ++ for image in parsed_images ++ ] ++ # TODO tile height/width do not necessarily need to match ++ aspect_ratios = [(image_size[0] // tile_size, ++ image_size[1] // tile_size) ++ for image_size in best_fit_sizes] ++ patches_per_image = [ ++ 1 if r_h * r_w == 1 else 1 + r_h * r_w ++ for (r_h, r_w) in aspect_ratios ++ ] ++ ++ # embed_is_patch should have one feature per image-related token: ++ # <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|> ++ # -> False ++ # <|patch|> -> True ++ # embed_is_patch has no entries corresponding to non-image-related ++ # tokens. ++ patch_id = tokenizer.get_vocab()[processor.img_patch_token] ++ num_patches_per_chunk = self.info.get_patch_per_chunk( ++ vision_config) ++ expanded_image_tokens_list = [ ++ processor._prompt_split_image(aspect_ratio, ++ num_patches_per_chunk) ++ for aspect_ratio in aspect_ratios ++ ] ++ expanded_image_token_ids = [ ++ tokenizer.encode(image_tokens, add_special_tokens=False) ++ for image_tokens in expanded_image_tokens_list ++ ] ++ embed_is_patch = [ ++ torch.tensor(tokens) == patch_id ++ for tokens in expanded_image_token_ids ++ ] ++ ++ processed_outputs["aspect_ratios"] = aspect_ratios ++ processed_outputs["patches_per_image"] = torch.tensor( ++ patches_per_image) ++ processed_outputs["embed_is_patch"] = embed_is_patch ++ ++ return processed_outputs ++ ++ def _get_mm_fields_config( ++ self, ++ hf_inputs: BatchFeature, ++ hf_processor_mm_kwargs: Mapping[str, object], ++ ) -> Mapping[str, MultiModalFieldConfig]: ++ patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) ++ return dict( ++ pixel_values=MultiModalFieldConfig.flat_from_sizes( ++ "image", patches_per_image), ++ patches_per_image=MultiModalFieldConfig.batched("image"), ++ aspect_ratios=MultiModalFieldConfig.batched("image"), ++ embed_is_patch=MultiModalFieldConfig.batched("image"), ++ ) ++ ++ def _get_prompt_updates( ++ self, ++ mm_items: MultiModalDataItems, ++ hf_processor_mm_kwargs: Mapping[str, object], ++ out_mm_kwargs: MultiModalKwargs, ++ ) -> List[PromptUpdate]: ++ assert ( ++ mm_items.get_count("image", strict=False) == 0 ++ or "aspect_ratios" in out_mm_kwargs ++ ), "Transformers expect to include aspect_ratios in out_mm_kwargs" ++ ++ config = self.info.get_hf_config() ++ vision_config = config.vision_config ++ ++ num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config) ++ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) ++ image_token = hf_processor.image_token ++ ++ def get_replacement(item_idx: int): ++ aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] ++ return hf_processor._prompt_split_image( ++ aspect_ratio=aspect_ratio, ++ num_patches_per_chunk=num_patches_per_chunk) ++ ++ return [ ++ PromptReplacement( ++ modality="image", ++ target=image_token, ++ replacement=get_replacement, ++ ) ++ ] ++ ++ ++class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): ++ ++ def get_dummy_processor_inputs( ++ self, ++ seq_len: int, ++ mm_counts: Mapping[str, int], ++ ) -> ProcessorInputs: ++ num_images = mm_counts.get("image", 0) ++ ++ (target_width, ++ target_height) = self.info.get_image_size_with_most_features() ++ ++ mm_data = { ++ "image": ++ self._get_dummy_images(width=target_width, ++ height=target_height, ++ num_images=num_images) ++ } ++ ++ image_token = self.info.get_hf_processor().fake_image_token ++ return ProcessorInputs( ++ prompt_text=image_token * num_images, ++ mm_data=mm_data, ++ ) ++ ++ ++@MULTIMODAL_REGISTRY.register_processor( ++ Mllama4MultiModalProcessor, ++ info=Mllama4ProcessingInfo, ++ dummy_inputs=Mllama4DummyInputsBuilder, ++) ++class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ++ packed_modules_mapping = { ++ "qkv_proj": ["q_proj", "k_proj", "v_proj"], ++ } ++ ++ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ++ super().__init__() ++ config = vllm_config.model_config.hf_config ++ quant_config = vllm_config.quant_config ++ multimodal_config = vllm_config.model_config.multimodal_config ++ self.config = config ++ self.quant_config = quant_config ++ self.multimodal_config = multimodal_config ++ self.vision_model = Llama4VisionModel(config.vision_config, ++ None, ++ prefix=maybe_prefix( ++ prefix, "vision_model")) ++ self.multi_modal_projector = Llama4MultiModalProjector( ++ self.config, ++ None, ++ prefix=maybe_prefix(prefix, "multi_modal_projector")) ++ self.language_model = init_vllm_registered_model( ++ vllm_config=vllm_config, ++ hf_config=config.text_config, ++ architectures=["Llama4ForCausalLM"], ++ prefix=maybe_prefix(prefix, "language_model")) ++ ++ self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) ++ ++ def _parse_and_validate_image_input( ++ self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: ++ # num_images, 1, num_chunks, channel, image_size, image_size ++ pixel_values = kwargs.pop("pixel_values", None) ++ if pixel_values is None: ++ return None ++ ++ # num_images x num_chunks, channel, image_size, image_size ++ # TODO: confirm handling for variable lengths ++ flat_pixel_values = flatten_bn(pixel_values, concat=True) ++ patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) ++ ++ embed_is_patch = kwargs.pop("embed_is_patch", None) ++ if not isinstance(embed_is_patch, (torch.Tensor, list)): ++ raise ValueError("Incorrect type of embed_is_patch. " ++ f"Got type: {type(embed_is_patch)}") ++ ++ aspect_ratios = kwargs.pop("aspect_ratios", None) ++ if not isinstance(aspect_ratios, (torch.Tensor, list)): ++ raise ValueError("Incorrect type of aspect_ratios. " ++ f"Got type: {type(aspect_ratios)}") ++ ++ return Llama4ImagePatchInputs( ++ type="pixel_values", ++ flat_data=flat_pixel_values, ++ patches_per_image=patches_per_image, ++ embed_is_patch=embed_is_patch, ++ aspect_ratios=aspect_ratios, ++ ) ++ ++ def _process_image_input( ++ self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: ++ flat_data = image_input["flat_data"] ++ patches_per_image = image_input["patches_per_image"].tolist() ++ vision_embeddings_flat = self.vision_model(flat_data).last_hidden_state ++ return vision_embeddings_flat.split(patches_per_image, dim=0) ++ ++ def get_multimodal_embeddings(self, ++ **kwargs) -> Optional[MultiModalEmbeddings]: ++ image_input = self._parse_and_validate_image_input(**kwargs) ++ if image_input is None: ++ return None ++ ++ # num_images x [num_chunks, num_patches, hidden_dim] ++ image_features = self._process_image_input(image_input) ++ # num_images x [num_chunks x num_patches, hidden_dim] ++ image_features_flat = [img.flatten(0, 1) for img in image_features] ++ # num_images x [1, input_len] -> num_images x [input_len] ++ embed_is_patch_flat = [ ++ is_patch.flatten(0, 1) ++ for is_patch in image_input["embed_is_patch"] ++ ] ++ ++ return scatter_patch_features( ++ image_features_flat, ++ embed_is_patch_flat, ++ ) ++ ++ def get_input_embeddings( ++ self, ++ input_ids: torch.Tensor, ++ multimodal_embeddings: Optional[NestedTensors] = None, ++ ) -> torch.Tensor: ++ inputs_embeds = self.language_model.get_input_embeddings(input_ids) ++ ++ if multimodal_embeddings is not None: ++ multimodal_embeddings = torch.cat(multimodal_embeddings) ++ mm_embeddings = self.multi_modal_projector(multimodal_embeddings) ++ inputs_embeds = merge_multimodal_embeddings( ++ input_ids, inputs_embeds, select_patch_features(mm_embeddings), ++ self.config.image_token_index) ++ ++ return inputs_embeds ++ ++ def forward( ++ self, ++ input_ids: torch.Tensor, ++ positions: torch.Tensor, ++ intermediate_tensors: Optional[IntermediateTensors] = None, ++ inputs_embeds: Optional[torch.Tensor] = None, ++ **kwargs: object, ++ ) -> Union[torch.Tensor, IntermediateTensors]: ++ # NOTE: In v1, inputs_embeds is always generated at model runner, this ++ # condition is for v0 compatibility. ++ if "pixel_values" in kwargs: ++ vision_embeddings = self.get_multimodal_embeddings(**kwargs) ++ inputs_embeds = self.get_input_embeddings(input_ids, ++ vision_embeddings) ++ input_ids = None ++ ++ return self.language_model(input_ids, positions, intermediate_tensors, ++ inputs_embeds) ++ ++ def compute_logits( ++ self, ++ hidden_states: torch.Tensor, ++ sampling_metadata: SamplingMetadata, ++ ) -> Optional[torch.Tensor]: ++ return self.language_model.compute_logits(hidden_states, ++ sampling_metadata) ++ ++ def sample(self, logits: torch.Tensor, ++ sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: ++ return self.language_model.sample(logits, sampling_metadata) ++ ++ def separate_weights( ++ self, ++ weights: Iterable[Tuple[str, torch.Tensor]], ++ prefix: str, ++ ) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[ ++ str, torch.Tensor]]]: ++ weights1, weights2 = tee(weights, 2) ++ ++ def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]: ++ for name, data in weights1: ++ if name.startswith(prefix): ++ yield (name, data) ++ ++ def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]: ++ for name, data in weights2: ++ if not name.startswith(prefix): ++ yield (name, data) ++ ++ return get_prefix_weights(), get_other_weights() ++ ++ def load_weights(self, weights: Iterable[Tuple[str, ++ torch.Tensor]]) -> Set[str]: ++ ++ stacked_params_mapping = [ ++ # (param_name, shard_name, shard_id) ++ (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), ++ (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), ++ (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), ++ ] ++ params_dict = dict(self.named_parameters()) ++ updated_params: Set[str] = set() ++ ++ # language_model is an Llama4ForCausalLM instance. We load it's ++ # using llama4's load_weights routine. ++ language_model_prefix = "language_model.model." ++ language_model_weights, other_weights = self.separate_weights( ++ weights, prefix=language_model_prefix) ++ loader = AutoWeightsLoader(self) ++ loaded_language_model_params = loader.load_weights( ++ language_model_weights) ++ assert loaded_language_model_params is not None ++ updated_params.update(loaded_language_model_params) ++ ++ for name, loaded_weight in other_weights: ++ for param_name, weight_name, shard_id in stacked_params_mapping: ++ if weight_name not in name: ++ continue ++ name = name.replace(weight_name, param_name) ++ param = params_dict[name] ++ updated_params.add(name) ++ weight_loader = param.weight_loader ++ weight_loader(param, loaded_weight, shard_id) ++ break ++ else: ++ param = params_dict[name] ++ weight_loader = getattr(param, "weight_loader", ++ default_weight_loader) ++ ++ weight_loader(param, loaded_weight) ++ updated_params.add(name) ++ return updated_params +diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py +index 6ead6509bfe8..c0a3c59ba257 100644 +--- a/vllm/model_executor/models/registry.py ++++ b/vllm/model_executor/models/registry.py +@@ -73,6 +73,7 @@ + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), ++ "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM"), +@@ -194,6 +195,7 @@ + # [Encoder-decoder] + "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 + "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 ++ "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 + "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), + "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 + } +diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py +index 92e4ffd0371a..c271f438e858 100755 +--- a/vllm/v1/attention/backends/flash_attn.py ++++ b/vllm/v1/attention/backends/flash_attn.py +@@ -96,6 +96,183 @@ class FlashAttentionMetadata: + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + ++ # for local attention ++ @dataclass ++ class LocalAttentionMetadata: ++ local_query_start_loc: torch.Tensor ++ local_seqused_k: torch.Tensor ++ local_block_table: torch.Tensor ++ local_max_query_len: int ++ local_max_seq_len: int ++ ++ local_attn_metadata: Optional[LocalAttentionMetadata] = None ++ ++ ++# ++# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into ++# local attention blocks, where each block is passed to the attention kernel ++# as an independent local ("virtual") batch item. ++# ++# For example, if are performing a chunked prefill a batch of 3 sequences: ++# q_seqlens = [4, 10, 5] ++# kv_seqlens = [6, 17, 9] ++# Then normally for regular attention we would compute with an attention mask ++# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: ++# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) ++# k_toks > 0 1 2 3 4 5 ++# q_toks v _____________ ++# 0 | 1 1 1 ++# 1 | 1 1 1 1 ++# 2 | 1 1 1 1 1 ++# 3 | 1 1 1 1 1 1 ++# ++# for local attention (with attn_chunk_size = 4) we would compute with an ++# attention mask like: ++# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) ++# k_toks > 0 1 2 3 4 5 ++# q_toks v _____________ ++# 0 | 1 1 1 ++# 1 | 1 1 1 1 ++# 2 | 1 ++# 3 | 1 1 ++# ++# We can simulate this mask using standard flash-attention by breaking the ++# sequences into local ("virtual") batches, where each local batch item is a ++# local attention block, so in this case batch idx 0 would be broken up into: ++# ++# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) ++# k_toks > 0 1 2 3 ++# q_toks v _____________ ++# 0 | 1 1 1 ++# 1 | 1 1 1 1 ++# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) ++# k_toks > 4 5 ++# q_toks v _____________ ++# 2 | 1 ++# 3 | 1 1 ++# ++# e.g. if we have: ++# attn_chunk_size = 4 ++# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) ++# Then this function would return: ++# __b0__ ______b1______ __b2__ < orig batch indices ++# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] ++# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] ++# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] ++# block_table_local : shape[local_virtual_batches, pages_per_local_batch] ++def make_local_attention_virtual_batches( ++ attn_chunk_size: int, ++ query_start_loc_np: np.ndarray, ++ seq_lens_np: np.ndarray, ++ block_table: torch.tensor, ++ page_size: int = 0, ++) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]: ++ q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] ++ actual_batch_size = seq_lens_np.shape[0] ++ ++ # Handle if we are starting in the middle of a local attention block, ++ # we assume q_seqlens > 0 (for all elements), for each batch idx we compute ++ # the number of tokens that are not in the first local attention block and ++ # then we can simply use a cdiv for the rest. ++ # For example if we have: ++ # attn_chunk_size = 4 ++ # q_seqlens = [4, 10, 5] ++ # k_seqlens = [6, 17, 9] ++ # Then we would get: ++ # new_tokens_in_first_block = [2, 1, 4] ++ # local_blocks = [2, 4, 2] ++ q_tokens_in_first_block = np.minimum( ++ attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), ++ q_seqlens).astype(np.int32) ++ tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) ++ local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, ++ attn_chunk_size) ++ ++ # Once we know the number of local blocks we can compute the request spans ++ # for each batch idx, we can figure out the number of "virtual" requests we ++ # have to make, ++ # For the above example we would get: ++ # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] ++ # ++ # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) ++ # (TODO: max a utility to share this code with _prepare_inputs) ++ # arange step 1. [2, 4, 2] -> [2, 6, 8] ++ cu_num_blocks = np.cumsum(local_blocks) ++ virtual_batches = cu_num_blocks[-1] ++ # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] ++ block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) ++ # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] ++ arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets ++ # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) ++ rarange = np.repeat(local_blocks, local_blocks) - arange - 1 ++ # Then we can compute the seqlens_q_local, handling the fact that the ++ # first and last blocks could be partial ++ seqlens_q_local = \ ++ np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) ++ # set the first block since this may be a partial block ++ seqlens_q_local[arange == 0] = q_tokens_in_first_block ++ # set the remaining blocks ++ seqlens_q_local[arange > 0] = np.minimum( ++ seqlens_q_local - attn_chunk_size * (arange - 1), ++ attn_chunk_size)[arange > 0] ++ ++ # convert from q_seqlens to cu_seqlens_q ++ cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ ++ .astype(np.int32) ++ ++ # compute the seqlens_k_local, ++ # basically a full local attention block for all but the last block in each ++ # batch ++ # For our example this will be: ++ # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] ++ seqlens_k_local = np.full(cu_num_blocks[-1], ++ attn_chunk_size, ++ dtype=np.int32) ++ seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block ++ ++ k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ ++ (rarange * attn_chunk_size + \ ++ np.repeat(tokens_in_last_block, local_blocks)) ++ # For the example the local attention blocks start at: ++ # _b0_ _____b1_____ _b2_ ++ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] ++ block_starts = k_seqstarts_absolute // page_size ++ assert attn_chunk_size % page_size == 0, \ ++ f"attn_chunk_size {attn_chunk_size} is not " \ ++ f"divisible by page_size {page_size}" ++ pages_per_local_batch = attn_chunk_size // page_size ++ ++ # Create a block_table for the local attention blocks ++ # For out example if we have a block-table like (assuming page_size=2): ++ # block_table = [ ++ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 ++ # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 ++ # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 ++ # ] ++ # Then for the local batches we would want a block-table like ++ # block_table_local = [ ++ # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) ++ # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) ++ # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) ++ # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) ++ # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) ++ # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) ++ # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) ++ # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) ++ # ] ++ block_indices= np.broadcast_to( ++ np.arange(pages_per_local_batch, dtype=np.int32), ++ (virtual_batches, pages_per_local_batch)) \ ++ + np.expand_dims(block_starts, axis=1) ++ block_indices = block_indices.flatten() ++ batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), ++ local_blocks * pages_per_local_batch) ++ block_table_local = block_table[batch_indices, block_indices]\ ++ .view(virtual_batches, -1) ++ ++ return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ ++ block_table_local ++ + + class FlashAttentionMetadataBuilder: + +@@ -109,18 +286,40 @@ def reorder_batch(self, input_batch: "InputBatch", + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + max_seq_len = self.runner.seq_lens_np[:num_reqs].max() +- query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( +- self.runner.device, non_blocking=True) +- seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, +- non_blocking=True) ++ query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] ++ query_start_loc = query_start_loc_cpu.to(self.runner.device, ++ non_blocking=True) ++ seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] ++ seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True).long() + ++ # for local attention ++ local_attn_metadata = None ++ if self.runner.attention_chunk_size is not None: ++ seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ ++ virt_block_table = make_local_attention_virtual_batches( ++ self.runner.attention_chunk_size, ++ self.runner.query_start_loc_np[:num_reqs + 1], ++ self.runner.seq_lens_np[:num_reqs], ++ block_table, ++ self.runner.block_size, ++ ) ++ local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( ++ local_query_start_loc=torch.from_numpy( ++ virt_q_cu_seqlens_np).to(self.runner.device, ++ non_blocking=True), ++ local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to( ++ self.runner.device, non_blocking=True), ++ local_block_table=virt_block_table, ++ local_max_query_len=seqlens_q_local_np.max(), ++ local_max_seq_len=virt_k_seqlens_np.max(), ++ ) ++ + use_cascade = common_prefix_len > 0 + if use_cascade: +- # TODO: Optimize. + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.runner.device) +@@ -149,6 +348,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, ++ local_attn_metadata=local_attn_metadata, + ) + return attn_metadata + +@@ -167,6 +367,7 @@ def __init__( + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, ++ use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( +@@ -203,6 +404,7 @@ def __init__( + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") ++ self.use_irope = use_irope + self.vllm_flash_attn_version = get_flash_attn_version() + if is_quantized_kv_cache(self.kv_cache_dtype) \ + and not flash_attn_supports_fp8(): +@@ -265,8 +467,7 @@ def forward( + layer._k_scale, + layer._v_scale, + ) +- descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, +- key.shape[1]) ++ + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) +@@ -278,22 +479,41 @@ def forward( + query = query.reshape((num_tokens, num_heads, head_size)) + + # Compute attention and update output up to `num_actual_tokens`. +- if not attn_metadata.use_cascade: +- # Regular attention (common case). ++ use_local_attn = \ ++ (self.use_irope and attn_metadata.local_attn_metadata is not None) ++ ++ if not attn_metadata.use_cascade or use_local_attn: ++ if use_local_attn: ++ assert attn_metadata.local_attn_metadata is not None ++ local_metadata = attn_metadata.local_attn_metadata ++ cu_seqlens_q = local_metadata.local_query_start_loc ++ seqused_k = local_metadata.local_seqused_k ++ max_seqlen_q = local_metadata.local_max_query_len ++ max_seqlen_k = local_metadata.local_max_seq_len ++ block_table = local_metadata.local_block_table ++ else: ++ cu_seqlens_q = attn_metadata.query_start_loc ++ seqused_k = attn_metadata.seq_lens ++ max_seqlen_q = attn_metadata.max_query_len ++ max_seqlen_k = attn_metadata.max_seq_len ++ block_table = attn_metadata.block_table ++ ++ descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) ++ + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], +- cu_seqlens_q=attn_metadata.query_start_loc, +- max_seqlen_q=attn_metadata.max_query_len, +- seqused_k=attn_metadata.seq_lens, +- max_seqlen_k=attn_metadata.max_seq_len, ++ cu_seqlens_q=cu_seqlens_q, ++ max_seqlen_q=max_seqlen_q, ++ seqused_k=seqused_k, ++ max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, +- block_table=attn_metadata.block_table, ++ block_table=block_table, + softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), +@@ -302,6 +522,8 @@ def forward( + ) + return output + ++ assert not use_local_attn, ( ++ "Cascade attention does not support local attention.") + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], +diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py +index 15b49b14c1dd..5f9610470567 100644 +--- a/vllm/v1/attention/backends/triton_attn.py ++++ b/vllm/v1/attention/backends/triton_attn.py +@@ -70,6 +70,7 @@ def __init__( + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, ++ use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( +@@ -86,6 +87,7 @@ def __init__( + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype ++ self.use_irope = use_irope + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads +@@ -156,24 +158,41 @@ def forward( + layer._v_scale, + ) + ++ use_local_attn = \ ++ (self.use_irope and attn_metadata.local_attn_metadata is not None) ++ ++ if use_local_attn: ++ assert attn_metadata.local_attn_metadata is not None ++ local_metadata = attn_metadata.local_attn_metadata ++ cu_seqlens_q = local_metadata.local_query_start_loc ++ sequesd_k = local_metadata.local_seqused_k ++ max_seqlen_q = local_metadata.local_max_query_len ++ max_seqlen_k = local_metadata.local_max_seq_len ++ block_table = local_metadata.local_block_table ++ else: ++ cu_seqlens_q = attn_metadata.query_start_loc ++ sequesd_k = attn_metadata.seq_lens ++ max_seqlen_q = attn_metadata.max_query_len ++ max_seqlen_k = attn_metadata.max_seq_len ++ block_table = attn_metadata.block_table ++ + # Compute attention and update output up to `num_actual_tokens`. +- chunked_prefill_paged_decode( +- query=query[:num_actual_tokens], +- key=key[:num_actual_tokens], +- value=value[:num_actual_tokens], +- output=output[:num_actual_tokens], +- kv_cache_dtype=self.kv_cache_dtype, +- key_cache=key_cache, +- value_cache=value_cache, +- block_table=attn_metadata.block_table, +- query_start_loc=attn_metadata.query_start_loc, +- seq_lens=attn_metadata.seq_lens, +- max_seq_len=attn_metadata.max_seq_len, +- max_query_len=attn_metadata.max_query_len, +- k_scale=layer._k_scale, +- v_scale=layer._v_scale, +- alibi_slopes=self.alibi_slopes, +- sliding_window=self.sliding_window[0], +- sm_scale=self.scale) ++ chunked_prefill_paged_decode(query=query[:num_actual_tokens], ++ key=key[:num_actual_tokens], ++ value=value[:num_actual_tokens], ++ output=output[:num_actual_tokens], ++ kv_cache_dtype=self.kv_cache_dtype, ++ key_cache=key_cache, ++ value_cache=value_cache, ++ block_table=block_table, ++ query_start_loc=cu_seqlens_q, ++ seq_lens=sequesd_k, ++ max_seq_len=max_seqlen_k, ++ max_query_len=max_seqlen_q, ++ k_scale=layer._k_scale, ++ v_scale=layer._v_scale, ++ alibi_slopes=self.alibi_slopes, ++ sliding_window=self.sliding_window[0], ++ sm_scale=self.scale) + + return output +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index 82b07c6cd327..5133c637f0e0 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -113,6 +113,7 @@ def __init__( + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() ++ self.attention_chunk_size = model_config.attention_chunk_size + + self.attn_backend = get_attn_backend( + self.head_size, +``` \ No newline at end of file