[ROCm] Add MXFP4 inline dequant Triton kernel for RDNA4/gfx12#34632
[ROCm] Add MXFP4 inline dequant Triton kernel for RDNA4/gfx12#34632laudney wants to merge 3 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new Triton kernel for MXFP4 dequantization on RDNA4/gfx12 hardware, which is a significant feature enablement. The changes include the new kernel implementation, modifications to the quantization backend to select this new kernel on appropriate hardware, and a resilience improvement for subprocess handling in the model registry. While the overall approach is sound, I've identified a critical correctness issue in the new Triton kernel related to memory access, which could lead to incorrect results when the 'N' dimension is not perfectly divisible by the block size. My review includes a detailed explanation and a suggested fix for this issue.
| offs_bn = (pid_n * BLOCK_SIZE_N + | ||
| tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N | ||
| if HAS_BIAS: | ||
| bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn | ||
| bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) | ||
|
|
||
| # Half-K for packed dimension | ||
| HALF_K: tl.constexpr = BLOCK_SIZE_K // 2 | ||
|
|
||
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | ||
|
|
||
| # Iterate over K in steps of BLOCK_SIZE_K (logical unpacked elements) | ||
| num_k_iters = tl.cdiv(K, BLOCK_SIZE_K) | ||
| for k_iter in range(0, num_k_iters): | ||
| k_start = k_iter * BLOCK_SIZE_K # logical K offset | ||
| k_packed_start = k_start // 2 # packed K offset | ||
|
|
||
| # Remaining elements mask | ||
| k_remaining = K - k_start | ||
|
|
||
| # --- Load B packed: [HALF_K, BLOCK_N] uint8 --- | ||
| offs_bk = tl.arange(0, HALF_K) | ||
| b_ptrs = (b_ptr | ||
| + off_experts * stride_be | ||
| + offs_bn[None, :] * stride_bn | ||
| + (k_packed_start + offs_bk[:, None]) * stride_bk) | ||
| b_mask = offs_bk[:, None] < (k_remaining // 2) | ||
| b_packed = tl.load(b_ptrs, mask=b_mask, other=0) | ||
|
|
||
| # Unpack nibbles | ||
| lo_nibble = b_packed & 0x0F # even K indices | ||
| hi_nibble = (b_packed >> 4) & 0x0F # odd K indices | ||
|
|
||
| # Dequant to bf16 | ||
| lo_bf16 = dequant_mxfp4_nibble_to_bf16(lo_nibble.to(tl.int32)) | ||
| hi_bf16 = dequant_mxfp4_nibble_to_bf16(hi_nibble.to(tl.int32)) | ||
|
|
||
| # --- Load and apply E8M0 scales --- | ||
| # Scale shape: [E, N, K//32], one scale per 32 logical elements. | ||
| # In packed space, 32 logical = 16 packed rows. | ||
| # Each packed row j maps to scale group (k_start // 32 + j // 16). | ||
| # Load scales directly as [HALF_K, BLOCK_N] by computing per-row | ||
| # scale pointers. | ||
| scale_k_start = k_start // 32 | ||
| scale_k_offs = offs_bk // 16 # [HALF_K] - scale group for each row | ||
| scale_ptrs = (b_scale_ptr | ||
| + off_experts * stride_bse | ||
| + offs_bn[None, :] * stride_bsn | ||
| + (scale_k_start + scale_k_offs[:, None]) * stride_bsk) | ||
| scale_mask = (scale_k_start + scale_k_offs[:, None]) < tl.cdiv(K, 32) | ||
| raw_scales = tl.load(scale_ptrs, mask=scale_mask, other=127) |
There was a problem hiding this comment.
There's a potential memory access bug here when N is not a multiple of BLOCK_SIZE_N. The offs_bn is calculated with a modulo N, which prevents out-of-bounds memory access but can lead to logically incorrect data being read for padded elements in the last block. The mask (offs_bn < N) is always true and thus ineffective.
This affects loading of bias, b_packed, and raw_scales, as they are not correctly masked along the N dimension. This can lead to incorrect computation results.
The fix involves calculating an explicit mask n_mask for the N dimension and applying it to all loads that depend on offs_bn.
# N-dimension offsets
unmasked_offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
n_mask = unmasked_offs_bn < N
offs_bn = unmasked_offs_bn
if HAS_BIAS:
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=n_mask, other=0.0)
# Half-K for packed dimension
HALF_K: tl.constexpr = BLOCK_SIZE_K // 2
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Iterate over K in steps of BLOCK_SIZE_K (logical unpacked elements)
num_k_iters = tl.cdiv(K, BLOCK_SIZE_K)
for k_iter in range(0, num_k_iters):
k_start = k_iter * BLOCK_SIZE_K # logical K offset
k_packed_start = k_start // 2 # packed K offset
# Remaining elements mask
k_remaining = K - k_start
# --- Load B packed: [HALF_K, BLOCK_N] uint8 ---
offs_bk = tl.arange(0, HALF_K)
b_ptrs = (b_ptr
+ off_experts * stride_be
+ offs_bn[None, :] * stride_bn
+ (k_packed_start + offs_bk[:, None]) * stride_bk)
b_mask = (offs_bk[:, None] < (k_remaining // 2)) & n_mask[None, :]
b_packed = tl.load(b_ptrs, mask=b_mask, other=0)
# Unpack nibbles
lo_nibble = b_packed & 0x0F # even K indices
hi_nibble = (b_packed >> 4) & 0x0F # odd K indices
# Dequant to bf16
lo_bf16 = dequant_mxfp4_nibble_to_bf16(lo_nibble.to(tl.int32))
hi_bf16 = dequant_mxfp4_nibble_to_bf16(hi_nibble.to(tl.int32))
# --- Load and apply E8M0 scales ---
# Scale shape: [E, N, K//32], one scale per 32 logical elements.
# In packed space, 32 logical = 16 packed rows.
# Each packed row j maps to scale group (k_start // 32 + j // 16).
# Load scales directly as [HALF_K, BLOCK_N] by computing per-row
# scale pointers.
scale_k_start = k_start // 32
scale_k_offs = offs_bk // 16 # [HALF_K] - scale group for each row
scale_ptrs = (b_scale_ptr
+ off_experts * stride_bse
+ offs_bn[None, :] * stride_bsn
+ (scale_k_start + scale_k_offs[:, None]) * stride_bsk)
scale_k_mask = (scale_k_start + scale_k_offs[:, None]) < tl.cdiv(K, 32)
scale_mask = scale_k_mask & n_mask[None, :]
raw_scales = tl.load(scale_ptrs, mask=scale_mask, other=127)|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1 similar comment
|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Related PRs (RDNA4/gfx12 series)This PR is part of a series enabling RDNA4 (gfx12) support in vLLM:
Each PR is independent and can be reviewed/merged separately. |
52fbea8 to
d3f57b8
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Hardware without tl.dot_scaled (e.g. RDNA4/gfx12) cannot use the standard OAI Triton MXFP4 path. Add a custom fused MoE kernel that keeps weights packed as uint8 in VRAM and dequantizes per-tile to bf16 inside the GEMM loop using a "two half-dots" strategy: - Unpack each uint8 into lo/hi FP4 nibbles - Dequant FP4 e2m1f -> bf16 via bit manipulation - Apply E8M0 block scales (1 per 32 elements) - Load A with stride-2 for even/odd K columns - acc += dot(a_even, lo_bf16) + dot(a_odd, hi_bf16) Supports both the legacy apply() path and the modular kernel (Mxfp4DequantTritonExperts) with expert mapping, bias, and gated activations (SiLU/SwiGLU). Signed-off-by: L.B.R. <lbr@mmonad.com>
Upstream migrated activation from str to MoEActivation enum. Update type annotations, _supports_activation, and is_gated check. Signed-off-by: L.B.R. <lbr@mmonad.com>
d3f57b8 to
1ca8e28
Compare
The modulo `% N` on offs_bn made `offs_bn < N` always true, so when N was not a multiple of BLOCK_SIZE_N the last tile wrapped around and loaded duplicate data instead of zeros. Replace with an explicit n_mask and apply it to all B-side loads (bias, packed weights, scales). Signed-off-by: L.B.R. <lbr@mmonad.com>
|
Are you sure that this is necessary? Triton 3.5 and newer should definitely support dot_scaled on gfx12 and even gfx11. |
|
Good question! Since |
|
Still not sure why this patch is necessary. There should be no special case for gfx12 as the operator is supported? |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Closing this — @ptrojahn is right that a custom kernel shouldn't be necessary here. I dug into why the standard Triton path doesn't work on gfx12. It's not just the Python capability check — The |
|
Could you create a small reproducer for this crash? Generally, dot_scaled works fine on gfx11/gfx12. We have tests in upstream Triton and the gpt-oss MoE implementation makes use of this implementation already. This must be a specific edge case you are hitting. Thanks! |
|
@ptrojahn you were right — my reproducer had the wrong scale shape for RHS ( The only thing blocking MXFP4 on gfx12 is the |
Summary
Enables MXFP4 (OCP MX FP4 e2m1f) quantized models on RDNA4/gfx12 hardware, which lacks
tl.dot_scaledsupport required by the existing OAI Triton MXFP4 path.fused_moe_mxfp4.py): Keeps weights packed as uint8 in VRAM (~half bf16 size) and dequantizes per-tile to bf16 inside the GEMM loop using a "two half-dots" strategy:acc += dot(a_even, lo_bf16) + dot(a_odd, hi_bf16)TRITON_MXFP4_DEQUANTinmxfp4.py: Auto-selected on gfx12 viaon_gfx1x()detectionapply()path and modular kernel (Mxfp4DequantTritonExperts) with expert mapping, bias, and gated activations (SiLU/SwiGLU)registry.py: ROCm roctracer can fire spurious assertion failures during process cleanup on RDNA4 — check for valid output before return code so cleanup crashes don't mask successful resultsTest plan
TRITONbackend still selected)