-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Llama4] [multimodal] Fix misplaced dtype cast of cos_sin_cache in Llama4VisionRotaryEmbedding
#25889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: cjackal <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a performance degradation in Llama-4 vision models by preventing an unintended data type cast of the cos_sin_cache. The change correctly preserves the precision of the cache. My review includes a suggestion to further optimize the implementation by avoiding an unconditional and potentially expensive device transfer operation on every forward pass, which could introduce a performance regression.
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
| assert key is not None | ||
| self._match_cos_sin_cache_dtype(query) | ||
| self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line unconditionally moves self.cos_sin_cache to the query tensor's device on every forward pass. This operation can be expensive and is unnecessary if the cache is already on the correct device. To optimize performance, it's better to add a check to perform the device transfer only when needed, similar to the logic in the _match_cos_sin_cache_dtype method that was previously used here.
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhat a valid comment, but as it was originally in this form before #21126, I'll wait for the vllm maintainer's opinition.
|
@mgoin @huydhn @yeqcharlotte Ping to the author of relevant PR(#21126) and maintainers on meta side that I am aware of. |
cos_sin_cache in Llama4VisionRotaryEmbeddingcos_sin_cache in Llama4VisionRotaryEmbedding
|
My bad, thanks for finding this. Can you leave a comment noting this intentionality? Two other notes:
|
Signed-off-by: cjackal <[email protected]>
Comment added!
def forward_hip( # type: ignore[override]
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(query, key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ywang96
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think non-cuda devices are already using forward_native unless explicitly overidden, so this fix looks good to me.
@yeqcharlotte @houseroad Is it possible for meta side to add some unit tests on this specific rope implementation so we can prevent this from happening in the future?
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]>
|
@mgoin It seems this PR hasn't been picked in v0.11.0rc6, would you mind double-checking it? |
…`Llama4VisionRotaryEmbedding` (#25889) Signed-off-by: cjackal <[email protected]> Signed-off-by: yewentao256 <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]> Signed-off-by: Tomer Asida <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]>
…`Llama4VisionRotaryEmbedding` (vllm-project#25889) Signed-off-by: cjackal <[email protected]>
Purpose
Fix #25888.
Especially, this PR roll back the misplaced dtype cast introduced by #21126 which deteriorates the vision understanding performance of llama 4 family.
Note that the
cos_sin_cacheinLlama4VisionRotaryEmbeddingis always complex (torch.complex64) but the query is always real (torch.bfloat16), so we cannot cast the trigonometric cache to query's dtype.Test Plan
Check that Llama 4 model families are correctly recognize the "4 dice" image used in the vllm multimodal test CI pipeline.
Test Result
Essential Elements of an Effective PR Description Checklist