Skip to content

[ROCm] Add MXFP4 inline dequant Triton kernel for RDNA4/gfx12#34632

Closed
laudney wants to merge 3 commits intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-mxfp4
Closed

[ROCm] Add MXFP4 inline dequant Triton kernel for RDNA4/gfx12#34632
laudney wants to merge 3 commits intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-mxfp4

Conversation

@laudney
Copy link
Copy Markdown
Contributor

@laudney laudney commented Feb 16, 2026

Summary

Enables MXFP4 (OCP MX FP4 e2m1f) quantized models on RDNA4/gfx12 hardware, which lacks tl.dot_scaled support required by the existing OAI Triton MXFP4 path.

  • New Triton kernel (fused_moe_mxfp4.py): Keeps weights packed as uint8 in VRAM (~half bf16 size) and dequantizes per-tile to bf16 inside the GEMM loop using a "two half-dots" strategy:
    • Unpack each uint8 into lo/hi FP4 nibbles
    • Dequant FP4 e2m1f → bf16 via bit manipulation
    • Apply E8M0 block scales (1 per 32 elements)
    • Load A with stride-2 for even/odd K columns
    • acc += dot(a_even, lo_bf16) + dot(a_odd, hi_bf16)
  • New backend TRITON_MXFP4_DEQUANT in mxfp4.py: Auto-selected on gfx12 via on_gfx1x() detection
  • Supports both legacy apply() path and modular kernel (Mxfp4DequantTritonExperts) with expert mapping, bias, and gated activations (SiLU/SwiGLU)
  • Subprocess crash resilience in registry.py: ROCm roctracer can fire spurious assertion failures during process cleanup on RDNA4 — check for valid output before return code so cleanup crashes don't mask successful results

Test plan

  • MXFP4 MoE model inference on RDNA4/gfx12 (e.g. GPT-OSS-20B MXFP4)
  • Verify no regression on MI300X (existing TRITON backend still selected)
  • Verify no regression on CUDA (MXFP4 backend selection unchanged)

@mergify mergify bot added new-model Requests to new models rocm Related to AMD ROCm labels Feb 16, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 16, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel for MXFP4 dequantization on RDNA4/gfx12 hardware, which is a significant feature enablement. The changes include the new kernel implementation, modifications to the quantization backend to select this new kernel on appropriate hardware, and a resilience improvement for subprocess handling in the model registry. While the overall approach is sound, I've identified a critical correctness issue in the new Triton kernel related to memory access, which could lead to incorrect results when the 'N' dimension is not perfectly divisible by the block size. My review includes a detailed explanation and a suggested fix for this issue.

Comment on lines +195 to +245
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
if HAS_BIAS:
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)

# Half-K for packed dimension
HALF_K: tl.constexpr = BLOCK_SIZE_K // 2

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

# Iterate over K in steps of BLOCK_SIZE_K (logical unpacked elements)
num_k_iters = tl.cdiv(K, BLOCK_SIZE_K)
for k_iter in range(0, num_k_iters):
k_start = k_iter * BLOCK_SIZE_K # logical K offset
k_packed_start = k_start // 2 # packed K offset

# Remaining elements mask
k_remaining = K - k_start

# --- Load B packed: [HALF_K, BLOCK_N] uint8 ---
offs_bk = tl.arange(0, HALF_K)
b_ptrs = (b_ptr
+ off_experts * stride_be
+ offs_bn[None, :] * stride_bn
+ (k_packed_start + offs_bk[:, None]) * stride_bk)
b_mask = offs_bk[:, None] < (k_remaining // 2)
b_packed = tl.load(b_ptrs, mask=b_mask, other=0)

# Unpack nibbles
lo_nibble = b_packed & 0x0F # even K indices
hi_nibble = (b_packed >> 4) & 0x0F # odd K indices

# Dequant to bf16
lo_bf16 = dequant_mxfp4_nibble_to_bf16(lo_nibble.to(tl.int32))
hi_bf16 = dequant_mxfp4_nibble_to_bf16(hi_nibble.to(tl.int32))

# --- Load and apply E8M0 scales ---
# Scale shape: [E, N, K//32], one scale per 32 logical elements.
# In packed space, 32 logical = 16 packed rows.
# Each packed row j maps to scale group (k_start // 32 + j // 16).
# Load scales directly as [HALF_K, BLOCK_N] by computing per-row
# scale pointers.
scale_k_start = k_start // 32
scale_k_offs = offs_bk // 16 # [HALF_K] - scale group for each row
scale_ptrs = (b_scale_ptr
+ off_experts * stride_bse
+ offs_bn[None, :] * stride_bsn
+ (scale_k_start + scale_k_offs[:, None]) * stride_bsk)
scale_mask = (scale_k_start + scale_k_offs[:, None]) < tl.cdiv(K, 32)
raw_scales = tl.load(scale_ptrs, mask=scale_mask, other=127)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential memory access bug here when N is not a multiple of BLOCK_SIZE_N. The offs_bn is calculated with a modulo N, which prevents out-of-bounds memory access but can lead to logically incorrect data being read for padded elements in the last block. The mask (offs_bn < N) is always true and thus ineffective.

This affects loading of bias, b_packed, and raw_scales, as they are not correctly masked along the N dimension. This can lead to incorrect computation results.

The fix involves calculating an explicit mask n_mask for the N dimension and applying it to all loads that depend on offs_bn.

    # N-dimension offsets
    unmasked_offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
    n_mask = unmasked_offs_bn < N
    offs_bn = unmasked_offs_bn

    if HAS_BIAS:
        bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
        bias = tl.load(bias_ptrs, mask=n_mask, other=0.0)

    # Half-K for packed dimension
    HALF_K: tl.constexpr = BLOCK_SIZE_K // 2

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # Iterate over K in steps of BLOCK_SIZE_K (logical unpacked elements)
    num_k_iters = tl.cdiv(K, BLOCK_SIZE_K)
    for k_iter in range(0, num_k_iters):
        k_start = k_iter * BLOCK_SIZE_K  # logical K offset
        k_packed_start = k_start // 2     # packed K offset

        # Remaining elements mask
        k_remaining = K - k_start

        # --- Load B packed: [HALF_K, BLOCK_N] uint8 ---
        offs_bk = tl.arange(0, HALF_K)
        b_ptrs = (b_ptr
                  + off_experts * stride_be
                  + offs_bn[None, :] * stride_bn
                  + (k_packed_start + offs_bk[:, None]) * stride_bk)
        b_mask = (offs_bk[:, None] < (k_remaining // 2)) & n_mask[None, :]
        b_packed = tl.load(b_ptrs, mask=b_mask, other=0)

        # Unpack nibbles
        lo_nibble = b_packed & 0x0F          # even K indices
        hi_nibble = (b_packed >> 4) & 0x0F   # odd K indices

        # Dequant to bf16
        lo_bf16 = dequant_mxfp4_nibble_to_bf16(lo_nibble.to(tl.int32))
        hi_bf16 = dequant_mxfp4_nibble_to_bf16(hi_nibble.to(tl.int32))

        # --- Load and apply E8M0 scales ---
        # Scale shape: [E, N, K//32], one scale per 32 logical elements.
        # In packed space, 32 logical = 16 packed rows.
        # Each packed row j maps to scale group (k_start // 32 + j // 16).
        # Load scales directly as [HALF_K, BLOCK_N] by computing per-row
        # scale pointers.
        scale_k_start = k_start // 32
        scale_k_offs = offs_bk // 16  # [HALF_K] - scale group for each row
        scale_ptrs = (b_scale_ptr
                      + off_experts * stride_bse
                      + offs_bn[None, :] * stride_bsn
                      + (scale_k_start + scale_k_offs[:, None]) * stride_bsk)
        scale_k_mask = (scale_k_start + scale_k_offs[:, None]) < tl.cdiv(K, 32)
        scale_mask = scale_k_mask & n_mask[None, :]
        raw_scales = tl.load(scale_ptrs, mask=scale_mask, other=127)

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 16, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 16, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Feb 17, 2026

Related PRs (RDNA4/gfx12 series)

This PR is part of a series enabling RDNA4 (gfx12) support in vLLM:

Each PR is independent and can be reviewed/merged separately.

@laudney laudney force-pushed the feat/rocm-rdna4-mxfp4 branch 2 times, most recently from 52fbea8 to d3f57b8 Compare February 17, 2026 20:42
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @laudney.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 18, 2026
L.B.R. added 2 commits February 19, 2026 03:03
Hardware without tl.dot_scaled (e.g. RDNA4/gfx12) cannot use the
standard OAI Triton MXFP4 path. Add a custom fused MoE kernel that
keeps weights packed as uint8 in VRAM and dequantizes per-tile to
bf16 inside the GEMM loop using a "two half-dots" strategy:

- Unpack each uint8 into lo/hi FP4 nibbles
- Dequant FP4 e2m1f -> bf16 via bit manipulation
- Apply E8M0 block scales (1 per 32 elements)
- Load A with stride-2 for even/odd K columns
- acc += dot(a_even, lo_bf16) + dot(a_odd, hi_bf16)

Supports both the legacy apply() path and the modular kernel
(Mxfp4DequantTritonExperts) with expert mapping, bias, and
gated activations (SiLU/SwiGLU).

Signed-off-by: L.B.R. <lbr@mmonad.com>
Upstream migrated activation from str to MoEActivation enum.
Update type annotations, _supports_activation, and is_gated check.

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney laudney force-pushed the feat/rocm-rdna4-mxfp4 branch from d3f57b8 to 1ca8e28 Compare February 19, 2026 03:03
@mergify mergify bot removed the needs-rebase label Feb 19, 2026
The modulo `% N` on offs_bn made `offs_bn < N` always true, so when N
was not a multiple of BLOCK_SIZE_N the last tile wrapped around and
loaded duplicate data instead of zeros.  Replace with an explicit
n_mask and apply it to all B-side loads (bias, packed weights, scales).

Signed-off-by: L.B.R. <lbr@mmonad.com>
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

For fusions @Rohan138 and @tjtanaa could you please take a look into the correctness of this PR? I run CI, but I don't know if there was any invocation of this MoE fused class and generally did not evaluate further than the CI run.

@ptrojahn
Copy link
Copy Markdown

Are you sure that this is necessary? Triton 3.5 and newer should definitely support dot_scaled on gfx12 and even gfx11.

@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Feb 22, 2026

Good question! tl.dot_scaled does compile and run on gfx12/gfx11 in Triton 3.5+, but it's lowered via DecomposeScaledBlocked (source) — pure software dequantization to bf16 followed by a regular WMMA tl.dot. The native scaled WMMA instruction (v_wmma_scale_f32_16x16x128_f8f6f4) requires wmmaVersion == 3, which is gfx1250-only — gfx1200/1201 get wmmaVersion == 2.

Since dot_scaled on gfx12 decomposes to the same dequant→bf16→dot that this kernel does manually, there's no throughput advantage. The custom kernel integrates the dequant directly into the fused-MoE pattern (expert routing + weight unpacking in one kernel launch), which is cleaner than trying to plumb dot_scaled through vLLM's existing MoE dispatch.

@ptrojahn
Copy link
Copy Markdown

Still not sure why this patch is necessary. There should be no special case for gfx12 as the operator is supported?

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @laudney.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 26, 2026
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 22, 2026

Closing this — @ptrojahn is right that a custom kernel shouldn't be necessary here.

I dug into why the standard Triton path doesn't work on gfx12. It's not just the Python capability check — tl.dot_scaled actually crashes the compiler on gfx1201:

'ttg.convert_layout' op requires the same shape for all operands and results
Pipeline failed while executing [`TritonAMDGPUAccelerateMatmul`]

The DecomposeAMDScaledBlocked pattern is in libtriton.so but doesn't match for RDNA4's WMMA layout (only tested against CDNA). Once that's fixed upstream, we'd just need to widen the cap range in _supports_current_device().

@laudney laudney closed this Mar 22, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 22, 2026
@ptrojahn
Copy link
Copy Markdown

Could you create a small reproducer for this crash? Generally, dot_scaled works fine on gfx11/gfx12. We have tests in upstream Triton and the gpt-oss MoE implementation makes use of this implementation already. This must be a specific edge case you are hitting. Thanks!

@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 22, 2026

@ptrojahn you were right — my reproducer had the wrong scale shape for RHS ([K//32, N] instead of [N, K//32]). With the correct shapes, tl.dot_scaled works fine on gfx1201.

The only thing blocking MXFP4 on gfx12 is the (9, 0) <= cap < (11, 0) check in _supports_current_device(). Opened #37826 with just that fix — tested with gpt-oss-20b on an R9700.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase new-model Requests to new models rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants