fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547
Open
ChuanLi1101 wants to merge 1 commit intomainfrom
Open
fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547ChuanLi1101 wants to merge 1 commit intomainfrom
ChuanLi1101 wants to merge 1 commit intomainfrom
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor
1ff7e70 to
ab58051
Compare
frida-andersson
added a commit
to frida-andersson/aiter
that referenced
this pull request
Apr 1, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul. Upstream: ROCm#2547 Made-with: Cursor
frida-andersson
added a commit
to frida-andersson/aiter
that referenced
this pull request
Apr 1, 2026
frida-andersson
added a commit
to frida-andersson/aiter
that referenced
this pull request
Apr 30, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul. Upstream: ROCm#2547 Made-with: Cursor
frida-andersson
added a commit
to frida-andersson/aiter
that referenced
this pull request
Apr 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix out-of-bounds buffer overflow in \ck_moe_stage1\ when splitK is enabled.
Root cause
The CK MoE kernel uses \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ as its M dimension. The kernel launches tile-based blocks covering the entire M range and scatters results to the output buffer. The output buffer must be large enough to accommodate \sorted_size\ rows.
The original code allocated only (token_num, topk, w1.shape[1])\ = \ oken_num * topk\ rows (a 3D tensor). For the padding entries in \sorted_token_ids, the sentinel value (topk << 24 | token_num)\ decodes to scatter position \ oken_num * topk + topk, which exceeds the allocated buffer. Additionally, the kernel expects the output buffer to span at least \sorted_size\ rows to match its tile-based computation grid.
Fix
Verification
Tested on MI355X (gfx950) with multiple token/topk/expert configurations:
\\
SplitK Scatter Fix Verification
[tok=1 topk=8 E=256] OK shape=torch.Size([1, 8, 256]) nan=False inf=False
[tok=2 topk=8 E=256] OK shape=torch.Size([2, 8, 256]) nan=False inf=False
[tok=4 topk=8 E=256] OK shape=torch.Size([4, 8, 256]) nan=False inf=False
[tok=16 topk=8 E=256] OK shape=torch.Size([16, 8, 256]) nan=False inf=False
[tok=1 topk=4 E=64] OK shape=torch.Size([1, 4, 256]) nan=False inf=False
[tok=3 topk=6 E=128] OK shape=torch.Size([3, 6, 256]) nan=False inf=False
Results: 6 passed, 0 failed out of 6
ALL TESTS PASSED!
\\
Comparison to PR #2508
PR #2508 uses \sorted_token_ids.shape[0]\ rows (safe but over-allocates). This PR uses \sorted_size\ (the exact M dimension the C++ wrapper computes), which is the minimal correct size. Both are valid; this PR is tighter on memory.
Test plan