Skip to content

(alternative to #2508) Fix/splitk tmp_out undersized buffer avoid double-zeroing#2551

Merged
rbrugaro-amd merged 8 commits intoROCm:mainfrom
rbrugaro-amd:fix/splitk-tmpout-and-ck-memset
Mar 31, 2026
Merged

(alternative to #2508) Fix/splitk tmp_out undersized buffer avoid double-zeroing#2551
rbrugaro-amd merged 8 commits intoROCm:mainfrom
rbrugaro-amd:fix/splitk-tmpout-and-ck-memset

Conversation

@rbrugaro-amd
Copy link
Copy Markdown
Contributor

@rbrugaro-amd rbrugaro-amd commented Mar 31, 2026

The C++ side of this crash was fixed in ROCm/rocm-libraries#5225, which corrected the hipMemsetAsync size from arg.M * arg.N to arg.NumTokens * arg.TopK * arg.N. This PR fixes the Python side to match.

Technical Details

Root Cause

In ck_moe_stage1(), the tmp_out buffer was allocated as (token_num, topk, w1.shape[1]) which is undersized when splitK > 1. The CK kernel operates on sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0]) rows, so the buffer must be (sorted_size, w1.shape[1]).

For DeepSeek V3 decode (token_num=1, topk=8, block_m=16):

  • Old Python buffer: 1 * 8 * 4096 * 4 = 128 KB
  • CK kernel expects: 128 * 2048 * 4 * 2 = 2 MB

Fix

  1. ck_moe_stage1: Allocate tmp_out with sorted_size rows using torch.empty (CK kernel zeros the buffer via hipMemsetAsync, avoiding redundant double-zeroing). After the kernel, slice valid_out = tmp_out[:token_num * topk, :] before silu_and_mul/gelu_and_mul.

  2. cktile_moe_stage1: Added warning comment flagging the same undersized buffer pattern. The code is left unchanged since fp32 splitk is not yet active (see existing TODO: support fp32 splitk), but the comment documents the fix to apply when that path is enabled.

Why torch.empty instead of torch.zeros

With ROCm/rocm-libraries#5225 merged, the CK kernel correctly zeros the buffer via hipMemsetAsync when KBatch > 1. Using torch.empty avoids double-zeroing (once by Python, once by CK), eliminating a redundant GPU kernel launch.

Test Result

DeepSeek-R1-0528 (FP8, 8xMI355X): runtime error resolved, splitK active with correct results.

Submission Checklist

@rbrugaro-amd rbrugaro-amd requested a review from a team March 31, 2026 07:56
@rbrugaro-amd rbrugaro-amd changed the title Fix/splitk tmpout and ck memset (alternative to #2508) Fix/splitk tmp_out undersized buffer Mar 31, 2026
@rbrugaro-amd rbrugaro-amd changed the title (alternative to #2508) Fix/splitk tmp_out undersized buffer (alternative to #2508) Fix/splitk tmp_out undersized buffer avoid double-zeroing Mar 31, 2026
@rbrugaro-amd rbrugaro-amd merged commit e47cc0e into ROCm:main Mar 31, 2026
38 of 39 checks passed
daydayup-lh pushed a commit that referenced this pull request Apr 1, 2026
…ble-zeroing (#2551)

* fix: align ck_moe_stage1 split-K tmp_out buffer with CK kernel

* Update fused_moe.py

* tmp_out to use torch.empty vs. torch.zeros to avoid double zeroing

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>

* tighten valid_out slice: drop redundant .contiguous() 

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>

* restore .view(dtypes.fp32) on valid_out for silu_and_mul/gelu_and_mul

---------

Signed-off-by: rbrugaro <rita.brugarolasbrufau@amd.com>
Co-authored-by: Karan Verma <karan.verma@amd.com>
LJ-underdog pushed a commit that referenced this pull request Apr 24, 2026
Verified via canary tests + tp=2 inference that the three +2 row
padding fixes from commit 68fc7d48b are not needed for the BF16
no-quant path on gfx950. CK kernels skip the entire block when
expert_id is the sentinel (=E), so sorted_ids sentinel (K<<24)|T
never triggers OOB scatter to a2[T*K+K] / moe_out[M].

Reverted three locations:
- L339-345: drop moe_out_padded, pass moe_buf directly
- L1262-1264: zeros((token_num+2, ...)) -> empty((token_num, ...))
- L1349: a2.view(token_num+2, ...) -> a2.view(token_num, ...)

Verification:
- /tmp/test_moe_canary.py: a2[T*K+K] pristine after stage1
- /tmp/test_moe_canary_stage2.py: moe_out[M] pristine after stage2
- tp=2 Step-3.5-Flash inference: 4 prompts complete normally,
  no NaN, no crash, latency 1.97s/req
- Required fixes still in place: V1->V3 force (block_m=128) and
  shuffle_weight() preprocessing

Note: PR #2551 +2 padding is only required for split-K + per_1x128
quant path, which is a different code branch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants