From f7ddc47d03ec1876dacf8eab4bcd0547bfc42deb Mon Sep 17 00:00:00 2001 From: Karan Verma Date: Fri, 27 Mar 2026 11:50:08 -0500 Subject: [PATCH 1/5] fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel --- aiter/fused_moe.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 02d6ce0a69..9ee6e100ea 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1707,13 +1707,14 @@ def ck_moe_stage1( ): token_num = hidden_states.shape[0] is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1 - tmp_out = ( - torch.zeros( - (token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device + if is_splitk: + # CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats + sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) + tmp_out = torch.zeros( + (sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device ) - if is_splitk - else out - ) + else: + tmp_out = out aiter.ck_moe_stage1_fwd( hidden_states, w1, @@ -1735,10 +1736,11 @@ def ck_moe_stage1( out.dtype, ) if is_splitk: + valid_out = tmp_out[: token_num * topk, :].contiguous() if activation == ActivationType.Silu: - aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) + aiter.silu_and_mul(out, valid_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) + aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32)) return out @@ -1931,3 +1933,4 @@ def fused_topk( # topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + From 857c9d2f2468f70a7baae25bd28877160808b05b Mon Sep 17 00:00:00 2001 From: vermak95 Date: Mon, 30 Mar 2026 18:37:02 -0500 Subject: [PATCH 2/5] Update fused_moe.py --- aiter/fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 9ee6e100ea..78422707fc 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1933,4 +1933,3 @@ def fused_topk( # topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids - From 3c8ef1759a475a2ff7a9c7ff5d1056f7afd7d2ad Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Tue, 31 Mar 2026 02:53:16 -0500 Subject: [PATCH 3/5] tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing Signed-off-by: rbrugaro --- aiter/fused_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 78422707fc..0b56105bbc 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1708,9 +1708,9 @@ def ck_moe_stage1( token_num = hidden_states.shape[0] is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1 if is_splitk: - # CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats + # CK kernel zeros this buffer via hipMemsetAsync when KBatch > 1 sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) - tmp_out = torch.zeros( + tmp_out = torch.empty( (sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device ) else: @@ -1775,6 +1775,12 @@ def cktile_moe_stage1( D = D * 8 out = torch.empty((token_num, topk, D), dtype=dtype, device=hidden_states.device) + # WARNING: when split_k > 1, this allocation has the same undersized buffer + # pattern fixed in ck_moe_stage1 (see ROCm/aiter#2508). If the CK tile + # kernel calls hipMemsetAsync with sorted_size rows, this will overflow. + # When fp32 splitk is enabled, apply the same fix: use sorted_size = + # min(token_num * topk * block_m, sorted_token_ids.shape[0]) and slice + # valid_out = tmp_out[:token_num * topk, :] before silu_and_mul/gelu_and_mul. tmp_out = ( torch.zeros( (token_num, topk, w1.shape[1]), dtype=hidden_states.dtype, device=out.device From 7b9fe7d9fb45808a9e9d1124562121fa165a5970 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Tue, 31 Mar 2026 03:33:20 -0500 Subject: [PATCH 4/5] tighten valid_out slice: drop redundant .contiguous() and .view(dtypes.fp32) Signed-off-by: rbrugaro --- aiter/fused_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 0b56105bbc..0c3a18e063 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1736,11 +1736,11 @@ def ck_moe_stage1( out.dtype, ) if is_splitk: - valid_out = tmp_out[: token_num * topk, :].contiguous() + valid_out = tmp_out[: token_num * topk, :] if activation == ActivationType.Silu: - aiter.silu_and_mul(out, valid_out.view(dtypes.fp32)) + aiter.silu_and_mul(out, valid_out) else: - aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32)) + aiter.gelu_and_mul(out, valid_out) return out From 2633e6881017ee5afd9453225f5a6129f5b9748d Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Tue, 31 Mar 2026 03:53:33 -0500 Subject: [PATCH 5/5] restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul --- aiter/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 0c3a18e063..2f6fd1889b 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1738,9 +1738,9 @@ def ck_moe_stage1( if is_splitk: valid_out = tmp_out[: token_num * topk, :] if activation == ActivationType.Silu: - aiter.silu_and_mul(out, valid_out) + aiter.silu_and_mul(out, valid_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, valid_out) + aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32)) return out