Skip to content

[Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic#26980

Merged
tjtanaa merged 11 commits intovllm-project:mainfrom
EmbeddedLLM:fix-rocmattnselection
Nov 24, 2025
Merged

[Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic#26980
tjtanaa merged 11 commits intovllm-project:mainfrom
EmbeddedLLM:fix-rocmattnselection

Conversation

@vllmellm
Copy link
Copy Markdown
Contributor

@vllmellm vllmellm commented Oct 16, 2025

Purpose

This is to simplify the user experiences so that users do not need to keep track of the environment variables to select the backend. Users can directly specify the backend through VLLM_ATTENTION_BACKEND.

Fix Attention Backend Configuration on ROCm

This PR fixes several issues with attention backend configuration on AMD ROCm platforms to ensure proper backend selection based on environment variables and flags.

Issues Fixed

  1. Missing ROCM_AITER_FA in V1 Oracle List

    • Added ROCM_AITER_FA to the V1 oracle list in vllm/engine/arg_utils.py to allow using AITER Flash Attention backend without enabling all AITER kernels
    • Fixes the case: VLLM_ATTENTION_BACKEND="ROCM_AITER_FA" vllm serve ...
  2. Incorrect Backend Selection with VLLM_ROCM_USE_AITER=1

    • Fixed backend selection logic to properly respect VLLM_ATTENTION_BACKEND when VLLM_ROCM_USE_AITER=1 is set
    • Previously, setting VLLM_ROCM_USE_AITER=1 would override explicit backend choices
    • Now supports:
      • VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="TRITON_ATTN" → Uses Triton attention
      • VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="ROCM_ATTN" → Uses ROCm chunked prefill/paged decode
      • VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="ROCM_AITER_UNIFIED_ATTN" → Uses AITER unified attention
  3. Improved Flag Combinations

    • Enhanced handling of VLLM_ROCM_USE_AITER_MHA flag to properly disable AITER MHA when set to 0
    • Better interaction between VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and backend selection

Supported Attention Backends on ROCm

After this fix, the following backends are properly supported:

  • TRITON_ATTN: vLLM's Triton unified attention (default)
  • ROCM_ATTN: Chunked prefill (Triton) + paged decode (HIP)
  • ROCM_AITER_FA: AITER Flash Attention
  • ROCM_AITER_UNIFIED_ATTN: AITER unified attention

MLA Backend Support

  • TRITON_MLA: Triton MLA backend for DeepSeek models (requires block-size >= 16)
  • ROCM_AITER_MLA: AITER MLA backend (now supports default block-size, and block-size 1)

Test Plan

All documented configuration (find in appendix) combinations now work correctly:

  • Explicit backend selection via VLLM_ATTENTION_BACKEND
  • Backend selection via VLLM_ROCM_USE_AITER flag combinations
  • Proper fallback behavior when flags conflict

Added a unit tests tests/v1/attention/test_rocm_attention_backends_selection.py

Test Result

All commands validated to run the correct backend

pytest tests/v1/attention/test_rocm_attention_backends_selection.py: All test passed.

Appendix (This bugfix PR allows all the following command be working correctly)

Attention Backend on ROCm

Attention backend can be set through various ways:

  • VLLM_ATTENTION_BACKEND
  • through combinations of VLLM_ROCM_ flags.

Attention

On AMD ROCm there are TRITON_ATTN, ROCM_ATTN, FLASH_ATTN or ROCM_AITER_UNIFIED_ATTN.

  • TRITON_ATTN:
    • Uses vLLM's triton unified attention backend. Both the prefill and decode are triton kernels.
      Example command:
      # Example 1
      vllm serve meta-llama/Llama-3.1-8B-Instruct
      # Example 2
      VLLM_ATTENTION_BACKEND="TRITON_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      # Example 3 (When enable AITER but still want to use TRITON_ATTN)
      VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="TRITON_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      # OR
      VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0 vllm serve meta-llama/Llama-3.1-8B-Instruct
  • ROCM_ATTN
    • Uses vLLM's chunked prefill paged decode kernel. The prefill is triton kernel and the decode is custom HIP paged attention kernel.
    • Examples
      # Example 1
      VLLM_ATTENTION_BACKEND="ROCM_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      
      # Example 2 (When enable AITER but still want to use TRITON_ATTN)
      VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="ROCM_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      # OR
      VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0 VLLM_ATTENTION_BACKEND="ROCM_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      
      # OR
      VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
  • ROCM_AITER_FA
    • Use the AITER Flash Attention backend.
    • Examples
      # Example 1 (Only use AITER FA backend without enabling other AITER kernels)
      VLLM_ATTENTION_BACKEND="ROCM_AITER_FA" vllm serve meta-llama/Llama-3.1-8B-Instruct
      
      # Example 2
      VLLM_ROCM_USE_AITER=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
  • ROCM_AITER_UNIFIED_ATTN
    • Use AITER unified attention backend.
    • Examples
      # Example 1 (Only use AITER FA backend without enabling other AITER kernels)
      VLLM_ATTENTION_BACKEND="ROCM_AITER_UNIFIED_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      
      # Example 2
      VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="ROCM_AITER_UNIFIED_ATTN" vllm serve meta-llama/Llama-3.1-8B-Instruct
      # OR 
      VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1 vllm serve meta-llama/Llama-3.1-8B-Instruct

MLA Backend:

On AMD ROCm, there are TRITON_MLA and ROCM_AITER_MLA

  • TRITON_MLA:
    • Uses vLLM's triton MLA backend. The prefill uses triton flash attention/ CK flash attention varlen, and decode uses triton mla decode kernel. Requires block_size >= 16.
    • Example commands:
      VLLM_ATTENTION_BACKEND="TRITON_MLA" vllm serve deepseek-ai/DeepSeek-R1 -tp 8
      
      VLLM_ROCM_USE_AITER=1 VLLM_ATTENTION_BACKEND="TRITON_MLA" vllm serve deepseek-ai/DeepSeek-R1 -tp 8
  • ROCM_AITER_MLA:
    • Uses AITER MLA backend. Now supports default block-size (16) and block-size 1 .
    • Example commands:
      VLLM_ATTENTION_BACKEND="ROCM_AITER_MLA" vllm serve deepseek-ai/DeepSeek-R1 -tp 8
      
      VLLM_ROCM_USE_AITER=1 vllm serve deepseek-ai/DeepSeek-R1 -tp 8  --block-size 1
      
      VLLM_ROCM_USE_AITER=1 vllm serve deepseek-ai/DeepSeek-R1 -tp 8

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Oct 16, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly improves the ROCm attention backend selection logic by reorganizing it for clarity and correctness. Prioritizing explicit user selection over environment variable-based auto-selection is a great change that enhances predictability. The addition of a comprehensive suite of unit tests is also excellent, ensuring the new logic is robust and well-verified. I have one suggestion to improve the clarity and correctness of a mock in the new test file.

Comment on lines +246 to +251
# Mock is_aiter_mla_enabled based on env vars and block_size
aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1" and block_size == 1
with patch(
"vllm.v1.attention.backends.mla.rocm_aiter_mla.is_aiter_mla_enabled",
return_value=aiter_enabled,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mock for is_aiter_mla_enabled is a bit confusing as its return value depends on block_size, which is an input to the function under test (get_attn_backend_cls), not to is_aiter_mla_enabled itself. This couples the test's mock behavior too tightly with the implementation details of the function being tested, making it harder to understand and maintain.

The purpose of is_aiter_mla_enabled is to check environment variables, while the block_size check happens within get_attn_backend_cls. The test should reflect this separation of concerns.

You can simplify the mock to only depend on the environment variables, which will make the test clearer and more robust. The current mock makes the condition is_aiter_mla_enabled() and block_size == 1 effectively become (env_vars.get("VLLM_ROCM_USE_AITER") == "1" and block_size == 1) and block_size == 1, which is redundant. The suggested change correctly tests the logic by mocking is_aiter_mla_enabled based only on its own dependencies (the environment variables).

Suggested change
# Mock is_aiter_mla_enabled based on env vars and block_size
aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1" and block_size == 1
with patch(
"vllm.v1.attention.backends.mla.rocm_aiter_mla.is_aiter_mla_enabled",
return_value=aiter_enabled,
):
# Mock is_aiter_mla_enabled based on env vars
aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1"
with patch(
"vllm.v1.attention.backends.mla.rocm_aiter_mla.is_aiter_mla_enabled",
return_value=aiter_enabled,
):

@vllmellm vllmellm changed the title [Bugfix] [ROCm] Reorganize ROCm Backend Selection Logic [Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic Oct 16, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
)
# When AITER is enabled and block_size is 1, use AITER MLA
# Otherwise, use TRITON MLA
if is_aiter_mla_enabled() and block_size == 1:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the is_aiter_mla_enabled() has been changed to rocm_aiter_ops.is_mla_enabled() from the _aiter_ops.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, Thanks!

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm
Copy link
Copy Markdown
Contributor Author

Hi @tjtanaa , I have updated code and push it, meanwhile according to Appendix to re-run successful. When you have time pls take a look, Thanks!

@vllmellm vllmellm requested a review from tjtanaa November 13, 2025 08:53
# When AITER is enabled and block_size is 1, use AITER MLA
# Otherwise, use TRITON MLA
if rocm_aiter_ops.is_mla_enabled() and block_size == 1:
selected_backend = AttentionBackendEnum.ROCM_AITER_MLA
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCM_AITER_MLA now supports block-size larger than 1.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, Thanks for your suggestion, I will fix them one by one.

if rocm_aiter_ops.is_mla_enabled() and block_size == 1:
selected_backend = AttentionBackendEnum.ROCM_AITER_MLA
else:
selected_backend = AttentionBackendEnum.TRITON_MLA
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TRITON_MLA must have block size of at least 16, it does not support block size of 1. Need to add statement to guard it

def opaque_attention_op(cls) -> bool:
return True

@classmethod
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is this line diff? I remember that this has already been removed.

dtype=torch.float16,
kv_cache_dtype="auto",
block_size=16,
use_v1=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this. there is no use_v1 argument anymore.

dtype=torch.float16,
kv_cache_dtype="auto",
block_size=16,
use_v1=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this. there is no use_v1 argument anymore.

dtype=torch.float16,
kv_cache_dtype="auto",
block_size=16,
use_v1=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this. there is no use_v1 argument anymore.

dtype=torch.float16,
kv_cache_dtype="auto",
block_size=16,
use_v1=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this. there is no use_v1 argument anymore.

dtype=torch.float16,
kv_cache_dtype="auto",
block_size=block_size,
use_v1=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this. there is no use_v1 argument anymore.

dtype=torch.float16,
kv_cache_dtype="auto",
block_size=16,
use_v1=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this. there is no use_v1 argument anymore.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Nov 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 21, 2025
# Conflicts:
#	vllm/platforms/rocm.py

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Nov 21, 2025
@vllmellm
Copy link
Copy Markdown
Contributor Author

Hi @tjtanaa , I have updated these code according to your suggestion, about pytest and re-run different commands successful, When you have time pls take a look, Thanks!

@vllmellm vllmellm requested a review from tjtanaa November 21, 2025 05:22
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm
Copy link
Copy Markdown
Contributor Author

Hi @tjtanaa , I've revised the code based on the latest main branch. All commands have been executed. pls review the code when you have time. Thanks!

(
{},
None,
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you replace this with the notation like AttentionBackendEnum.ROCM_AITER_FA.get_path()

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 24, 2025
@tjtanaa tjtanaa enabled auto-merge (squash) November 24, 2025 09:07
@tjtanaa tjtanaa merged commit e48b2e6 into vllm-project:main Nov 24, 2025
45 checks passed
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@vllmellm Why was FLEX_ATTENTION backend logic removed? It is still a V1 backend and also it is supported by ROCm.

RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
…oject#26980)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
@vllmellm
Copy link
Copy Markdown
Contributor Author

vllmellm commented Nov 25, 2025

@vllmellm Why was FLEX_ATTENTION backend logic removed? It is still a V1 backend and also it is supported by ROCm.

Hi, good catch. We just put up a fix here: #29371. Thanks.

devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants