Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,13 +1707,14 @@ def ck_moe_stage1(
):
token_num = hidden_states.shape[0]
is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1
tmp_out = (
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cktile_moe_stage1 has the identical undersized buffer pattern for split_k > 1. It calls aiter.moe_cktile2stages_gemm1 which is a different kernel, but if that kernel has the same hipMemsetAsync behavior, it would have the same overflow. The TODO at line 1723 (# TODO: support fp32 splitk) suggests this path may not be actively used yet, but it's worth either applying the same fix or adding a guard/comment.

torch.zeros(
(token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
if is_splitk:
# CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats
sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])
tmp_out = torch.zeros(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not good to me, looks like we do zeros twice, 1 here, 1 ck, we need remove one

Copy link
Copy Markdown
Contributor

@rbrugaro-amd rbrugaro-amd Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ROCm/rocm-libraries#5225 already addressed the proper allocation from the C++ side. I have swap to torch.empty to avoid the double zeroing #2551 (created new PR because i cannot edit current branch)
@ChuanLi1101 i see you also created another PR but I think the double zeroing still there?

(sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)
if is_splitk
else out
)
else:
tmp_out = out
aiter.ck_moe_stage1_fwd(
hidden_states,
w1,
Expand All @@ -1735,10 +1736,11 @@ def ck_moe_stage1(
out.dtype,
)
if is_splitk:
valid_out = tmp_out[: token_num * topk, :].contiguous()
if activation == ActivationType.Silu:
aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32))
aiter.silu_and_mul(out, valid_out.view(dtypes.fp32))
else:
aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32))
aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32))
return out


Expand Down
Loading