diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 02d6ce0a69..2f6fd1889b 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1707,13 +1707,14 @@ def ck_moe_stage1( ): token_num = hidden_states.shape[0] is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1 - tmp_out = ( - torch.zeros( - (token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device + if is_splitk: + # CK kernel zeros this buffer via hipMemsetAsync when KBatch > 1 + sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) + tmp_out = torch.empty( + (sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device ) - if is_splitk - else out - ) + else: + tmp_out = out aiter.ck_moe_stage1_fwd( hidden_states, w1, @@ -1735,10 +1736,11 @@ def ck_moe_stage1( out.dtype, ) if is_splitk: + valid_out = tmp_out[: token_num * topk, :] if activation == ActivationType.Silu: - aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) + aiter.silu_and_mul(out, valid_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) + aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32)) return out @@ -1773,6 +1775,12 @@ def cktile_moe_stage1( D = D * 8 out = torch.empty((token_num, topk, D), dtype=dtype, device=hidden_states.device) + # WARNING: when split_k > 1, this allocation has the same undersized buffer + # pattern fixed in ck_moe_stage1 (see ROCm/aiter#2508). If the CK tile + # kernel calls hipMemsetAsync with sorted_size rows, this will overflow. + # When fp32 splitk is enabled, apply the same fix: use sorted_size = + # min(token_num * topk * block_m, sorted_token_ids.shape[0]) and slice + # valid_out = tmp_out[:token_num * topk, :] before silu_and_mul/gelu_and_mul. tmp_out = ( torch.zeros( (token_num, topk, w1.shape[1]), dtype=hidden_states.dtype, device=out.device