-
Notifications
You must be signed in to change notification settings - Fork 295
fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel #2508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1707,13 +1707,14 @@ def ck_moe_stage1( | |
| ): | ||
| token_num = hidden_states.shape[0] | ||
| is_splitk = quant_type is aiter.QuantType.per_1x128 and splitk > 1 | ||
| tmp_out = ( | ||
| torch.zeros( | ||
| (token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device | ||
| if is_splitk: | ||
| # CK splitK kernel hipMemsetAsync zeros sorted_size * w1.shape[1] floats | ||
| sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) | ||
| tmp_out = torch.zeros( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not good to me, looks like we do zeros twice, 1 here, 1 ck, we need remove one
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This ROCm/rocm-libraries#5225 already addressed the proper allocation from the C++ side. I have swap to torch.empty to avoid the double zeroing #2551 (created new PR because i cannot edit current branch) |
||
| (sorted_size, w1.shape[1]), dtype=dtypes.fp32, device=out.device | ||
| ) | ||
| if is_splitk | ||
| else out | ||
| ) | ||
| else: | ||
| tmp_out = out | ||
| aiter.ck_moe_stage1_fwd( | ||
| hidden_states, | ||
| w1, | ||
|
|
@@ -1735,10 +1736,11 @@ def ck_moe_stage1( | |
| out.dtype, | ||
| ) | ||
| if is_splitk: | ||
| valid_out = tmp_out[: token_num * topk, :].contiguous() | ||
| if activation == ActivationType.Silu: | ||
| aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) | ||
| aiter.silu_and_mul(out, valid_out.view(dtypes.fp32)) | ||
| else: | ||
| aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) | ||
| aiter.gelu_and_mul(out, valid_out.view(dtypes.fp32)) | ||
| return out | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cktile_moe_stage1 has the identical undersized buffer pattern for split_k > 1. It calls aiter.moe_cktile2stages_gemm1 which is a different kernel, but if that kernel has the same hipMemsetAsync behavior, it would have the same overflow. The TODO at line 1723 (# TODO: support fp32 splitk) suggests this path may not be actively used yet, but it's worth either applying the same fix or adding a guard/comment.