Skip to content

fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel #2508

Closed
karverma-amd wants to merge 4 commits intoROCm:mainfrom
karverma-amd:main
Closed

fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel #2508
karverma-amd wants to merge 4 commits intoROCm:mainfrom
karverma-amd:main

Conversation

@karverma-amd
Copy link
Copy Markdown
Contributor

@karverma-amd karverma-amd commented Mar 27, 2026

Motivation

The splitK optimization in aiter's CK MoE kernel reduces average kernel time from 44us to 22us (2x improvement). However, enabling it in aiter commit: f80d9c0 causes:

RuntimeError: HIP runtime error: invalid argument. hip_check_error.hpp: 18
in function: hip_check_error

Technical Details

The crash occurs during CUDA graph capture in vLLM when ksplit > 1 for decode batches with small token counts (token_num=1, topk=8).

Root Cause

Buffer overflow in tmp_out allocation in ck_moe_stage1() (/app/aiter/aiter/fused_moe.py line ~1638).

The Mismatch

When splitK is active (KBatch > 1), the CK C++ kernel calls hipMemsetAsync to zero-initialize the output buffer for atomic accumulation. The size it zeroes is:

M * N * sizeof(float) * 2

where M = sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]).

But Python allocated tmp_out as:

torch.zeros((token_num, topk, w1.shape[1]), dtype=fp32)

For DeepSeek V3 decode with token_num=1, topk=8, block_m=16:

  • Python buffer: 1 * 8 * 4096 * 4 = 131,072 bytes (128 KB)
  • C++ hipMemsetAsync: 128 * 2048 * 4 * 2 = 2,097,152 bytes (2 MB)
  • Overflow: 16x beyond the allocated tensor

C++ Code Reference

The overflow originates in device_moe_gemm_blockscale.hpp (CK):

if(arg.KBatch > 1)
    hipGetErrorString(hipMemsetAsync(arg.p_c_grid, 0,
        arg.M * arg.N * sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1),
        stream_config.stream_id_));

Fix

Replace the tmp_out allocation to match the kernel's expected buffer size:

# BEFORE (buggy):
tmp_out = torch.zeros(
    (token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)

# AFTER (fixed):
sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])
tmp_out = torch.zeros(
    (sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)

Test Result

DeepSeek-R1-0528 with aiter commit - f80d9c00b9c8 runtime error is resolved.

Submission Checklist

@karverma-amd karverma-amd requested a review from a team March 27, 2026 16:58
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2508 --add-label <label>

@karverma-amd karverma-amd added the bug Something isn't working label Mar 30, 2026
@karverma-amd karverma-amd requested a review from valarLip March 30, 2026 16:33
@ChuanLi1101 ChuanLi1101 self-assigned this Mar 31, 2026
Copy link
Copy Markdown

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good with minor suggestions. Pre Approved for speed up the deliveries.

Comment thread aiter/fused_moe.py
):
token_num = hidden_states.shape[0]
is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1
tmp_out = (
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cktile_moe_stage1 has the identical undersized buffer pattern for split_k > 1. It calls aiter.moe_cktile2stages_gemm1 which is a different kernel, but if that kernel has the same hipMemsetAsync behavior, it would have the same overflow. The TODO at line 1723 (# TODO: support fp32 splitk) suggests this path may not be actively used yet, but it's worth either applying the same fix or adding a guard/comment.

Comment thread aiter/fused_moe.py
if is_splitk:
# CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats
sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])
tmp_out = torch.zeros(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not good to me, looks like we do zeros twice, 1 here, 1 ck, we need remove one

Copy link
Copy Markdown
Contributor

@rbrugaro-amd rbrugaro-amd Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ROCm/rocm-libraries#5225 already addressed the proper allocation from the C++ side. I have swap to torch.empty to avoid the double zeroing #2551 (created new PR because i cannot edit current branch)
@ChuanLi1101 i see you also created another PR but I think the double zeroing still there?

ChuanLi1101 added a commit that referenced this pull request Mar 31, 2026
The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
ChuanLi1101 added a commit to ChuanLi1101/vllm-rocm-docker that referenced this pull request Mar 31, 2026
Base image: rocm/vllm-dev:base_custom_rocm_7.2.1_torch_triton_0330_vllm018

Patches applied:
- AITER SplitK bug fix (ROCm/aiter#2508)
- vLLM persistent MLA kernel (vllm-project/vllm#36574)
- vLLM fused AllReduce+RMSNorm (vllm-project/vllm#37891)

Made-with: Cursor
ChuanLi1101 added a commit that referenced this pull request Mar 31, 2026
The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
akii96 pushed a commit that referenced this pull request Mar 31, 2026
The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
rbrugaro-amd added a commit that referenced this pull request Mar 31, 2026
…ble-zeroing (#2551)

* fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel

* Update fused_moe.py

* tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>

* tighten valid_out slice: drop redundant .contiguous() 

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>

* restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul

---------

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>
Co-authored-by: Karan Verma <karan.verma@amd.com>
@karverma-amd
Copy link
Copy Markdown
Contributor Author

Alternative PR merged - #2551

daydayup-lh pushed a commit that referenced this pull request Apr 1, 2026
…ble-zeroing (#2551)

* fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel

* Update fused_moe.py

* tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>

* tighten valid_out slice: drop redundant .contiguous() 

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>

* restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul

---------

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>
Co-authored-by: Karan Verma <karan.verma@amd.com>
thpereir added a commit that referenced this pull request Apr 22, 2026
…r shared experts

Switch the MXFP4 a4w4 (fp4x2 activations + fp4x2 weights) MoE path from
CKTile to CK JIT kernels (ck_moe_stage1 / ck_moe_stage2_fwd) and fix
several dispatch bugs that caused crashes on MiniMax-M2.1-MXFP4 at TP=8.

Root cause 1 -- wrong stage1 kernel for a4w4:
  cktile_moe_stage1 dispatches fp4x2 activations through the fp8 pipeline
  (F8xMXF4FlatmmPipelineAGmemBGmemCRegV1), which misinterprets packed fp4
  as fp8 and produces garbage. Switch to ck_moe_stage1 which uses the
  correct DeviceMoeGemmMXBPreShuffle<FP4X2, FP4X2, ...> instances from the
  JIT-compiled module_moe_ck2stages_fp4x2_fp4x2_preshuffle_on module.

Root cause 2 -- ksplit>1 steals a4w4 cases for shared experts:
  MiniMax has shared expert layers with inter_dim=256 (per TP=8 rank).
  get_ksplit() returns 2 for these, and the "ksplit > 1 and is_shuffled"
  elif fired before the a4w4 elif because it only checked q_dtype_w
  (not q_dtype_a). This routed fp4x2 activations through cktile_moe_stage1
  (bf16 output) into stage2 with the b16_fp4x2_preshuffle_on module, whose
  heuristic dispatch has no bf16xfp4x2 kernels -> crash.
  Fix: add "q_dtype_a not in [dtypes.fp4x2]" guard.

Root cause 3 -- ksplit must be 0 for a4w4 in fused_moe_2stages:
  When metadata.ksplit > 1, fused_moe_2stages skips a1 fp4x2 quantization
  and leaves a2_scale=None (bf16 activations into stage2). Force ksplit=0
  in the a4w4 MOEMetadata so activation quantization always runs.

Additional fixes:
- cktile_moe_stage1 split_k>1 buffer overflow: allocate tmp_out with
  sorted_size rows (= max_num_tokens_padded) instead of token_num*topk
  rows, mirroring the fix from ck_moe_stage1 (#2508).
- cktile_moe_stage1 split_k>1 stage2: use ck_moe_stage2_fwd (JIT CK)
  instead of cktile_moe_stage2 for correct preshuffle dispatch.
- shuffle_weight_a16w4: set is_shuffled=True on returned tensor.
- Add e8m0_unshuffle utility (inverse of e8m0_shuffle).
- test_moe_2stage: add a4w4 branch using shuffle_weight_a16w4(gate_up=False)
  matching the CK JIT kernel expected layout.

Tested: MiniMax-M2.1-MXFP4 TP=8 server starts, warmup completes (all
cuda_graph_capture_sizes including 48), inference produces correct output.

Files changed:
  aiter/fused_moe.py
  aiter/ops/shuffle.py
  aiter/utility/fp4_utils.py
  op_tests/test_moe_2stage.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants