(alternative to #2508) Fix/splitk tmp_out undersized buffer avoid double-zeroing#2551
Merged
rbrugaro-amd merged 8 commits intoROCm:mainfrom Mar 31, 2026
Merged
Conversation
Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>
1 task
…s.fp32) Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>
valarLip
approved these changes
Mar 31, 2026
daydayup-lh
pushed a commit
that referenced
this pull request
Apr 1, 2026
…ble-zeroing (#2551) * fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel * Update fused_moe.py * tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> * tighten valid_out slice: drop redundant .contiguous() Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> * restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul --------- Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> Co-authored-by: Karan Verma <karan.verma@amd.com>
LJ-underdog
pushed a commit
that referenced
this pull request
Apr 24, 2026
Verified via canary tests + tp=2 inference that the three +2 row padding fixes from commit 68fc7d48b are not needed for the BF16 no-quant path on gfx950. CK kernels skip the entire block when expert_id is the sentinel (=E), so sorted_ids sentinel (K<<24)|T never triggers OOB scatter to a2[T*K+K] / moe_out[M]. Reverted three locations: - L339-345: drop moe_out_padded, pass moe_buf directly - L1262-1264: zeros((token_num+2, ...)) -> empty((token_num, ...)) - L1349: a2.view(token_num+2, ...) -> a2.view(token_num, ...) Verification: - /tmp/test_moe_canary.py: a2[T*K+K] pristine after stage1 - /tmp/test_moe_canary_stage2.py: moe_out[M] pristine after stage2 - tp=2 Step-3.5-Flash inference: 4 prompts complete normally, no NaN, no crash, latency 1.97s/req - Required fixes still in place: V1->V3 force (block_m=128) and shuffle_weight() preprocessing Note: PR #2551 +2 padding is only required for split-K + per_1x128 quant path, which is a different code branch.
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The C++ side of this crash was fixed in ROCm/rocm-libraries#5225, which corrected the
hipMemsetAsyncsize fromarg.M * arg.Ntoarg.NumTokens * arg.TopK * arg.N. This PR fixes the Python side to match.Technical Details
Root Cause
In
ck_moe_stage1(), thetmp_outbuffer was allocated as(token_num, topk, w1.shape[1])which is undersized whensplitK > 1. The CK kernel operates onsorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])rows, so the buffer must be(sorted_size, w1.shape[1]).For DeepSeek V3 decode (
token_num=1, topk=8, block_m=16):1 * 8 * 4096 * 4 = 128 KB128 * 2048 * 4 * 2 = 2 MBFix
ck_moe_stage1: Allocatetmp_outwithsorted_sizerows usingtorch.empty(CK kernel zeros the buffer viahipMemsetAsync, avoiding redundant double-zeroing). After the kernel, slicevalid_out = tmp_out[:token_num * topk, :]beforesilu_and_mul/gelu_and_mul.cktile_moe_stage1: Added warning comment flagging the same undersized buffer pattern. The code is left unchanged since fp32 splitk is not yet active (see existingTODO: support fp32 splitk), but the comment documents the fix to apply when that path is enabled.Why
torch.emptyinstead oftorch.zerosWith ROCm/rocm-libraries#5225 merged, the CK kernel correctly zeros the buffer via
hipMemsetAsyncwhenKBatch > 1. Usingtorch.emptyavoids double-zeroing (once by Python, once by CK), eliminating a redundant GPU kernel launch.Test Result
DeepSeek-R1-0528 (FP8, 8xMI355X): runtime error resolved, splitK active with correct results.
Submission Checklist