Skip to content

Enable Triton MLA for prefill#738

Merged
dllehr-amd merged 5 commits into355_wipfrom
cagri/triton_MHA
Oct 17, 2025
Merged

Enable Triton MLA for prefill#738
dllehr-amd merged 5 commits into355_wipfrom
cagri/triton_MHA

Conversation

@cagrikymk
Copy link
Copy Markdown

@cagrikymk cagrikymk commented Oct 15, 2025

This PR adds a flag (VLLM_ROCM_USE_AITER_TRITON_MLA) that enables Triton MLA when the flag is turned on.

The corresponding PR in aiter: ROCm/aiter#1203

@ZJLi2013
Copy link
Copy Markdown

looks a few gaps 1) it's not aligned with main branch; 2) when cherry-pick to main branch, still got runtime error for gfx950(mi35x):

 05:04:23 [multiproc_executor.py:585]     aiter_triton_fp8_bmm(x,^M
 05:04:23 [multiproc_executor.py:585]   File "/usr/local/lib/python3.12/dist-packages/aiter-0.1.5.dev196+gb5f0b0a05.d20251016-py3.12.egg/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py", line 315, in batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant^M
 05:04:23 [multiproc_executor.py:585]     _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel[^M
 05:04:23 [multiproc_executor.py:585]   File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 570, in run^M
 05:04:23 [multiproc_executor.py:585]     options = backend.parse_options(kwargs)^M
 05:04:23 [multiproc_executor.py:585]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^M
 05:04:23 [multiproc_executor.py:585]   File "/usr/local/lib/python3.12/dist-packages/triton/backends/amd/compiler.py", line 124, in parse_options^M
 05:04:23 [multiproc_executor.py:585]     return HIPOptions(**args)^M
 05:04:23 [multiproc_executor.py:585]            ^^^^^^^^^^^^^^^^^^^M
 05:04:23 [multiproc_executor.py:585]   File "<string>", line 24, in __init__^M
 05:04:23 [multiproc_executor.py:585]   File "/usr/local/lib/python3.12/dist-packages/triton/backends/amd/compiler.py", line 74, in __post_init__^M
 05:04:23 [multiproc_executor.py:585]     assert self.kpack == 1, "gfx950 only accepts kpack == 1"^M

@cagrikymk
Copy link
Copy Markdown
Author

@ZJLi2013 That might be about the Triton version, can you try: https://github.com/ROCm/triton/tree/pytorch/rocm7.1_internal_testing

Also, I have this branch for the vllm upstream for testing: https://github.com/ROCm/vllm/tree/cagri/triton_MHA_upstream

I will also run more tests to see if there are any issues.

Copy link
Copy Markdown
Collaborator

@dllehr-amd dllehr-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved! Merging into 355_wip

@dllehr-amd dllehr-amd merged commit c28077e into 355_wip Oct 17, 2025
0 of 2 checks passed
@gshtras gshtras deleted the cagri/triton_MHA branch January 16, 2026 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants