Skip to content

[ROCm][AITER] Fix AITER import regression for explicit backend selection#33749

Merged
tjtanaa merged 10 commits intovllm-project:mainfrom
ROCm:akaratza_flash_attn_builtin_fix
Feb 6, 2026
Merged

[ROCm][AITER] Fix AITER import regression for explicit backend selection#33749
tjtanaa merged 10 commits intovllm-project:mainfrom
ROCm:akaratza_flash_attn_builtin_fix

Conversation

@AndreasKaratzas
Copy link
Collaborator

@AndreasKaratzas AndreasKaratzas commented Feb 4, 2026

A regression was introduced that broke explicit AITER backend selection on ROCm when VLLM_ROCM_USE_AITER=0 (or unset). Users could not explicitly select the AITER backend via attention_config={"backend": "ROCM_AITER_FA"} even though the backend was available.

Error observed:

AttributeError: 'builtin_function_or_method' object has no attribute 'flash_attn_varlen_func'

This occurred because:

  1. is_aiter_found_and_supported() incorrectly checked the env var, preventing aiter import during explicit selection
  2. Eagle's allowed_attn_types validation rejected AiterFlashAttentionMetadata when env var was unset
  3. Missing ROCm-specific imports in fa_utils.py broke Eagle tests with import errors
  4. Test suite had an obsolete test using a removed custom op

Root Cause

The original motivation to avoid JIT compilation warnings was correct, but the implementation was incomplete:

  • Auto-discovery path: Should respect VLLM_ROCM_USE_AITER to avoid JIT warnings ✅
  • Explicit selection path: Should work regardless of env var (was broken) ❌

The fix needed to implement proper OR logic: import aiter if VLLM_ROCM_USE_AITER=1 OR if AITER backend is explicitly selected.

Brief explanation & summary of changes

Changes

1. vllm/_aiter_ops.py - Separate Availability from Preference

Before:

def is_aiter_found_and_supported() -> bool:
    if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
        # Incorrectly included env var check

After:

def is_aiter_found_and_supported() -> bool:
    if current_platform.is_rocm() and IS_AITER_FOUND:
        # Only checks platform + library + arch (availability)
        # Env var check moved to module-level imports for JIT warning prevention
  • Removed env var from availability check (capability vs preference)
  • Re-added env var check to module-level AITER_FP8_DTYPE import to prevent JIT warnings
  • Updated docstrings to clarify the separation

2. vllm/v1/attention/backends/rocm_aiter_fa.py - Document Unconditional Import

Added comments explaining why import aiter is unconditional at module level:

# IMPORTANT: This import is UNCONDITIONAL on ROCm (no env var check).
# This enables explicit backend selection via attention_config to work
# even when VLLM_ROCM_USE_AITER=0.

This clarifies the intentional design decision.

3. vllm/v1/spec_decode/eagle.py - Fix Allowed Types Validation

Before:

if rocm_aiter_ops.is_enabled() and find_spec(...):
    # Only added to allowed_attn_types if env var set

After:

if find_spec(...):
    # Always add if module exists (capability check)

This allows explicit backend selection to work by ensuring AiterFlashAttentionMetadata is always in allowed_attn_types if the backend module exists, regardless of the env var.

4. vllm/v1/attention/backends/fa_utils.py - Add Missing ROCm Imports

Added ROCm-specific imports that were missing:

elif current_platform.is_rocm():
    # Import flash_attn_varlen_func from flash-attn package
    from flash_attn import flash_attn_varlen_func
    
    # NEW: ROCm doesn't use scheduler metadata (FA3 feature)
    def get_scheduler_metadata(*args, **kwargs) -> None:
        return None
    
    # NEW: ROCm uses C++ custom op for reshape_and_cache
    reshape_and_cache_flash = torch.ops._C_cache_ops.reshape_and_cache_flash

Also fixed is_flash_attn_varlen_func_available() to correctly report AITER availability on ROCm:

# Before: Always returned False on ROCm (bug)
# After: Checks if aiter is available
if current_platform.is_rocm():
    from vllm._aiter_ops import is_aiter_found_and_supported
    return is_aiter_found_and_supported()

Why these were needed:

  • flash_attn.py imports these functions from fa_utils.py
  • They were only defined for CUDA/XPU, causing import errors on ROCm
  • This broke all Eagle tests on ROCm with AITER backend

5. tests/kernels/attention/test_aiter_flash_attn.py - Migrate Obsolete Test

Problem: Test called torch.ops.vllm.flash_attn_varlen_func, a custom op that was removed during backend refactoring.

Solution: Migrated test to use current two-step architecture:

  1. cp_mha_gather_cache() - Fetch paged KV cache via Triton kernel
  2. aiter.flash_attn_varlen_func() - Compute attention on gathered KV

This accurately reflects how the production AITER backend actually works.

Design Philosophy

The fix implements a clear separation:

Concept What it checks Where used Example
Availability Platform + library + arch is_aiter_found_and_supported() Tests, capability checks
Preference Availability + env var rocm_aiter_ops.is_enabled() Auto-discovery, default backend

This ensures:

  • Auto-discovery respects user preference (env var)
  • Explicit selection respects user intent (config parameter)
  • Tests run when capability exists, regardless of preference

Testing

Proposed fixes were tested with:

  • pytest -s -v tests/models/multimodal/generation/test_granite_speech.py::test_models[10-128-512-float16-ibm-granite/granite-speech-3.3-2b]
  • pytest -s -v tests/kernels/attention/test_aiter_flash_attn.py
  • pytest -s -v tests/v1/spec_decode/test_eagle.py
  • pytest -s -v tests/v1/e2e/test_spec_decode.py
  • TP_SIZE=2 DP_SIZE=2 pytest -v -s tests/v1/distributed/test_eagle_dp.py

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves a regression that prevented explicit selection of the AITER attention backend on ROCm. The core of the fix is a well-reasoned separation of backend availability from user preference, which is now correctly handled. The changes are consistently applied across the codebase, including updates to is_aiter_found_and_supported, unconditional imports in the backend module, and adjustments to the speculative decoding logic. Additionally, the pull request includes a necessary fix for an obsolete test case, updating it to align with the current attention implementation. The code is well-documented, and the changes are robust. I have no major concerns.

@AndreasKaratzas
Copy link
Collaborator Author

cc @vllmellm @rabi @tjtanaa

This PR attempts to address the regression introduced here: #32902
Let me know what you think. This PR should get merged quickly to mitigate AITER FA current failures.

# This file is not going to be imported unless the user explicitly
# imports it or selects it via attention_config, so the aiter import
# here should not cause any issues.
import aiter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndreasKaratzas Let's do lazy loading.

we does this within the functions scope from aiter import pa_fwd_asm

Given that flash_attn_varlen_func is used repeatedly,
we should define this in the _aiter_ops.py and invoked it through aiter_ops. from aiter import flash_attn_varlen_func

Make sure this rocm_aiter_ops.flash_attn_varlen_func is not wrapped @is_aiter_supported as we want to invoke it even if the master switch is not enabled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Lmk if the recent commit addresses your point :)

if current_platform.is_rocm():
from vllm._aiter_ops import is_aiter_found_and_supported

return is_aiter_found_and_supported()
Copy link
Collaborator

@tjtanaa tjtanaa Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition does not fit the current situation of the code.

This is_flash_attn_varlen_func_available() is a helper function for the ops imported in this vllm/v1/attention/backends/fa_utils.py Line 10- Line 32. Currently in the codebase, the condition to use or can use the aiter.flash_attn_varlen_func is explicitly handle in some other parts of the code.

Let me piece up more context. But I will need sometime.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did my recent commit remedy (a little bit) your concern? If possible, we would like this PR merged because the whole AITER_FA pipeline is currently broken on our CI. We could work on optimizing this on a follow-up PR.

… import flag

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@tjtanaa
Copy link
Collaborator

tjtanaa commented Feb 5, 2026

@AndreasKaratzas you also have to fix the issue with the op torch.ops.aiter.paged_attention_v1

+                    # import so that aiter register the op to the namespace of
+                    # torch.ops.aiter
+                    import aiter # noqa: F401

                    torch.ops.aiter.paged_attention_v1(
                        output[:num_decode_tokens],
                        workspace_buffer,
                        query[:num_decode_tokens],
                        key_cache,
                        value_cache,
                        self.scale,
                        attn_metadata.block_table[:num_decodes],
                        attn_metadata.query_start_loc[:num_decodes],
                        attn_metadata.seq_lens[:num_decodes],
                        attn_metadata.max_seq_len,
                        self.alibi_slopes,
                        self.kv_cache_dtype,
                        "NHD",
                        self.logits_soft_cap,
                        layer._k_scale,
                        layer._v_scale,
                        None,
                        _PARTITION_SIZE_ROCM,
                    )

Without this, there will be an error if you run

vllm serve Qwen/Qwen3-0.6B --attention-backend "ROCM_AITER_FA"

if current_platform.is_rocm():
# Use the flag set during module import to check if
# upstream flash-attn was successfully imported
return _ROCM_FLASH_ATTN_AVAILABLE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AndreasKaratzas . This implementation fits current semantics.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Feb 5, 2026

@AndreasKaratzas Please also help to change the condition of is_shuffle_kv_cache_enabled as it also fits the topic of this PR.

is_shuffle_kv_cache_enabled is tied to ROCM_AITER_FA, we would also want it to be able to be enabled even if the master switch VLLM_ROCM_USE_AITER is not enabled.

    @classmethod
    @if_aiter_supported
    def is_shuffle_kv_cache_enabled(cls) -> bool:
-        return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
+       return cls._SHUFFLE_KV_CACHE_ENABLED

So that we can do this to deploy the preshuffle kvcache pa_fwd_asm kernel

VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 vllm serve Qwen/Qwen3-0.6B --attention-backend "ROCM_AITER_FA"

@AndreasKaratzas
Copy link
Collaborator Author

@tjtanaa I've added the suggested fixes. Lmk if things look better now. Happy to add any more suggestions.

# 1. Backend modules (rocm_aiter_fa.py) import aiter directly when loaded
# 2. Individual op implementations import aiter locally when called
# 3. This module's ops are only called when an aiter backend is actually in use
if is_aiter_found_and_supported() and envs.VLLM_ROCM_USE_AITER:
Copy link
Collaborator

@tjtanaa tjtanaa Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chain of bugfix is caused by not doing lazy import for aiter. Especially this DTYPE issue.

I have looked further into the issue. I have doubled check the definition of dtype from AITER, it is the same as current_platform.fp8_dtype(). So we don't have to import dtype here from aiter.

https://github.com/ROCm/aiter/blob/6af8b687480509d67f42a69bca7ed092c432e8dc/aiter/utility/dtypes.py#L15 (commit used in Dockerfile.rocm_base, it is the same definition)

https://github.com/ROCm/aiter/blob/12fe5f0291dad871584db71c49a4c33556519bbf/aiter/utility/dtypes.py#L17 (even on latest main commit, it is the same definition)

So the suggested changes in this file is to define this at the top of the file.

# fp8_dtype is not cached.
# on ROCm the fp8_dtype always call is_fp8_fnuz
# which is a host op
FP8_DTYPE = current_platform.fp8_dtype()

replace all the _FP8_DTYPE and AITER_FP8_DTYPE with FP8_DTYPE.

We don't need this section of the code

# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
#
# Import strategy to avoid unwanted JIT warnings:
# - Only import aiter dtypes at module load if VLLM_ROCM_USE_AITER=1 (env var set)
# - This prevents JIT warnings during backend auto-discovery when aiter is not preferred
# - Explicit backend selection (via attention_config) still works because:
#   1. Backend modules (rocm_aiter_fa.py) import aiter directly when loaded
#   2. Individual op implementations import aiter locally when called
#   3. This module's ops are only called when an aiter backend is actually in use
if is_aiter_found_and_supported() and envs.VLLM_ROCM_USE_AITER:
    from aiter import dtypes

    AITER_FP8_DTYPE = dtypes.fp8
else:
    # Placeholder when AITER is not the default - prevents NameError during module load.
    # Note: This fallback is used for fake implementations and type checking.
    # If an AITER backend is explicitly selected (even with env var=0),
    # the backend module will import aiter directly (rocm_aiter_fa.py line 35).
    AITER_FP8_DTYPE = _FP8_DTYPE

Copy link
Collaborator

@tjtanaa tjtanaa Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have validated locally with the following command.
The ops that are relying on this OP are fusion pass op of rmsnorm and blockquant.

sudo rm -rf ~/.cache/vllm

ATTN_BACKEND="ROCM_AITER_FA"

VLLM_LOGGING_LEVEL=DEBUG \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 \
vllm serve Qwen/Qwen3-32B-FP8 \
--attention-backend $ATTN_BACKEND \
-O3 \
> launch_server_$ATTN_BACKEND-preshuffled-compilation.log 2>&1

Log showing replacement occurs

�[0;36m(APIServer pid=1155)�[0;0m DEBUG 02-06 03:32:37 [v1/engine/utils.py:980] Waiting for 1 local, 0 remote core engine proc(s) to start.
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/noop_elimination.py:105] Removed 0 no-op reshapes and slices
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] NoOpEliminationPass completed in 0.5 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/fusion.py:558] Replaced 0 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] RMSNormQuantFusionPass completed in 0.8 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/rocm_aiter_fusion.py:315] Replaced 1 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] RocmAiterRMSNormQuantFusionPass completed in 15.7 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/activation_quant_fusion.py:207] Replaced 0 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] ActivationQuantFusionPass completed in 0.4 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/rocm_aiter_fusion.py:394] Replaced 0 patterns
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] RocmAiterSiluMulFp8GroupQuantFusionPass completed in 0.4 ms
�[0;36m(EngineCore_DP0 pid=1441)�[0;0m DEBUG 02-06 03:32:39 [compilation/vllm_inductor_pass.py:79] PostCleanupPass completed in 0.2 ms

The lmeval score is

2026-02-06:03:40:49 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'model': 'Qwen/Qwen3-32B-FP8', 'base_url': 'http://127.0.0.1:8000/v1/completions'}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 100
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6353|±  |0.0133|
|     |       |strict-match    |     5|exact_match|↑  |0.7521|±  |0.0119|

…cessary aiter import

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Collaborator Author

@tjtanaa I updated wrt the suggested changes, let me know if it looks good now :) I accidentally pushed a line that was intended for a different PR and when I realized it, I removed it. Sry about that guys.

Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now. Thanks a lot.

Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix.
The status of following TGs:
Entrypoints Integration Test (API Server 1) [Passed]
Multi-Modal Models Test (Extended) 1 [Passed]
Multi-Modal Models Test (Standard) [Passed] (This takes more than 1 hour to run)
V1 Test others (failed due to numerical issue, so shouldn't be related to aiter)

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 6, 2026
@tjtanaa tjtanaa enabled auto-merge (squash) February 6, 2026 12:22
@tjtanaa tjtanaa merged commit 350ca72 into vllm-project:main Feb 6, 2026
52 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 6, 2026
@AndreasKaratzas AndreasKaratzas deleted the akaratza_flash_attn_builtin_fix branch February 6, 2026 17:28
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…ion (vllm-project#33749)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…ion (vllm-project#33749)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants