Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,13 +1707,18 @@ def ck_moe_stage1(
):
token_num = hidden_states.shape[0]
is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1
tmp_out = (
torch.zeros(
(token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
if is_splitk:
# CK kernel uses sorted_size as its M dimension and scatters output via
# sorted_token_ids. The output buffer must have at least sorted_size rows
# so that the kernel's tile-based writes stay in bounds.
sorted_size = min(
token_num * topk * block_m, sorted_token_ids.shape[0]
)
if is_splitk
else out
)
tmp_out = torch.zeros(
(sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)
else:
tmp_out = out
aiter.ck_moe_stage1_fwd(
hidden_states,
w1,
Expand All @@ -1735,10 +1740,11 @@ def ck_moe_stage1(
out.dtype,
)
if is_splitk:
valid_out = tmp_out[: token_num * topk, :]
if activation == ActivationType.Silu:
aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32))
aiter.silu_and_mul(out, valid_out)
else:
aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32))
aiter.gelu_and_mul(out, valid_out)
return out


Expand Down
Loading