[Hardware][AMD] Add fused QK RoPE and reshape & cache flash support for ROCm#28850
Conversation
…AITER FlashAttention
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a fused kernel for ROCm to enhance performance for Qwen3 and Qwen3-MoE models. The changes are well-contained and the approach of using a fused kernel for RoPE, zeros, and caching is sound. However, I've identified a critical bug that could cause a NameError on non-ROCm platforms, along with some code duplication and a magic number that should be refactored for better maintainability. Addressing these points will improve the robustness and clarity of the code.
| if current_platform.is_rocm(): | ||
| from vllm.platforms.rocm import on_gfx9 | ||
|
|
||
| if envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | ||
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
| ) | ||
| else: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False | ||
| else: | ||
| on_gfx9 = lambda *args, **kwargs: False |
There was a problem hiding this comment.
The variable VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is only defined if current_platform.is_rocm() is true. This will cause a NameError on other platforms (e.g., CUDA) where this variable is used later in unified_attention_with_output. The definition should be refactored to ensure it's always defined, regardless of the platform.
| if current_platform.is_rocm(): | |
| from vllm.platforms.rocm import on_gfx9 | |
| if envs.VLLM_ROCM_USE_AITER: | |
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | |
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | |
| ) | |
| else: | |
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False | |
| else: | |
| on_gfx9 = lambda *args, **kwargs: False | |
| if current_platform.is_rocm(): | |
| from vllm.platforms.rocm import on_gfx9 | |
| else: | |
| on_gfx9 = lambda *args, **kwargs: False | |
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | |
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | |
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | |
| ) | |
| else: | |
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False |
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | ||
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
| ) | ||
| else: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False |
There was a problem hiding this comment.
This logic for setting VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is duplicated from vllm/attention/layer.py. To improve maintainability and avoid potential inconsistencies, this flag should be defined in a single location and imported where needed. Please remove this duplicated block and import the flag from vllm.attention.layer like so:
from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | ||
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
| ) | ||
| else: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False |
There was a problem hiding this comment.
This logic for setting VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is duplicated from vllm/attention/layer.py. To improve maintainability and avoid potential inconsistencies, this flag should be defined in a single location and imported where needed. Please remove this duplicated block and import the flag from vllm.attention.layer like so:
from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE| self.kv_cache_dtype, | ||
| layer._k_scale, | ||
| layer._v_scale, | ||
| if positions is not None and query.shape[0] <= 256: |
There was a problem hiding this comment.
The value 256 is a magic number that determines the token threshold for using the fused kernel. It should be defined as a named constant, for example _FUSED_QK_ROPE_RESHAPE_AND_CACHE_MAX_TOKENS = 256, at the top of the file to improve readability and maintainability.
| if positions is not None and query.shape[0] <= 256: | |
| if positions is not None and query.shape[0] <= 256: # TODO: make this a constant |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if self.use_output: | ||
| output_shape = output_shape if output_shape is not None else query.shape | ||
| output = torch.empty(output_shape, dtype=output_dtype, device=query.device) | ||
| if positions is not None: | ||
| output = torch.empty( | ||
| output_shape, dtype=query.dtype, device=query.device | ||
| ) | ||
| else: | ||
| output = torch.zeros( | ||
| output_shape, dtype=query.dtype, device=query.device |
There was a problem hiding this comment.
Output buffer now allocated in quantized FP8 dtype
In the attention forward path, the output tensor is now created with dtype=query.dtype (lines 402‑410). When FP8 query quantization is active, self.query_quant converts query to an FP8 tensor before this allocation. The previous code cached the pre‑quantization dtype (output_dtype) so the output buffer remained fp16/bf16. After this change the output is allocated in FP8, but downstream attention kernels expect the regular activation dtype, so the call either fails or produces incorrect results whenever query quantization is enabled. Capture the original dtype before quantizing and use it for output to avoid creating an FP8 output buffer.
Useful? React with 👍 / 👎.
|
CC @tjtanaa |
| if current_platform.is_rocm(): | ||
| from vllm.platforms.rocm import on_gfx9 | ||
|
|
||
| if envs.VLLM_ROCM_USE_AITER: |
There was a problem hiding this comment.
AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_enabled() and some new flags there.
| ) | ||
| from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl | ||
|
|
||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and isinstance( |
There was a problem hiding this comment.
AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_xxx_enabled() and some new flags there.
|
This pull request has merge conflicts that must be resolved before it can be |
| VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False | ||
| VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False | ||
| VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True |
There was a problem hiding this comment.
I saw that this is enabled default.
Does this apply to all models?
If this can be applied to all models, do we see improvement in general?
If it does, maybe we don't need a flag to manage it, just have a logic where when aiter is enabled, we use the fusion op.
|
|
||
| logger = init_logger(__name__) | ||
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( |
There was a problem hiding this comment.
AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_enabled() and some new flags there.
| else {}, | ||
| rotary_emb=( | ||
| self.rotary_emb | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE |
| k = k_by_head.view(k.shape) | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| attn_output = self.attn(q, k, v) | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: |
|
|
||
| logger = init_logger(__name__) | ||
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( |
| else {}, | ||
| rotary_emb=( | ||
| self.rotary_emb | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE |
|
|
||
| from vllm.triton_utils import tl, triton | ||
|
|
||
| if envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: |
| k = k_by_head.view(k.shape) | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| attn_output = self.attn(q, k, v) | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: |
|
@mjkvaak-amd I think it is possible to turn this into a fusion pass that merge rotary embedding with attention ops similar to what is happening in the fusion_attn.py . Using fusion pass also avoids adding new AITER flag. Regarding to fusion passes, let's get @ProExpertProg feedback. @ProExpertProg Do you think it is a better way to enable this feature through fusion pass? |
|
Yes, we should do this fusion using a fusion pass. @ElizaWszola has a PR that will be merged soon that separates the cache op from To start this PR can just integrate the fused kernels, if you want. And then you (or someone else) can work on a pass in a follow-up. Or, you can work on the whole thing now, whatever you prefer. The first option is probably easier (kernels in this PR, pass in the next). |
Apologies, I didn't realize there were ongoing efforts to refactor the cache op out of unified_attention. I agree that the fusion pass is a more suitable approach moving forward for several reasons, such as being globally available to models without requiring changes to individual model blueprints (this PR only addressed Qwen3) and its ability to work cross-platform (not just on AMD). At the moment, I have limited bandwidth to rework this PR to include only the fusion kernels. IMHO, it might be best to simply close this one and start fresh. Feel free to close it, and I'll ping you with a new PR when I find more time—unless someone else beats me to it. |
|
Hi @mjkvaak-amd, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
Done in #33443 |
Purpose
This PR adds support for fusing QK RoPE, zeros, and reshape_and_cache in ROCm, and implements the fusion for Qwen3 and Qwen3-MoE models. This fusion kernel, _fused_qk_rope_reshape_and_cache_kernel, slightly improves the model speed, while not affecting the quality.
Test Plan
Please advise on the test plan, since it's not entirely clear based on existing tests whether/how this feature should be tested. Maybe something similar to
test_silu_mul_quant_fusion.py?Test Result
See "Test Plan"
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.