Skip to content

fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547

Open
ChuanLi1101 wants to merge 1 commit intomainfrom
chuan/fix-ck-moe-stage1-splitk-scatter
Open

fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547
ChuanLi1101 wants to merge 1 commit intomainfrom
chuan/fix-ck-moe-stage1-splitk-scatter

Conversation

@ChuanLi1101
Copy link
Copy Markdown

@ChuanLi1101 ChuanLi1101 commented Mar 31, 2026

Summary

Fix out-of-bounds buffer overflow in \ck_moe_stage1\ when splitK is enabled.

Root cause

The CK MoE kernel uses \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ as its M dimension. The kernel launches tile-based blocks covering the entire M range and scatters results to the output buffer. The output buffer must be large enough to accommodate \sorted_size\ rows.

The original code allocated only (token_num, topk, w1.shape[1])\ = \ oken_num * topk\ rows (a 3D tensor). For the padding entries in \sorted_token_ids, the sentinel value (topk << 24 | token_num)\ decodes to scatter position \ oken_num * topk + topk, which exceeds the allocated buffer. Additionally, the kernel expects the output buffer to span at least \sorted_size\ rows to match its tile-based computation grid.

Fix

  • Compute \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ (matching the C++ wrapper logic)
  • Allocate a 2D fp32 buffer of shape (sorted_size, w1.shape[1])\ instead of the undersized 3D (token_num, topk, w1.shape[1])\
  • After the kernel, slice only the valid rows \ mp_out[:token_num*topk, :]\ before passing to \silu_and_mul\ / \gelu_and_mul\

Verification

Tested on MI355X (gfx950) with multiple token/topk/expert configurations:
\\

SplitK Scatter Fix Verification

[tok=1 topk=8 E=256] OK shape=torch.Size([1, 8, 256]) nan=False inf=False
[tok=2 topk=8 E=256] OK shape=torch.Size([2, 8, 256]) nan=False inf=False
[tok=4 topk=8 E=256] OK shape=torch.Size([4, 8, 256]) nan=False inf=False
[tok=16 topk=8 E=256] OK shape=torch.Size([16, 8, 256]) nan=False inf=False
[tok=1 topk=4 E=64] OK shape=torch.Size([1, 4, 256]) nan=False inf=False
[tok=3 topk=6 E=128] OK shape=torch.Size([3, 6, 256]) nan=False inf=False

Results: 6 passed, 0 failed out of 6
ALL TESTS PASSED!
\\

Comparison to PR #2508

PR #2508 uses \sorted_token_ids.shape[0]\ rows (safe but over-allocates). This PR uses \sorted_size\ (the exact M dimension the C++ wrapper computes), which is the minimal correct size. Both are valid; this PR is tighter on memory.

Test plan

  • Verified on MI355X with 6 different token/topk/expert configs
  • All tests pass with no NaN/Inf in output
  • Non-splitK path is unchanged (tmp_out = out)

@ChuanLi1101 ChuanLi1101 requested a review from a team March 31, 2026 06:20
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2547 --add-label <label>

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuan/fix-ck-moe-stage1-splitk-scatter branch from 1ff7e70 to ab58051 Compare March 31, 2026 07:46
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 1, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding
so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul.

Upstream: ROCm#2547
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 1, 2026
Allow callers to supply a pre-allocated (M, model_dim) buffer for
moe_sorting instead of torch.empty each forward, for DSv32/vLLM integration.

Keeps ck_moe_stage1 split-K fix from ROCm#2547.

docs: update dsv32-opt-branch provenance (moe_buf + ROCm#2547).
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 30, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding
so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul.

Upstream: ROCm#2547
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 30, 2026
Allow callers to supply a pre-allocated (M, model_dim) buffer for
moe_sorting instead of torch.empty each forward, for DSv32/vLLM integration.

Keeps ck_moe_stage1 split-K fix from ROCm#2547.

docs: update dsv32-opt-branch provenance (moe_buf + ROCm#2547).
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant