Skip to content

[Hardware][AMD] Add fused QK RoPE and reshape & cache flash support for ROCm#28850

Closed
mjkvaak-amd wants to merge 1 commit intovllm-project:mainfrom
mjkvaak-amd:feat/add_fused_rope_zeros_kv_cache_support_for_rocm
Closed

[Hardware][AMD] Add fused QK RoPE and reshape & cache flash support for ROCm#28850
mjkvaak-amd wants to merge 1 commit intovllm-project:mainfrom
mjkvaak-amd:feat/add_fused_rope_zeros_kv_cache_support_for_rocm

Conversation

@mjkvaak-amd
Copy link

@mjkvaak-amd mjkvaak-amd commented Nov 17, 2025

Purpose

This PR adds support for fusing QK RoPE, zeros, and reshape_and_cache in ROCm, and implements the fusion for Qwen3 and Qwen3-MoE models. This fusion kernel, _fused_qk_rope_reshape_and_cache_kernel, slightly improves the model speed, while not affecting the quality.

Test Plan

Please advise on the test plan, since it's not entirely clear based on existing tests whether/how this feature should be tested. Maybe something similar to test_silu_mul_quant_fusion.py?

Test Result

See "Test Plan"


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm v1 labels Nov 17, 2025
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused kernel for ROCm to enhance performance for Qwen3 and Qwen3-MoE models. The changes are well-contained and the approach of using a fused kernel for RoPE, zeros, and caching is sound. However, I've identified a critical bug that could cause a NameError on non-ROCm platforms, along with some code duplication and a magic number that should be refactored for better maintainability. Addressing these points will improve the robustness and clarity of the code.

Comment on lines 46 to 56
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9

if envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
else:
on_gfx9 = lambda *args, **kwargs: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is only defined if current_platform.is_rocm() is true. This will cause a NameError on other platforms (e.g., CUDA) where this variable is used later in unified_attention_with_output. The definition should be refactored to ensure it's always defined, regardless of the platform.

Suggested change
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
if envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
else:
on_gfx9 = lambda *args, **kwargs: False
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

Comment on lines +54 to +59
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for setting VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is duplicated from vllm/attention/layer.py. To improve maintainability and avoid potential inconsistencies, this flag should be defined in a single location and imported where needed. Please remove this duplicated block and import the flag from vllm.attention.layer like so:

from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE

Comment on lines +82 to +87
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for setting VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is duplicated from vllm/attention/layer.py. To improve maintainability and avoid potential inconsistencies, this flag should be defined in a single location and imported where needed. Please remove this duplicated block and import the flag from vllm.attention.layer like so:

from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE

self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
if positions is not None and query.shape[0] <= 256:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 256 is a magic number that determines the token threshold for using the fused kernel. It should be defined as a named constant, for example _FUSED_QK_ROPE_RESHAPE_AND_CACHE_MAX_TOKENS = 256, at the top of the file to improve readability and maintainability.

Suggested change
if positions is not None and query.shape[0] <= 256:
if positions is not None and query.shape[0] <= 256: # TODO: make this a constant

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 402 to +410
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
if positions is not None:
output = torch.empty(
output_shape, dtype=query.dtype, device=query.device
)
else:
output = torch.zeros(
output_shape, dtype=query.dtype, device=query.device

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Output buffer now allocated in quantized FP8 dtype

In the attention forward path, the output tensor is now created with dtype=query.dtype (lines 402‑410). When FP8 query quantization is active, self.query_quant converts query to an FP8 tensor before this allocation. The previous code cached the pre‑quantization dtype (output_dtype) so the output buffer remained fp16/bf16. After this change the output is allocated in FP8, but downstream attention kernels expect the regular activation dtype, so the call either fails or produces incorrect results whenever query quantization is enabled. Capture the original dtype before quantizing and use it for output to avoid creating an FP8 output buffer.

Useful? React with 👍 / 👎.

@heheda12345
Copy link
Collaborator

CC @tjtanaa

if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9

if envs.VLLM_ROCM_USE_AITER:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_enabled() and some new flags there.

)
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl

if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and isinstance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_xxx_enabled() and some new flags there.

@mergify
Copy link

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mjkvaak-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 19, 2025
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True
Copy link
Collaborator

@tjtanaa tjtanaa Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that this is enabled default.
Does this apply to all models?
If this can be applied to all models, do we see improvement in general?
If it does, maybe we don't need a flag to manage it, just have a logic where when aiter is enabled, we use the fusion op.


logger = init_logger(__name__)
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_enabled() and some new flags there.

else {},
rotary_emb=(
self.rotary_emb
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise


logger = init_logger(__name__)
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

else {},
rotary_emb=(
self.rotary_emb
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise


from vllm.triton_utils import tl, triton

if envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 19, 2025

@mjkvaak-amd I think it is possible to turn this into a fusion pass that merge rotary embedding with attention ops similar to what is happening in the fusion_attn.py . Using fusion pass also avoids adding new AITER flag. Regarding to fusion passes, let's get @ProExpertProg feedback.

@ProExpertProg Do you think it is a better way to enable this feature through fusion pass?

@ProExpertProg
Copy link
Collaborator

Yes, we should do this fusion using a fusion pass. @ElizaWszola has a PR that will be merged soon that separates the cache op from unified_attention. Then we can replace rope -> cache with a fused rope_cache op. More is described in #24678.

To start this PR can just integrate the fused kernels, if you want. And then you (or someone else) can work on a pass in a follow-up. Or, you can work on the whole thing now, whatever you prefer. The first option is probably easier (kernels in this PR, pass in the next).

@mjkvaak-amd
Copy link
Author

Yes, we should do this fusion using a fusion pass. @ElizaWszola has a PR that will be merged soon that separates the cache op from unified_attention. Then we can replace rope -> cache with a fused rope_cache op. More is described in #24678.

To start this PR can just integrate the fused kernels, if you want. And then you (or someone else) can work on a pass in a follow-up. Or, you can work on the whole thing now, whatever you prefer. The first option is probably easier (kernels in this PR, pass in the next).

Apologies, I didn't realize there were ongoing efforts to refactor the cache op out of unified_attention. I agree that the fusion pass is a more suitable approach moving forward for several reasons, such as being globally available to models without requiring changes to individual model blueprints (this PR only addressed Qwen3) and its ability to work cross-platform (not just on AMD).

At the moment, I have limited bandwidth to rework this PR to include only the fusion kernels. IMHO, it might be best to simply close this one and start fresh. Feel free to close it, and I'll ping you with a new PR when I find more time—unless someone else beats me to it.

@mergify
Copy link

mergify bot commented Dec 5, 2025

Hi @mjkvaak-amd, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@github-actions
Copy link

github-actions bot commented Mar 6, 2026

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Mar 6, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 6, 2026
@ProExpertProg
Copy link
Collaborator

Done in #33443

@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase qwen Related to Qwen models rocm Related to AMD ROCm stale Over 90 days of inactivity v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants