Skip to content

[ROCm] Validate block_size for explicitly selected attention backends#36846

Merged
gshtras merged 4 commits intovllm-project:mainfrom
ROCm:akaratza_fix_atten_dispatch
Mar 17, 2026
Merged

[ROCm] Validate block_size for explicitly selected attention backends#36846
gshtras merged 4 commits intovllm-project:mainfrom
ROCm:akaratza_fix_atten_dispatch

Conversation

@AndreasKaratzas
Copy link
Collaborator

@AndreasKaratzas AndreasKaratzas commented Mar 12, 2026

#36274 stripped block_size from attn_selector_config before backend validation in get_attn_backend_cls, which was correct for auto-selection (block_size may not be finalized at that point). However, this also bypassed block_size validation for explicitly user-selected backends, breaking the contract established in #36292.

  • Add an explicit supports_block_size check for the selected-backend path, before the strip
  • Auto-selection path is unaffected and block_size is still stripped there
  • Fixes: test_mla_backend_selection[env_vars1-TRITON_MLA-1-None-True]

cc @kenroche

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Mar 12, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 12, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the get_attn_backend_cls method in vllm/platforms/rocm.py to validate the block_size for explicitly selected attention backends. A new check is added to verify that the selected backend supports the specified block_size before this parameter is stripped from the configuration. This fixes a bug where this validation was being bypassed. The change is targeted and does not affect the automatic backend selection path. I have reviewed the changes and found no issues.

@AndreasKaratzas AndreasKaratzas added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 12, 2026
f"{backend_class.get_supported_kernel_block_sizes()}."
)

attn_selector_config = attn_selector_config._replace(block_size=None)
Copy link
Collaborator

@tjtanaa tjtanaa Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndreasKaratzas We should also guard this block size. Because there are also a use case where users do not specify a specific backend, but he specified the --block-size explicitly.

This #36274 looks more like a hotpatch.

We should be looking into solving supporting the correct size through get_supported_kernel_block_sizes of attention backend class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Indeed. I removed both my patch and patch from #36274

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndreasKaratzas Let's try to fix the get_supported_kernel_block_sizes of attention backend in this PR. Else Meta will encounter issue running Qwen3.5 after this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa But Qwen works fine. Are you referring to:

vllm bench throughput --model Qwen/Qwen3-Next-80B-A3B-Instruct --kv-cache-dtype auto --load-format dummy --input-len 1024 --output-len 1024 --num-prompts 128 --tensor-parallel-size 8 --dtype float16

Throughput: 3.61 requests/s, 4157.86 total tokens/s, 461.98 output tokens/s
Total num prompt tokens:  131072
Total num output tokens:  16384
vllm bench throughput --model Qwen/Qwen3-Next-80B-A3B-Instruct --kv-cache-dtype auto --load-format dummy --input-len 1024 --output-len 1024 --num-prompts 128 --tensor-parallel-size 8 --dtype float16 --attention-backend ROCM_ATTN

Throughput: 8.01 requests/s, 9228.46 total tokens/s, 1025.38 output tokens/s
Total num prompt tokens:  131072
Total num output tokens:  16384

The above are with this PR.

Copy link
Collaborator Author

@AndreasKaratzas AndreasKaratzas Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried Qwen 3.5 today as well.

vllm bench throughput --model Qwen/Qwen3.5-35B-A3B --load-format dummy --input-len 1024 --output-len 1024 --num-prompts 128 --tensor-parallel-size 8 --dtype float16

Throughput: 5.53 requests/s, 6367.64 total tokens/s, 707.52 output tokens/s
Total num prompt tokens:  131072
Total num output tokens:  16384
vllm bench throughput --model Qwen/Qwen3.5-35B-A3B --load-format dummy --input-len 1024 --output-len 1024 --num-prompts 128 --tensor-parallel-size 8 --dtype float16 --attention-backend ROCM_ATTN

Throughput: 8.34 requests/s, 9602.49 total tokens/s, 1066.94 output tokens/s
Total num prompt tokens:  131072
Total num output tokens:  16384

cc @jennyyyyzhen @Rohan138 @tjtanaa

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a code refactor effort on cuda side #35122 and probably fixed the previous issue.

@gshtras
Copy link
Collaborator

gshtras commented Mar 16, 2026

cc @Rohan138

@AndreasKaratzas
Copy link
Collaborator Author

AndreasKaratzas commented Mar 16, 2026

AMD CI is red due to new regression from: #36204
I will be addressing it in a different PR.

EDIT: Fixed here: #37219

@gshtras gshtras merged commit 3ed7b1e into vllm-project:main Mar 17, 2026
44 of 46 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 17, 2026
@gshtras gshtras deleted the akaratza_fix_atten_dispatch branch March 17, 2026 22:04
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants