[Bugfix][Kernel] Fix mHC fused-RMSNorm big-fuse miscompile for hidden_size != 4096#44692
Merged
Merged
Conversation
…_size != 4096 mhc_pre_big_fuse_with_norm_tilelang pipelined the fused-RMSNorm pass at num_stages=3. Combined with the loop-carried sumsq reduction and the persistent output_shared buffer, the tilelang software pipeliner miscompiles it: NaN when hidden_size // 1024 <= 3 (e.g. 2048, 3072) and finite-but-wrong, slightly non-deterministic output when hidden_size // 1024 > 4 (e.g. 7168, the DeepSeek-V4 size). Only hidden_size=4096 (trip count 4) was correct. Drop to num_stages=2, matching the correct no-norm sibling kernel mhc_pre_big_fuse_tilelang. Verified against an fp32 RMSNorm reference: all hidden sizes 2048-8192 now match to ~1.6e-3 (bf16 floor), deterministically, across token counts 1..16384. num_stages=3 without the loop-carried state (the no-norm kernel) compiles correctly, confirming the issue is the pipeliner's handling of that combination. Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
jeejeelee
approved these changes
Jun 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
mhc_pre_big_fuse_with_norm_tilelang(the RMSNorm-fused mHC pre big-fuse usedby DeepSeek-V4 when
norm_weightis supplied) pipelines its fused-RMSNorm passat
num_stages=3. Combined with the loop-carriedsumsqreduction and thepersistent
output_sharedstaging buffer written at per-iteration offsets, theTileLang software pipeliner miscompiles the loop. The result is silently wrong
layer_inputfor every hidden size except 4096:H // 1024)It is occupancy-dependent (at H=7168 it is correct for ≤256 tokens, breaks at
≥512) and slightly non-deterministic run-to-run — the signature of a pipeline
prologue/epilogue codegen issue, not an arithmetic one. The path was evidently
only validated at
hidden_size=4096.Fix
Drop the fused-RMSNorm pass to
num_stages=2, matching the correct no-normsibling kernel
mhc_pre_big_fuse_tilelang(which already uses depth 2). One line.Root cause is in the pipeliner, not the algorithm: forcing the no-norm kernel
(no loop-carried state) to
num_stages=3stays correct, so depth-3 alone isfine — it only breaks when combined with the loop-carried reduction + persistent
shared buffer. Depth-2 handles that combination correctly.
Not a duplicate
Checked open mHC PRs (#42735 bf16 shared staging, #44144 XPU fused_post_pre,
#43950 ROCm aiter default, #42893 / #41834 DSv4 platform fixes). None touch the
num_stages/ fused-norm correctness. #42735 is a separate float32→bf16 stagingperf change and already uses
num_stages=2in its variants.Test
Verified on GB200 (CUDA 13, TileLang) against an fp32 RMSNorm reference built on
top of the existing
mhc_pre_refintests/kernels/test_mhc_kernels.py:hidden_size2048/3072 → NaN; 5120/6144/7168/8192 → rel-err 0.17–0.55; only 4096 correct.torch.ops.vllm.mhc_pre_tilelang(..., norm_weight, norm_eps)(deep_gemm + kernel) at H=3072/7168: 1.67e-3, all finite.mhc_pre,mhc_post) are unchanged and unaffected.