Skip to content

[Misc][LLaMa4] Compile LLaMa Vision Encoder#30709

Merged
ProExpertProg merged 5 commits intovllm-project:mainfrom
Lucaskabela:lucaskabela/mllama4_compilation
Jan 10, 2026
Merged

[Misc][LLaMa4] Compile LLaMa Vision Encoder#30709
ProExpertProg merged 5 commits intovllm-project:mainfrom
Lucaskabela:lucaskabela/mllama4_compilation

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Dec 15, 2025

Purpose

We want to speedup up inference for mllama4 by applying torch.compile to the intensive workload, similar to what is done in #23207. We start by enabling the VisionEncoder + PixelShuffle

Test Plan

Unit Test

with-proxy pytest tests/compile/fullgraph/test_multimodal_compile.py::test_mllama4_vit_compilation

Result:

 1 passed, 27 warnings in 176.88s (0:02:56) 

Offline Test

with-proxy VLLM_USE_V1=1 python examples/offline_inference/vision_language.py -m llama4

With compilation_config={"compile_mm_encoder": True} monkey patched to EngineArgs

Results in

--------------------------------------------------
The image depicts a tower, likely Tokyo Tower, framed by cherry blossoms. The tower is white and has a distinctive shape, with a large sphere at the top and a long, thin spire extending from it. It appears to be made of metal and has a lattice-like structure.

In the foreground, there are
--------------------------------------------------
The image depicts a tower, likely Tokyo Tower, framed by cherry blossoms. The tower is white and has a distinctive shape, with a large sphere at the top and a series of latticework sections below it. It appears to be made of metal and has a tall, slender design.

In the foreground, there are
--------------------------------------------------
The image depicts a serene scene of Tokyo Tower, partially obscured by the vibrant pink blossoms of cherry blossom trees. The tower's white and gold structure is visible through the branches and flowers, set against a clear blue sky.

**Key Features:**

* **Tokyo Tower:** A prominent landmark in Tokyo, Japan,
--------------------------------------------------
The image depicts a serene scene of Tokyo Tower, partially obscured by blooming cherry blossoms. The tower's distinctive shape and structure are visible through the branches of the trees, which are adorned with vibrant pink flowers.

**Key Features:**

* **Tokyo Tower:** A prominent landmark in Tokyo, Japan, known for
--------------------------------------------------

Server Benchmark

vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct --tensor-parallel-size=8 --gpu_memory_utilization=.8 --max_model_len=8192 --compilation-config='{"compile_mm_encoder":"true"}'
vllm bench serve   --backend openai-chat   --model meta-llama/Llama-4-Scout-17B-16E-Instruct    --endpoint /v1/chat/completions   --dataset-name hf   --dataset-path lmarena-ai/VisionArena-Chat   --hf-split train   --num-prompts 1000

Test Result

Main This PR
Successful requests 998 998
Benchmark duration (s) 63.54 61.71
Total generated tokens 117050 117397
Request throughput (req/s) 15.71 16.17
Output token throughput (tok/s) 1842.17 1909.91
Mean TTFT (ms) 29224.49 28011.36
Mean TPOT (ms) 240.21 231.26
Mean ITL (ms) 232.45 223.91

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Note

Speeds up mLLaMA4 vision by compiling its encoder path and tightening related infra.

  • Compile Llama4VisionModel (VisionEncoder + PixelShuffleMLP) via support_torch_compile, gated by compile_mm_encoder; tag with set_model_tag and run under set_forward_context
  • Update CompilationConfig.compile_mm_encoder docs to include mLLaMa4; add test test_mllama4_vit_compilation (forked/skipped in CI)
  • Fix arg order/defaults in ViT flash-attn wrapper and its fake impl; plumb args in MMEncoderAttention
  • Optimize Llama4VisionRotaryEmbedding to avoid in-place cache updates; minor embed path change in LlamaModel to prevent recompiles

Written by Cursor Bugbot for commit e1e0f0a. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 84c926015cf934954bf3e1eaf51049d5e3003492. Configure here.


Note

Speeds up mLLaMA4 vision by compiling its encoder path and tightening related infra.

  • Compile Llama4VisionModel (VisionEncoder + PixelShuffleMLP) via support_torch_compile, gated by CompilationConfig.compile_mm_encoder; tag with set_model_tag and run under set_forward_context
  • Update CompilationConfig.compile_mm_encoder docs to include mLLaMa4; add test_mllama4_vit_compilation (forked/skipped in CI)
  • Fix arg order/defaults in ViT flash-attn wrapper (vit_attn_wrappers.py) and plumb args in MMEncoderAttention
  • Optimize Llama4VisionRotaryEmbedding to avoid in-place cache updates; add dynamic-shape hints to LlamaModel to reduce recompiles

Written by Cursor Bugbot for commit 84c926015cf934954bf3e1eaf51049d5e3003492. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 8935253245665bb7b66ed35b63846b7ec513a9b9. Configure here.


Note

Speeds up mLLaMA4 vision by compiling its encoder path and aligning infra.

  • Compile Llama4VisionModel (VisionEncoder + PixelShuffleMLP) via support_torch_compile gated by CompilationConfig.compile_mm_encoder; tag with set_model_tag and run under set_forward_context
  • Update CompilationConfig.compile_mm_encoder doc to include mLLaMa4; add test_mllama4_vit_compilation (forked/skipped in CI)
  • Fix ViT flash-attn wrapper arg order/defaults (vit_attn_wrappers.py) and plumb args in MMEncoderAttention
  • Optimize Llama4VisionRotaryEmbedding to avoid in-place cache updates; add mark_unbacked_dims to LlamaModel to reduce recompiles

Written by Cursor Bugbot for commit 8935253245665bb7b66ed35b63846b7ec513a9b9. This will update automatically on new commits. Configure here.


Note

Speeds up mLLaMA4 vision by compiling its encoder path and aligning related infra.

  • Compile Llama4VisionModel via support_torch_compile (gated by compile_mm_encoder), tag with set_model_tag, and run under set_forward_context
  • Update CompilationConfig.compile_mm_encoder docs to include mLLaMa4; add test_mllama4_vit_compilation (forked/skipped in CI) and refine Qwen2.5-VL tests
  • Fix ViT flash-attn wrapper arg order/defaults in vit_attn_wrappers.py and plumb args in MMEncoderAttention
  • Optimize Llama4VisionRotaryEmbedding to avoid in-place cache updates
  • Add dynamic-shape hints to LlamaModel (mark_unbacked_dims) to reduce recompiles

Written by Cursor Bugbot for commit 95c4616. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 31bc1de. Configure here.


Note

Speeds up mLLaMA4 vision by compiling its encoder path and aligning related infra.

  • Wraps Llama4VisionModel with support_torch_compile (dynamic arg dims; gated by should_torch_compile_mm_vit), constructs under set_current_vllm_config and tags via set_model_tag; runs image embed under set_forward_context
  • Documents CompilationConfig.compile_mm_encoder to include mLLaMa4
  • Fixes ViT flash-attn wrapper plumbing: reorder args in MMEncoderAttention call; add defaults in flash_attn_maxseqlen_wrapper_fake
  • Optimizes Llama4VisionRotaryEmbedding to avoid in-place cache updates; tweaks LlamaModel compile decorator to reduce recompiles
  • Adds test_mllama4_vit_compilation (forked/skipped due to CI constraints)

Written by Cursor Bugbot for commit 31bc1de. This will update automatically on new commits. Configure here.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify mergify bot added the llama Related to Llama models label Dec 15, 2025
@Lucaskabela
Copy link
Contributor Author

cc @ywang96 - resubmit of #27900 (previous PR fell a bit out of date so resubmitting for another review)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables torch.compile for the LLaMa Vision Encoder layers in mllama4 to improve inference performance. The changes primarily involve adapting the model code to be compatible with torch.compile, such as introducing a wrapper for flash attention and decorating vision submodules. A new test is also added to verify that the model runs correctly with compilation enabled. The approach is sound and follows existing patterns in the codebase for torch.compile integration. I have one high-severity suggestion to correct misleading type hints in the new flash attention wrapper function to improve code correctness and maintainability.

@Lucaskabela Lucaskabela changed the title [Misc][LLaMa4] Compile LLaMa Vision Encoder layers [Misc][LLaMa4] Compile LLaMa Vision Encode Dec 15, 2025
@Lucaskabela Lucaskabela force-pushed the lucaskabela/mllama4_compilation branch from dbdda27 to af16562 Compare December 15, 2025 22:35
@@ -407,6 +407,9 @@ def __init__(
)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
# Need explicit mark here to avoid recompile from 0/1 spec
# since VocabEmbedding uses a different torch.compile decorator
torch._dynamo.decorators.mark_unbacked(input_ids, 0)
Copy link
Contributor Author

@Lucaskabela Lucaskabela Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed to avoid a recompile (discovered via tlparse)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can't handle this via the decorator and the work @laithsakka has been doing on dynamic/unbacked shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be handled with the decorator :)

@Lucaskabela Lucaskabela changed the title [Misc][LLaMa4] Compile LLaMa Vision Encode [Misc][LLaMa4] Compile LLaMa Vision Encoder Dec 15, 2025
@Lucaskabela
Copy link
Contributor Author

Note for reviewers: This is off by default, and requires compile_mm_encoder: True to turn on

@Lucaskabela
Copy link
Contributor Author

One more interesting note: Since rebasing from last Friday (see #27900 table), there was a pretty sizable performance dip for the compiled artifact. I know the compiled ranges work landed in that time, but wondering if there was any other significant backend changes/code changes that might have caused this

@DarkLight1337
Copy link
Member

cc @ywang96 @Isotr0py @zou3519

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Lucaskabela thanks a lot for your work on another mm model!

Have you looked into the MMEncoderAttention CustomOp https://github.com/vllm-project/vllm/blob/main/vllm/attention/layers/mm_encoder_attention.py#L44?

I think it'd be nice to start having a more homogeneous code structure when compiling a new encoder, rather than adding an FA wrapper for each.
@Isotr0py for one is refactoring this part here #30684 , and it should be able to satisfy your use case with q_len != k_len, without requiring a separate wrapper.
At the very least, you could re-use the is_rocm+fa_version boilerplate code which is now taken care of in that MMEncoderAttention class.

@Lucaskabela
Copy link
Contributor Author

Note: will wait for #30684 to land since that is a fairly large refactor to this code (and will eliminate the need for us to add new custom ops)

@mergify
Copy link

mergify bot commented Dec 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Lucaskabela.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 18, 2025
@Lucaskabela Lucaskabela marked this pull request as draft December 19, 2025 00:23
@Lucaskabela Lucaskabela force-pushed the lucaskabela/mllama4_compilation branch from 2be7440 to 38f02d1 Compare December 19, 2025 00:29
@mergify mergify bot removed the needs-rebase label Dec 19, 2025
@Lucaskabela Lucaskabela force-pushed the lucaskabela/mllama4_compilation branch from 38f02d1 to 9fb41ce Compare December 19, 2025 00:36
@mergify
Copy link

mergify bot commented Jan 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Lucaskabela.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 9, 2026
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
@Lucaskabela Lucaskabela force-pushed the lucaskabela/mllama4_compilation branch from 3d612d9 to e1e0f0a Compare January 9, 2026 23:21
@mergify mergify bot added v1 and removed needs-rebase labels Jan 9, 2026
@mergify
Copy link

mergify bot commented Jan 9, 2026

Hi @Lucaskabela, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@Lucaskabela Lucaskabela force-pushed the lucaskabela/mllama4_compilation branch from 84c9260 to 8935253 Compare January 9, 2026 23:47
@mergify
Copy link

mergify bot commented Jan 9, 2026

Hi @Lucaskabela, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

scale: float | None,
cu_seqlens: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
scale: float | None = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to match the other api (needs = None)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
@Lucaskabela Lucaskabela force-pushed the lucaskabela/mllama4_compilation branch from 8935253 to 95c4616 Compare January 10, 2026 00:05
@ProExpertProg ProExpertProg enabled auto-merge (squash) January 10, 2026 00:51
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 10, 2026
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
auto-merge was automatically disabled January 10, 2026 00:55

Head branch was pushed to by a user without write access

@ProExpertProg ProExpertProg merged commit ea6d067 into vllm-project:main Jan 10, 2026
64 checks passed
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
@Lucaskabela Lucaskabela deleted the lucaskabela/mllama4_compilation branch February 19, 2026 16:40
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants