Skip to content

[ROCm] Widen OAI Triton MoE capability range to include gfx12 (RDNA4)#37826

Open
laudney wants to merge 1 commit intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-mxfp4
Open

[ROCm] Widen OAI Triton MoE capability range to include gfx12 (RDNA4)#37826
laudney wants to merge 1 commit intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-mxfp4

Conversation

@laudney
Copy link
Copy Markdown
Contributor

@laudney laudney commented Mar 22, 2026

Summary

  • Widen _supports_current_device() cap range from < (11, 0) to < (13, 0) in BaseOAITritonExperts and OAITritonMxfp4ExpertsMonolithic
  • Same change in oracle/mxfp4.py for the LoRA path

gfx12 (RDNA4) maps to capability (12, 0) which was excluded by the old upper bound. Triton 3.5+ supports tl.dot_scaled on gfx12 via DecomposeScaledBlocked, so the standard MXFP4 MoE path works without any custom kernel.

Replaces #34632 which tried to add a custom dequant kernel — turns out that's unnecessary (thanks @ptrojahn for pushing back on that).

Tested with openai/gpt-oss-20b (MXFP4 MoE) on AMD Radeon AI PRO R9700 (gfx1201).

AI assistance was used. Test commands run: curl localhost:8080/v1/chat/completions against a running vLLM instance with the model loaded.

Test plan

  • Load MXFP4 MoE model (gpt-oss-20b) on gfx1201 — model loads, correct output
  • CI

The _supports_current_device() check in BaseOAITritonExperts and
OAITritonMxfp4ExpertsMonolithic rejects gfx12 (capability 12.0)
because the upper bound is (11, 0). Triton 3.5+ supports
tl.dot_scaled on gfx12 via DecomposeScaledBlocked, so the standard
MXFP4 MoE path works without a custom kernel.

Widen the range from <(11,0) to <(13,0) to cover RDNA4 (gfx1200/1201).
Also widen the same check in oracle/mxfp4.py for the LoRA path.

Tested with openai/gpt-oss-20b (MXFP4 MoE) on AMD Radeon AI PRO R9700
(gfx1201) — model loads and produces correct output.

Signed-off-by: L.B.R. <lbr@mmonad.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request widens the device capability range for Triton MoE kernels to include gfx12 (RDNA4), which is a necessary and well-justified change. However, the implementation uses a broad range that also enables support for gfx11 (RDNA3), which was not mentioned as a target for this PR and may be untested. To mitigate potential risks for users on that hardware, I've suggested making the conditions more specific to only include the intended gfx12 architecture alongside the existing supported ranges.

Comment on lines +555 to +557
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change to < (13, 0) enables support for devices with major capability versions 11 and 12. The PR description and updated comment only mention gfx12 (RDNA4). If gfx11 (RDNA3) is not intended to be supported or has not been tested, it would be safer to use a more specific condition to only enable gfx12. This prevents potential issues for users on gfx11 hardware.

If gfx11 is also supported, please consider updating the comment to reflect that.

Suggested change
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
# (9,0) <= cap < (11,0) or cap.major == 12 covers CUDA SM90 (Hopper),
# SM100+ (Blackwell), ROCm gfx942/gfx950 (9.4/9.5), and gfx12 (RDNA4).
return (9, 0) <= (cap.major, cap.minor) < (11, 0) or cap.major == 12

Comment on lines +887 to +889
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, this change to < (13, 0) also enables support for gfx11. To be safer and only enable the tested gfx12 architecture, a more specific condition is recommended.

Suggested change
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
# (9,0) <= cap < (11,0) or cap.major == 12 covers CUDA SM90 (Hopper),
# SM100+ (Blackwell), ROCm gfx942/gfx950 (9.4/9.5), and gfx12 (RDNA4).
return (9, 0) <= (cap.major, cap.minor) < (11, 0) or cap.major == 12

Comment on lines 202 to +205
triton_kernels_supported = has_triton_kernels() and (
9,
0,
) <= current_platform.get_device_capability() < (11, 0)
) <= current_platform.get_device_capability() < (13, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This condition is now broad enough to enable Triton kernels for gfx11 as well, which is not mentioned in the PR. To avoid potential issues on untested hardware, it's better to make the condition more specific to the architectures that are known to be supported. The suggested change also improves clarity by using a local variable for the device capability.

    cap = current_platform.get_device_capability()
    triton_kernels_supported = has_triton_kernels() and cap and (
        (9, 0) <= cap < (11, 0) or cap.major == 12)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm

Projects

Status: Todo
Status: To Triage

Development

Successfully merging this pull request may close these issues.

1 participant