[ROCm][AITER] Fix AITER import regression for explicit backend selection#33749
[ROCm][AITER] Fix AITER import regression for explicit backend selection#33749tjtanaa merged 10 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
There was a problem hiding this comment.
Code Review
This pull request effectively resolves a regression that prevented explicit selection of the AITER attention backend on ROCm. The core of the fix is a well-reasoned separation of backend availability from user preference, which is now correctly handled. The changes are consistently applied across the codebase, including updates to is_aiter_found_and_supported, unconditional imports in the backend module, and adjustments to the speculative decoding logic. Additionally, the pull request includes a necessary fix for an obsolete test case, updating it to align with the current attention implementation. The code is well-documented, and the changes are robust. I have no major concerns.
| # This file is not going to be imported unless the user explicitly | ||
| # imports it or selects it via attention_config, so the aiter import | ||
| # here should not cause any issues. | ||
| import aiter |
There was a problem hiding this comment.
@AndreasKaratzas Let's do lazy loading.
we does this within the functions scope from aiter import pa_fwd_asm
Given that flash_attn_varlen_func is used repeatedly,
we should define this in the _aiter_ops.py and invoked it through aiter_ops. from aiter import flash_attn_varlen_func
Make sure this rocm_aiter_ops.flash_attn_varlen_func is not wrapped @is_aiter_supported as we want to invoke it even if the master switch is not enabled.
There was a problem hiding this comment.
@tjtanaa Lmk if the recent commit addresses your point :)
…rocm_aiter_ops Signed-off-by: Andreas Karatzas <akaratza@amd.com>
| if current_platform.is_rocm(): | ||
| from vllm._aiter_ops import is_aiter_found_and_supported | ||
|
|
||
| return is_aiter_found_and_supported() |
There was a problem hiding this comment.
This condition does not fit the current situation of the code.
This is_flash_attn_varlen_func_available() is a helper function for the ops imported in this vllm/v1/attention/backends/fa_utils.py Line 10- Line 32. Currently in the codebase, the condition to use or can use the aiter.flash_attn_varlen_func is explicitly handle in some other parts of the code.
Let me piece up more context. But I will need sometime.
There was a problem hiding this comment.
Did my recent commit remedy (a little bit) your concern? If possible, we would like this PR merged because the whole AITER_FA pipeline is currently broken on our CI. We could work on optimizing this on a follow-up PR.
… import flag Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
@AndreasKaratzas you also have to fix the issue with the op + # import so that aiter register the op to the namespace of
+ # torch.ops.aiter
+ import aiter # noqa: F401
torch.ops.aiter.paged_attention_v1(
output[:num_decode_tokens],
workspace_buffer,
query[:num_decode_tokens],
key_cache,
value_cache,
self.scale,
attn_metadata.block_table[:num_decodes],
attn_metadata.query_start_loc[:num_decodes],
attn_metadata.seq_lens[:num_decodes],
attn_metadata.max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
"NHD",
self.logits_soft_cap,
layer._k_scale,
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
)Without this, there will be an error if you run |
| if current_platform.is_rocm(): | ||
| # Use the flag set during module import to check if | ||
| # upstream flash-attn was successfully imported | ||
| return _ROCM_FLASH_ATTN_AVAILABLE |
There was a problem hiding this comment.
Thanks @AndreasKaratzas . This implementation fits current semantics.
|
@AndreasKaratzas Please also help to change the condition of
@classmethod
@if_aiter_supported
def is_shuffle_kv_cache_enabled(cls) -> bool:
- return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
+ return cls._SHUFFLE_KV_CACHE_ENABLEDSo that we can do this to deploy the preshuffle kvcache VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 vllm serve Qwen/Qwen3-0.6B --attention-backend "ROCM_AITER_FA" |
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
@tjtanaa I've added the suggested fixes. Lmk if things look better now. Happy to add any more suggestions. |
vllm/_aiter_ops.py
Outdated
| # 1. Backend modules (rocm_aiter_fa.py) import aiter directly when loaded | ||
| # 2. Individual op implementations import aiter locally when called | ||
| # 3. This module's ops are only called when an aiter backend is actually in use | ||
| if is_aiter_found_and_supported() and envs.VLLM_ROCM_USE_AITER: |
There was a problem hiding this comment.
This chain of bugfix is caused by not doing lazy import for aiter. Especially this DTYPE issue.
I have looked further into the issue. I have doubled check the definition of dtype from AITER, it is the same as current_platform.fp8_dtype(). So we don't have to import dtype here from aiter.
https://github.com/ROCm/aiter/blob/6af8b687480509d67f42a69bca7ed092c432e8dc/aiter/utility/dtypes.py#L15 (commit used in Dockerfile.rocm_base, it is the same definition)
https://github.com/ROCm/aiter/blob/12fe5f0291dad871584db71c49a4c33556519bbf/aiter/utility/dtypes.py#L17 (even on latest main commit, it is the same definition)
So the suggested changes in this file is to define this at the top of the file.
# fp8_dtype is not cached.
# on ROCm the fp8_dtype always call is_fp8_fnuz
# which is a host op
FP8_DTYPE = current_platform.fp8_dtype()
replace all the _FP8_DTYPE and AITER_FP8_DTYPE with FP8_DTYPE.
We don't need this section of the code
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
#
# Import strategy to avoid unwanted JIT warnings:
# - Only import aiter dtypes at module load if VLLM_ROCM_USE_AITER=1 (env var set)
# - This prevents JIT warnings during backend auto-discovery when aiter is not preferred
# - Explicit backend selection (via attention_config) still works because:
# 1. Backend modules (rocm_aiter_fa.py) import aiter directly when loaded
# 2. Individual op implementations import aiter locally when called
# 3. This module's ops are only called when an aiter backend is actually in use
if is_aiter_found_and_supported() and envs.VLLM_ROCM_USE_AITER:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
else:
# Placeholder when AITER is not the default - prevents NameError during module load.
# Note: This fallback is used for fake implementations and type checking.
# If an AITER backend is explicitly selected (even with env var=0),
# the backend module will import aiter directly (rocm_aiter_fa.py line 35).
AITER_FP8_DTYPE = _FP8_DTYPE
There was a problem hiding this comment.
I have validated locally with the following command.
The ops that are relying on this OP are fusion pass op of rmsnorm and blockquant.
sudo rm -rf ~/.cache/vllm
ATTN_BACKEND="ROCM_AITER_FA"
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 \
vllm serve Qwen/Qwen3-32B-FP8 \
--attention-backend $ATTN_BACKEND \
-O3 \
> launch_server_$ATTN_BACKEND-preshuffled-compilation.log 2>&1
Log showing replacement occurs
�[0;36m(APIServer pid=1155)�[0;0m DEBUG 02-06 03:32:37 [v1/engine/utils.py:980] Waiting for 1 local, 0 remote core engine proc(s) to start.
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/noop_elimination.py:105] Removed 0 no-op reshapes and slices
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] NoOpEliminationPass completed in 0.5 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/fusion.py:558] Replaced 0 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] RMSNormQuantFusionPass completed in 0.8 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/rocm_aiter_fusion.py:315] Replaced 1 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] RocmAiterRMSNormQuantFusionPass completed in 15.7 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/activation_quant_fusion.py:207] Replaced 0 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] ActivationQuantFusionPass completed in 0.4 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/rocm_aiter_fusion.py:394] Replaced 0 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] RocmAiterSiluMulFp8GroupQuantFusionPass completed in 0.4 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] PostCleanupPass completed in 0.2 ms
The lmeval score is
2026-02-06:03:40:49 INFO [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'model': 'Qwen/Qwen3-32B-FP8', 'base_url': 'http://127.0.0.1:8000/v1/completions'}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 100
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6353|± |0.0133|
| | |strict-match | 5|exact_match|↑ |0.7521|± |0.0119|
…cessary aiter import Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
@tjtanaa I updated wrt the suggested changes, let me know if it looks good now :) I accidentally pushed a line that was intended for a different PR and when I realized it, I removed it. Sry about that guys. |
tjtanaa
left a comment
There was a problem hiding this comment.
LGTM now. Thanks a lot.
tjtanaa
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the fix.
The status of following TGs:
Entrypoints Integration Test (API Server 1) [Passed]
Multi-Modal Models Test (Extended) 1 [Passed]
Multi-Modal Models Test (Standard) [Passed] (This takes more than 1 hour to run)
V1 Test others (failed due to numerical issue, so shouldn't be related to aiter)
…ion (vllm-project#33749) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…ion (vllm-project#33749) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
A regression was introduced that broke explicit AITER backend selection on ROCm when
VLLM_ROCM_USE_AITER=0(or unset). Users could not explicitly select the AITER backend viaattention_config={"backend": "ROCM_AITER_FA"}even though the backend was available.Error observed:
This occurred because:
is_aiter_found_and_supported()incorrectly checked the env var, preventing aiter import during explicit selectionallowed_attn_typesvalidation rejectedAiterFlashAttentionMetadatawhen env var was unsetfa_utils.pybroke Eagle tests with import errorsRoot Cause
The original motivation to avoid JIT compilation warnings was correct, but the implementation was incomplete:
VLLM_ROCM_USE_AITERto avoid JIT warnings ✅The fix needed to implement proper OR logic: import aiter if
VLLM_ROCM_USE_AITER=1OR if AITER backend is explicitly selected.Brief explanation & summary of changes
Changes
1.
vllm/_aiter_ops.py- Separate Availability from PreferenceBefore:
After:
AITER_FP8_DTYPEimport to prevent JIT warnings2.
vllm/v1/attention/backends/rocm_aiter_fa.py- Document Unconditional ImportAdded comments explaining why
import aiteris unconditional at module level:This clarifies the intentional design decision.
3.
vllm/v1/spec_decode/eagle.py- Fix Allowed Types ValidationBefore:
After:
This allows explicit backend selection to work by ensuring
AiterFlashAttentionMetadatais always inallowed_attn_typesif the backend module exists, regardless of the env var.4.
vllm/v1/attention/backends/fa_utils.py- Add Missing ROCm ImportsAdded ROCm-specific imports that were missing:
Also fixed
is_flash_attn_varlen_func_available()to correctly report AITER availability on ROCm:Why these were needed:
flash_attn.pyimports these functions fromfa_utils.py5.
tests/kernels/attention/test_aiter_flash_attn.py- Migrate Obsolete TestProblem: Test called
torch.ops.vllm.flash_attn_varlen_func, a custom op that was removed during backend refactoring.Solution: Migrated test to use current two-step architecture:
cp_mha_gather_cache()- Fetch paged KV cache via Triton kernelaiter.flash_attn_varlen_func()- Compute attention on gathered KVThis accurately reflects how the production AITER backend actually works.
Design Philosophy
The fix implements a clear separation:
is_aiter_found_and_supported()rocm_aiter_ops.is_enabled()This ensures:
Testing
Proposed fixes were tested with:
pytest -s -v tests/models/multimodal/generation/test_granite_speech.py::test_models[10-128-512-float16-ibm-granite/granite-speech-3.3-2b]pytest -s -v tests/kernels/attention/test_aiter_flash_attn.pypytest -s -v tests/v1/spec_decode/test_eagle.pypytest -s -v tests/v1/e2e/test_spec_decode.pyTP_SIZE=2 DP_SIZE=2 pytest -v -s tests/v1/distributed/test_eagle_dp.py