[ROCm] Widen OAI Triton MoE capability range to include gfx12 (RDNA4)#37826
[ROCm] Widen OAI Triton MoE capability range to include gfx12 (RDNA4)#37826laudney wants to merge 1 commit intovllm-project:mainfrom
Conversation
The _supports_current_device() check in BaseOAITritonExperts and OAITritonMxfp4ExpertsMonolithic rejects gfx12 (capability 12.0) because the upper bound is (11, 0). Triton 3.5+ supports tl.dot_scaled on gfx12 via DecomposeScaledBlocked, so the standard MXFP4 MoE path works without a custom kernel. Widen the range from <(11,0) to <(13,0) to cover RDNA4 (gfx1200/1201). Also widen the same check in oracle/mxfp4.py for the LoRA path. Tested with openai/gpt-oss-20b (MXFP4 MoE) on AMD Radeon AI PRO R9700 (gfx1201) — model loads and produces correct output. Signed-off-by: L.B.R. <lbr@mmonad.com>
There was a problem hiding this comment.
Code Review
This pull request widens the device capability range for Triton MoE kernels to include gfx12 (RDNA4), which is a necessary and well-justified change. However, the implementation uses a broad range that also enables support for gfx11 (RDNA3), which was not mentioned as a target for this PR and may be untested. To mitigate potential risks for users on that hardware, I've suggested making the conditions more specific to only include the intended gfx12 architecture alongside the existing supported ranges.
| # (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) | ||
| # and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0). | ||
| return (9, 0) <= (cap.major, cap.minor) < (13, 0) |
There was a problem hiding this comment.
The change to < (13, 0) enables support for devices with major capability versions 11 and 12. The PR description and updated comment only mention gfx12 (RDNA4). If gfx11 (RDNA3) is not intended to be supported or has not been tested, it would be safer to use a more specific condition to only enable gfx12. This prevents potential issues for users on gfx11 hardware.
If gfx11 is also supported, please consider updating the comment to reflect that.
| # (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) | |
| # and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0). | |
| return (9, 0) <= (cap.major, cap.minor) < (13, 0) | |
| # (9,0) <= cap < (11,0) or cap.major == 12 covers CUDA SM90 (Hopper), | |
| # SM100+ (Blackwell), ROCm gfx942/gfx950 (9.4/9.5), and gfx12 (RDNA4). | |
| return (9, 0) <= (cap.major, cap.minor) < (11, 0) or cap.major == 12 |
| # (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) | ||
| # and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0). | ||
| return (9, 0) <= (cap.major, cap.minor) < (13, 0) |
There was a problem hiding this comment.
Similar to the previous comment, this change to < (13, 0) also enables support for gfx11. To be safer and only enable the tested gfx12 architecture, a more specific condition is recommended.
| # (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) | |
| # and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0). | |
| return (9, 0) <= (cap.major, cap.minor) < (13, 0) | |
| # (9,0) <= cap < (11,0) or cap.major == 12 covers CUDA SM90 (Hopper), | |
| # SM100+ (Blackwell), ROCm gfx942/gfx950 (9.4/9.5), and gfx12 (RDNA4). | |
| return (9, 0) <= (cap.major, cap.minor) < (11, 0) or cap.major == 12 |
| triton_kernels_supported = has_triton_kernels() and ( | ||
| 9, | ||
| 0, | ||
| ) <= current_platform.get_device_capability() < (11, 0) | ||
| ) <= current_platform.get_device_capability() < (13, 0) |
There was a problem hiding this comment.
This condition is now broad enough to enable Triton kernels for gfx11 as well, which is not mentioned in the PR. To avoid potential issues on untested hardware, it's better to make the condition more specific to the architectures that are known to be supported. The suggested change also improves clarity by using a local variable for the device capability.
cap = current_platform.get_device_capability()
triton_kernels_supported = has_triton_kernels() and cap and (
(9, 0) <= cap < (11, 0) or cap.major == 12)
Summary
_supports_current_device()cap range from< (11, 0)to< (13, 0)inBaseOAITritonExpertsandOAITritonMxfp4ExpertsMonolithicoracle/mxfp4.pyfor the LoRA pathgfx12 (RDNA4) maps to capability
(12, 0)which was excluded by the old upper bound. Triton 3.5+ supportstl.dot_scaledon gfx12 viaDecomposeScaledBlocked, so the standard MXFP4 MoE path works without any custom kernel.Replaces #34632 which tried to add a custom dequant kernel — turns out that's unnecessary (thanks @ptrojahn for pushing back on that).
Tested with
openai/gpt-oss-20b(MXFP4 MoE) on AMD Radeon AI PRO R9700 (gfx1201).AI assistance was used. Test commands run:
curl localhost:8080/v1/chat/completionsagainst a running vLLM instance with the model loaded.Test plan