Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,4 @@ configuration.
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
25 changes: 17 additions & 8 deletions tests/v1/attention/test_rocm_attention_backends_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ def mock_vllm_config():

@pytest.fixture
def mock_on_gfx9():
"""Mock the on_gfx9 function to return True."""
"""Mock gfx9 arch detection to return True."""
with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
yield


@pytest.fixture
def mock_on_mi3xx():
"""Mock mi3xx arch detection to return True."""
with patch("vllm.platforms.rocm.on_mi3xx", return_value=True):
yield


@pytest.mark.parametrize(
"env_vars, selected_backend, expected_backend_path",
[
Expand Down Expand Up @@ -122,6 +129,7 @@ def test_standard_attention_backend_selection(
expected_backend_path,
mock_vllm_config,
mock_on_gfx9,
mock_on_mi3xx,
monkeypatch,
):
"""Test standard attention backend selection with various configurations."""
Expand Down Expand Up @@ -313,16 +321,16 @@ def test_mla_backend_selection(
assert backend_path == expected_backend_path


def test_aiter_fa_requires_gfx9(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
def test_aiter_fa_requires_mi3xx(mock_vllm_config):
"""Test that ROCM_AITER_FA requires mi3xx architecture."""
from vllm.platforms.rocm import RocmPlatform

# Mock on_gfx9 to return False
# Mock on_mi3xx to return False (used by supports_compute_capability)
with (
patch("vllm.platforms.rocm.on_gfx9", return_value=False),
patch("vllm.platforms.rocm.on_mi3xx", return_value=False),
pytest.raises(
ValueError,
match="only supported on gfx9",
match="compute capability not supported",
),
):
attn_selector_config = AttentionSelectorConfig(
Expand All @@ -342,11 +350,12 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):


def test_sparse_not_supported(mock_vllm_config):
"""Test that sparse attention is not supported on ROCm."""
"""Test that sparse MLA without use_mla flag raises an error."""
from vllm.platforms.rocm import RocmPlatform

with pytest.raises(
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
ValueError,
match="No valid attention backend found",
):
attn_selector_config = AttentionSelectorConfig(
head_size=128,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class AiterMLABackend(MLACommonBackend):
"fp8_e5m2",
]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
Expand All @@ -33,6 +34,20 @@ class TritonMLABackend(MLACommonBackend):
"bfloat16",
]

@classmethod
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this or override supports_head_size and always return True instead?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rohan138 I think that having in each attention backend the proper supported head sizes is the way to go. Besides, regarding AITER, there may be some unsupported head sizes if I am not wrong. It's just that we have not precisely identified them, right?

def get_supported_head_sizes(cls) -> list[int]:
return []

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0

@staticmethod
def get_name() -> str:
return "TRITON_MLA"
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0

Comment on lines +32 to +37
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras This is also what we had before the two PRs landed correct?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this, get_supported_kernel_block_sizes above acheives the exact same effect

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests were failing before I added this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without these for example, nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 is not supported on ROCm.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repro commands:

  • cd .buildkite/lm-eval-harness; pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
  • bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040

@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# uses our optimized kernel logic.
return [16, 32, 544]

@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size in (16, 32, 544)

Comment on lines +191 to +196
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras Do we need to add more here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't need this, get_supported_kernel_block_sizes is called by AttentionBackend.supports_block_size already

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responding in the comment below on this. I believe that these definitions are necessary as there are failures if they are deleted.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is going to fix the block size issue of Qwen3.5 #35923

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa I see that both attempt to address this indeed. You need me to remove my patch here or shall we close the other one?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa I have integrated @JartX 's contribution here in hindsight. Let's first merge @JartX PR and then confirm that everything is smooth here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JartX Missed that part. let me see if I can just merge your commit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you did that already 😅

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0

Comment on lines +276 to +281
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras (Same question I had for ROCm AITER Unified Attn): This is also what we had before the two PRs landed correct?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without these for example, nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 is not supported on ROCm.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we triage why?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a "block size not supported message". So the straight forward solution was to define what is the supported block size set.

forward_includes_kv_cache_update: bool = False

@staticmethod
Expand Down