Skip to content
24 changes: 16 additions & 8 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,13 +1707,14 @@ def ck_moe_stage1(
):
token_num = hidden_states.shape[0]
is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1
tmp_out = (
torch.zeros(
(token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
if is_splitk:
# CK kernel zeros this buffer via hipMemsetAsync when KBatch > 1
sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])
tmp_out = torch.empty(
(sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)
if is_splitk
else out
)
else:
tmp_out = out
aiter.ck_moe_stage1_fwd(
hidden_states,
w1,
Expand All @@ -1735,10 +1736,11 @@ def ck_moe_stage1(
out.dtype,
)
if is_splitk:
valid_out = tmp_out[: token_num * topk, :]
if activation == ActivationType.Silu:
aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32))
aiter.silu_and_mul(out, valid_out.view(dtypes.fp32))
else:
aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32))
aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32))
return out


Expand Down Expand Up @@ -1773,6 +1775,12 @@ def cktile_moe_stage1(
D = D * 8

out = torch.empty((token_num, topk, D), dtype=dtype, device=hidden_states.device)
# WARNING: when split_k > 1, this allocation has the same undersized buffer
# pattern fixed in ck_moe_stage1 (see ROCm/aiter#2508). If the CK tile
# kernel calls hipMemsetAsync with sorted_size rows, this will overflow.
# When fp32 splitk is enabled, apply the same fix: use sorted_size =
# min(token_num * topk * block_m, sorted_token_ids.shape[0]) and slice
# valid_out = tmp_out[:token_num * topk, :] before silu_and_mul/gelu_and_mul.
tmp_out = (
torch.zeros(
(token_num, topk, w1.shape[1]), dtype=hidden_states.dtype, device=out.device
Expand Down
Loading