Skip to content

Conversation

@cjackal
Copy link
Contributor

@cjackal cjackal commented Sep 29, 2025

Purpose

Fix #25888.

Especially, this PR roll back the misplaced dtype cast introduced by #21126 which deteriorates the vision understanding performance of llama 4 family.

Note that the cos_sin_cache in Llama4VisionRotaryEmbedding is always complex (torch.complex64) but the query is always real (torch.bfloat16), so we cannot cast the trigonometric cache to query's dtype.

Test Plan

Check that Llama 4 model families are correctly recognize the "4 dice" image used in the vllm multimodal test CI pipeline.

import httpx
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8080/v1", api_key="a")

# This is the 4 dice image used in the vllm multimodal test CI pipeline
image_url = "https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/tests/multimodal/assets/rgba.png"

messages = [{"role": "user", "content": [{"type": "text", "text": "Explain the image."}, {"type": "image_url", "image_url": {"url": image_url}}]}]

for t in client.chat.completions.create(model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", messages=messages, stream=True):
    if t.choices:
        d = t.choices[0].delta
        print(getattr(d, "reasoning_content", "") or d.content or "", end="", flush=True)

Test Result

The image depicts four dice in mid-air, each with a distinct color and orientation.

*   The dice are arranged in a staggered formation, with the red die at the center and the other three dice positioned around it.
    *   The red die is in sharp focus, while the other three dice are slightly blurred.
    *   The blue die is located in the upper left corner of the image, the green die is in the upper right corner, and the yellow die is below the red die.
*   The dice have a reflective surface, giving them a shiny appearance.
    *   The dice are translucent, allowing the background to be visible through them.
    *   The dice have white dots on their faces, which represent the numbers 1-6.
*   The background of the image is solid white, providing a clean and simple backdrop for the dice.
    *   There are no other objects or features in the image besides the dice.

Overall, the image presents a visually appealing and dynamic representation of four dice in motion, with a focus on their colors and reflective surfaces.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

@mergify mergify bot added the llama Related to Llama models label Sep 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a performance degradation in Llama-4 vision models by preventing an unintended data type cast of the cos_sin_cache. The change correctly preserves the precision of the cache. My review includes a suggestion to further optimize the implementation by avoiding an unconditional and potentially expensive device transfer operation on every forward pass, which could introduce a performance regression.

) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
self._match_cos_sin_cache_dtype(query)
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line unconditionally moves self.cos_sin_cache to the query tensor's device on every forward pass. This operation can be expensive and is unnecessary if the cache is already on the correct device. To optimize performance, it's better to add a check to perform the device transfer only when needed, similar to the logic in the _match_cos_sin_cache_dtype method that was previously used here.

if self.cos_sin_cache.device != query.device:
    self.cos_sin_cache = self.cos_sin_cache.to(query.device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat a valid comment, but as it was originally in this form before #21126, I'll wait for the vllm maintainer's opinition.

@cjackal
Copy link
Contributor Author

cjackal commented Sep 29, 2025

@mgoin @huydhn @yeqcharlotte Ping to the author of relevant PR(#21126) and maintainers on meta side that I am aware of.

@cjackal cjackal changed the title [Llama4] [multimodal] Fix unwanted dtype cast of cos_sin_cache in Llama4VisionRotaryEmbedding [Llama4] [multimodal] Fix misplaced dtype cast of cos_sin_cache in Llama4VisionRotaryEmbedding Sep 29, 2025
@mgoin
Copy link
Member

mgoin commented Sep 29, 2025

My bad, thanks for finding this. Can you leave a comment noting this intentionality?

Two other notes:

  • we are missing unit tests for these custom rope classes if this wasn't caught in CI
  • since this custom rope is only overriding the native impl, I think it is possible that a hardware backend with a forward_cuda/hip/etc impl in the base class may not be running the right code. We should make sure to force this

Signed-off-by: cjackal <[email protected]>
@cjackal
Copy link
Contributor Author

cjackal commented Sep 29, 2025

My bad, thanks for finding this. Can you leave a comment noting this intentionality?

Two other notes:

  • we are missing unit tests for these custom rope classes if this wasn't caught in CI
  • since this custom rope is only overriding the native impl, I think it is possible that a hardware backend with a forward_cuda/hip/etc impl in the base class may not be running the right code. We should make sure to force this

Comment added!

  1. From the look of the test filetree, we indeed do not have a dedicated test for Llama4VisionRope. Do we want a full E2E test under test_maverick.py (which sounds too heavy I think) or a lightweight unit test like test_rope_module_cache in test_pos_encoding.py?
  2. I'm not familiar with non-CUDA hardwares, do you want something like the following?
    def forward_hip(  # type: ignore[override]
        self,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        return self.forward_native(query, key)

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this was changed by @mgoin in #21126

Edit: Sorry I didn't see your comments earlier

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think non-cuda devices are already using forward_native unless explicitly overidden, so this fix looks good to me.

@yeqcharlotte @houseroad Is it possible for meta side to add some unit tests on this specific rope implementation so we can prevent this from happening in the future?

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Sep 30, 2025
@mgoin mgoin enabled auto-merge (squash) September 30, 2025 19:58
@mgoin mgoin merged commit 43b752c into vllm-project:main Sep 30, 2025
46 checks passed
@mgoin mgoin added this to the v0.11.0 Cherry Picks milestone Oct 1, 2025
@cjackal cjackal deleted the fix-llama4 branch October 2, 2025 02:01
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
@cjackal
Copy link
Contributor Author

cjackal commented Oct 3, 2025

@mgoin It seems this PR hasn't been picked in v0.11.0rc6, would you mind double-checking it?

yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…`Llama4VisionRotaryEmbedding` (#25889)

Signed-off-by: cjackal <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
…`Llama4VisionRotaryEmbedding` (vllm-project#25889)

Signed-off-by: cjackal <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…`Llama4VisionRotaryEmbedding` (vllm-project#25889)

Signed-off-by: cjackal <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…`Llama4VisionRotaryEmbedding` (vllm-project#25889)

Signed-off-by: cjackal <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: [llama4] Llama 4 vision understanding performance degrades significantly since #21126

4 participants