diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 02d6ce0a69..78422707fc 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1707,13 +1707,14 @@ def ck_moe_stage1( ): token_num = hidden_states.shape[0] is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1 - tmp_out = ( - torch.zeros( - (token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device + if is_splitk: + # CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats + sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) + tmp_out = torch.zeros( + (sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device ) - if is_splitk - else out - ) + else: + tmp_out = out aiter.ck_moe_stage1_fwd( hidden_states, w1, @@ -1735,10 +1736,11 @@ def ck_moe_stage1( out.dtype, ) if is_splitk: + valid_out = tmp_out[: token_num * topk, :].contiguous() if activation == ActivationType.Silu: - aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) + aiter.silu_and_mul(out, valid_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) + aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32)) return out