Skip to content

[Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5#35923

Merged
DarkLight1337 merged 16 commits intovllm-project:mainfrom
JartX:bugfix/qwen35_rocm_attn
Mar 11, 2026
Merged

[Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5#35923
DarkLight1337 merged 16 commits intovllm-project:mainfrom
JartX:bugfix/qwen35_rocm_attn

Conversation

@JartX
Copy link
Contributor

@JartX JartX commented Mar 3, 2026

This PR adds multiple of 16 to the list of supported kernel block sizes in RocmAttentionBackend

When running Qwen3.5 models using the ROCM_ATTN backend, the model produces broken, nonsensical outputs (e.g., repeating exclamation marks like !!!!!!!!!!). This happens because Qwen3.5 utilizes a non-standard block size of 1056. Since this size was not explicitly permitted, the model failed to correctly route the value_cache through the optimized Triton kernel fallback (triton_reshape_and_cache_flash).

Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX requested review from gshtras and tjtanaa as code owners March 3, 2026 22:11
@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm v1 bug Something isn't working labels Mar 3, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 3, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds support for the Qwen3.5 model on ROCm by including its non-standard block size of 1056 in the RocmAttentionBackend. This change is a simple and effective fix for the reported issue, allowing the model to use the appropriate Triton kernel fallback. The implementation is correct and I have no further suggestions for improvement.

Signed-off-by: JartX <sagformas@epdcenter.es>
@mergify
Copy link

mergify bot commented Mar 3, 2026

Hi @JartX, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: JartX <sagformas@epdcenter.es>
@mergify
Copy link

mergify bot commented Mar 3, 2026

Documentation preview: https://vllm--35923.org.readthedocs.build/en/35923/

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 3, 2026
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX force-pushed the bugfix/qwen35_rocm_attn branch from c07fc35 to 6212617 Compare March 4, 2026 10:28
JartX added 2 commits March 4, 2026 12:09
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX changed the title [Bugfix] Add 1056 block_size to triton fallback on rocm Attention to support qwen3_5 [Bugfix] Add 784,1056 block_size to triton fallback on rocm Attention to support qwen3_5 Mar 4, 2026
@JartX JartX changed the title [Bugfix] Add 784,1056 block_size to triton fallback on rocm Attention to support qwen3_5 [Bugfix] Add 784 and 1056 block_size to triton fallback on rocm Attention to support qwen3_5 Mar 4, 2026
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX changed the title [Bugfix] Add 784 and 1056 block_size to triton fallback on rocm Attention to support qwen3_5 [Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5 Mar 5, 2026
@JartX
Copy link
Contributor Author

JartX commented Mar 5, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where Qwen3.5 models produced incorrect outputs on the ROCm backend. The fix correctly identifies that the non-standard block size was the issue and generalizes the supported block sizes for the ROCM_ATTN backend to any multiple of 16. This is a good change that improves robustness for future models. The corresponding documentation has also been updated. However, I've found a critical issue related to this change that could cause failures for other models.

Comment on lines 168 to +177
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# ROCM paged attention kernel only supports block sizes 16 and 32
# ROCM paged attention native C++ kernel only supports block sizes 16 and 32
# due to shared memory (LDS) constraints on AMD GPUs.
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.

# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
# but vLLM should allow support for non-standard sizes via the Triton path,
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
# where the Triton kernel under rocm_atten does not support inference
# for a non-standard qwen3-next model with a block_size of 544.
# We have fixed the Triton kernel so that the standard model uses the original
# bit-addressing logic, while the non-standard model
# uses our optimized kernel logic.
return [16, 32, 544]
# However, vLLM allows support for any multiple of 16 via the Triton path.
# As addressed in PR: https://github.com/vllm-project/vllm/pull/31380,
# non-standard models (like qwen3-next with block_size 544, or qwen3_5
# with 784 and 1056) are dynamically routed to our optimized Triton kernel
# in `do_kv_cache_update`.
return [MultipleOf(16)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change to allow any block size that is a multiple of 16 is correct for supporting models like Qwen3.5, it introduces a potential failure for other models.

The dispatch logic in do_kv_cache_update (lines 450-480) uses is_pow2 to decide whether to use the native C++ kernel or the Triton fallback. The native C++ kernel, as noted in the comments and confirmed in csrc/rocm/attention.cu, only supports block sizes of 16 and 32.

With this PR, a model using a block size that is a power of two but not 16 or 32 (e.g., 64) will be incorrectly routed to the native C++ kernel, which will then raise an error.

To fix this, the condition in do_kv_cache_update should be changed from if is_pow2: to if block_size in (16, 32):. This will ensure that only the explicitly supported block sizes are routed to the native kernel, and all others (including other powers of two) use the Triton fallback.

@mergify
Copy link

mergify bot commented Mar 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JartX.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 6, 2026
@mergify mergify bot removed the needs-rebase label Mar 6, 2026
@JartX
Copy link
Contributor Author

JartX commented Mar 6, 2026

/gemini review

@JartX
Copy link
Contributor Author

JartX commented Mar 6, 2026

@tjtanaa failed no such container in AMD entrypoint

@JartX
Copy link
Contributor Author

JartX commented Mar 7, 2026

@Rohan138
47] File "/usr/local/lib/python3.12/dist-packages/vllm/platforms/rocm.py", line 481, in get_attn_backend_cls
(Worker pid=618) (Worker_TP1 pid=618) ERROR 03-07 11:59:41 [multiproc_executor.py:847] raise ValueError(
(Worker pid=618) (Worker_TP1 pid=618) ERROR 03-07 11:59:41 [multiproc_executor.py:847] ValueError: No valid attention backend found for rocm with AttentionSelectorConfig(head_size=256, dtype=torch.float16, kv_cache_dtype=auto, block_size=1056, use_mla=False, has_sink=False, use_sparse=False, use_mm_prefix=False, use_per_head_quant_scales=False, attn_type=AttentionType.DECODER). Reasons: {TRITON_ATTN: [block_size not supported]}.

@JartX
Copy link
Contributor Author

JartX commented Mar 7, 2026

@Rohan138 solved here: #36292

@AndreasKaratzas
Copy link
Collaborator

@JartX can you rebase your branch? This test group should be green as of yesterday.

Signed-off-by: JartX <sagformas@epdcenter.es>

Co-authored-by: akaratza <akaratza@amd.com>
@JartX JartX force-pushed the bugfix/qwen35_rocm_attn branch from 4f16fcd to cd8be20 Compare March 7, 2026 20:02
@JartX
Copy link
Contributor Author

JartX commented Mar 8, 2026

@AndreasKaratzas all test passed :)

@AndreasKaratzas
Copy link
Collaborator

That's great :) Unfortunately, even though my tag says "member" my approval won't turn your PR green (I only have read permissions 😅). I have forwarded your PR to the right channels.

@JartX
Copy link
Contributor Author

JartX commented Mar 8, 2026

@AndreasKaratzas many thanks ! Hahah :)

@JartX
Copy link
Contributor Author

JartX commented Mar 8, 2026

@tjtanaa Please check this out when you can :)

@mergify
Copy link

mergify bot commented Mar 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JartX.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 9, 2026
@mergify mergify bot removed the needs-rebase label Mar 9, 2026
Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjtanaa tjtanaa enabled auto-merge (squash) March 10, 2026 13:33
Signed-off-by: JartX <sagformas@epdcenter.es>
auto-merge was automatically disabled March 10, 2026 22:02

Head branch was pushed to by a user without write access

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) March 11, 2026 04:26
@DarkLight1337 DarkLight1337 merged commit a40ee48 into vllm-project:main Mar 11, 2026
53 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 11, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…ention to support qwen3_5 (vllm-project#35923)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: akaratza <akaratza@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…ention to support qwen3_5 (vllm-project#35923)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: akaratza <akaratza@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working documentation Improvements or additions to documentation qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants