fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel #2508
fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel #2508karverma-amd wants to merge 4 commits intoROCm:mainfrom
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
ChuanLi1101
left a comment
There was a problem hiding this comment.
Looks good with minor suggestions. Pre Approved for speed up the deliveries.
| ): | ||
| token_num = hidden_states.shape[0] | ||
| is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1 | ||
| tmp_out = ( |
There was a problem hiding this comment.
cktile_moe_stage1 has the identical undersized buffer pattern for split_k > 1. It calls aiter.moe_cktile2stages_gemm1 which is a different kernel, but if that kernel has the same hipMemsetAsync behavior, it would have the same overflow. The TODO at line 1723 (# TODO: support fp32 splitk) suggests this path may not be actively used yet, but it's worth either applying the same fix or adding a guard/comment.
| if is_splitk: | ||
| # CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats | ||
| sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) | ||
| tmp_out = torch.zeros( |
There was a problem hiding this comment.
not good to me, looks like we do zeros twice, 1 here, 1 ck, we need remove one
There was a problem hiding this comment.
This ROCm/rocm-libraries#5225 already addressed the proper allocation from the C++ side. I have swap to torch.empty to avoid the double zeroing #2551 (created new PR because i cannot edit current branch)
@ChuanLi1101 i see you also created another PR but I think the double zeroing still there?
The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor
Base image: rocm/vllm-dev:base_custom_rocm_7.2.1_torch_triton_0330_vllm018 Patches applied: - AITER SplitK bug fix (ROCm/aiter#2508) - vLLM persistent MLA kernel (vllm-project/vllm#36574) - vLLM fused AllReduce+RMSNorm (vllm-project/vllm#37891) Made-with: Cursor
The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor
The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor
…ble-zeroing (#2551) * fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel * Update fused_moe.py * tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> * tighten valid_out slice: drop redundant .contiguous() Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> * restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul --------- Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> Co-authored-by: Karan Verma <karan.verma@amd.com>
|
Alternative PR merged - #2551 |
…ble-zeroing (#2551) * fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel * Update fused_moe.py * tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> * tighten valid_out slice: drop redundant .contiguous() Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> * restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul --------- Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com> Co-authored-by: Karan Verma <karan.verma@amd.com>
…r shared experts Switch the MXFP4 a4w4 (fp4x2 activations + fp4x2 weights) MoE path from CKTile to CK JIT kernels (ck_moe_stage1 / ck_moe_stage2_fwd) and fix several dispatch bugs that caused crashes on MiniMax-M2.1-MXFP4 at TP=8. Root cause 1 -- wrong stage1 kernel for a4w4: cktile_moe_stage1 dispatches fp4x2 activations through the fp8 pipeline (F8xMXF4FlatmmPipelineAGmemBGmemCRegV1), which misinterprets packed fp4 as fp8 and produces garbage. Switch to ck_moe_stage1 which uses the correct DeviceMoeGemmMXBPreShuffle<FP4X2, FP4X2, ...> instances from the JIT-compiled module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on module. Root cause 2 -- ksplit>1 steals a4w4 cases for shared experts: MiniMax has shared expert layers with inter_dim=256 (per TP=8 rank). get_ksplit() returns 2 for these, and the "ksplit > 1 and is_shuffled" elif fired before the a4w4 elif because it only checked q_dtype_w (not q_dtype_a). This routed fp4x2 activations through cktile_moe_stage1 (bf16 output) into stage2 with the b16_fp4x2_preshuffle_on module, whose heuristic dispatch has no bf16xfp4x2 kernels -> crash. Fix: add "q_dtype_a not in [dtypes.fp4x2]" guard. Root cause 3 -- ksplit must be 0 for a4w4 in fused_moe_2stages: When metadata.ksplit > 1, fused_moe_2stages skips a1 fp4x2 quantization and leaves a2_scale=None (bf16 activations into stage2). Force ksplit=0 in the a4w4 MOEMetadata so activation quantization always runs. Additional fixes: - cktile_moe_stage1 split_k>1 buffer overflow: allocate tmp_out with sorted_size rows (= max_num_tokens_padded) instead of token_num*topk rows, mirroring the fix from ck_moe_stage1 (#2508). - cktile_moe_stage1 split_k>1 stage2: use ck_moe_stage2_fwd (JIT CK) instead of cktile_moe_stage2 for correct preshuffle dispatch. - shuffle_weight_a16w4: set is_shuffled=True on returned tensor. - Add e8m0_unshuffle utility (inverse of e8m0_shuffle). - test_moe_2stage: add a4w4 branch using shuffle_weight_a16w4(gate_up=False) matching the CK JIT kernel expected layout. Tested: MiniMax-M2.1-MXFP4 TP=8 server starts, warmup completes (all cuda_graph_capture_sizes including 48), inference produces correct output. Files changed: aiter/fused_moe.py aiter/ops/shuffle.py aiter/utility/fp4_utils.py op_tests/test_moe_2stage.py
Motivation
The
splitKoptimization inaiter's CK MoE kernel reduces average kernel time from 44us to 22us (2x improvement). However, enabling it in aiter commit: f80d9c0 causes:Technical Details
The crash occurs during CUDA graph capture in vLLM when
ksplit > 1for decode batches with small token counts (token_num=1, topk=8).Root Cause
Buffer overflow in
tmp_outallocation inck_moe_stage1()(/app/aiter/aiter/fused_moe.pyline ~1638).The Mismatch
When
splitKis active (KBatch > 1), the CK C++ kernel callshipMemsetAsyncto zero-initialize the output buffer for atomic accumulation. The size it zeroes is:where
M = sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]).But Python allocated
tmp_outas:For DeepSeek V3 decode with
token_num=1, topk=8, block_m=16:1 * 8 * 4096 * 4 = 131,072 bytes (128 KB)128 * 2048 * 4 * 2 = 2,097,152 bytes (2 MB)C++ Code Reference
The overflow originates in
device_moe_gemm_blockscale.hpp(CK):Fix
Replace the
tmp_outallocation to match the kernel's expected buffer size:Test Result
DeepSeek-R1-0528 with aiter commit -
f80d9c00b9c8runtime error is resolved.Submission Checklist