Skip to content

Add SplitK support for CK/CKTile Block-Scale GEMMs#2541

Closed
samremes wants to merge 14 commits intomainfrom
samremes/blockscale_splitk
Closed

Add SplitK support for CK/CKTile Block-Scale GEMMs#2541
samremes wants to merge 14 commits intomainfrom
samremes/blockscale_splitk

Conversation

@samremes
Copy link
Copy Markdown
Contributor

@samremes samremes commented Mar 30, 2026

This pull request adds support for the split-K technique to the gemm_a8w8_blockscale and gemm_a8w8_blockscale_cktile kernels. The changes propagate a new splitK (or k_batch) parameter throughout the Python and C++ code, update function signatures, and implement the necessary logic to handle split-K in both kernel implementations. Additionally, it removes previous limitations and ensures output tensors are properly zeroed when using split-K.

Split-K support and API changes:

  • Added a splitK (or k_batch) parameter to the Python API (gemm_a8w8_blockscale, gemm_a8w8_blockscale_cktile) and propagated this parameter through all relevant C++function signatures and kernel instance generators, allowing users to specify the degree of K-dimension splitting.

  • Updated kernel implementations to accept and validate the splitK parameter, compute the batch size for K (KBatch = 1 << splitK), and ensure it is within the supported range.

Kernel logic and correctness:

  • Implemented logic to zero the output tensor (Y.zero_()) when using split-K in the tile kernel, ensuring correct accumulation with atomic adds.

  • Removed previous runtime errors that prevented split-K usage in the tile kernel, enabling support for split-K execution.

Miscellaneous:

  • Cleaned up includes by removing unnecessary headers (e.g., <cmath>) in the C++ source files.

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor
The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor
@samremes samremes changed the title Samremes/blockscale splitk Enable SplitK for CK Block-Scale GEMMs Mar 30, 2026
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2541 --add-label <label>

@samremes samremes marked this pull request as ready for review March 31, 2026 16:30
@samremes samremes requested a review from a team March 31, 2026 16:30
@valarLip valarLip added the ci:all label Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
Enable split-K support for CK and CKTile block-scale A8W8 GEMMs.
Adds splitK parameter to kernel launch and instance generation.

Source: ROCm#2541
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
Wire splitK parameter from tuning CSV through production blockscale
GEMM dispatch in Python and C++ layers.

Source: ROCm#2541
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
The pybind definitions were updated by PR ROCm#2541 to accept splitK,
but the C++ function signatures in the headers and .cu files were
not. Add int splitK parameter to gemm_a8w8_blockscale and
gemm_a8w8_blockscale_cktile to match the pybind interface.
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
The cherry-pick of PR ROCm#2541 lost the KBatch computation and pass-through
in gemm_a8w8_blockscale.cu and gemm_a8w8_blockscale_cktile.cu.
The functions accepted splitK at the pybind level but never converted it
to KBatch or forwarded it to the generated kernel wrappers.

Add back: KBatch = pow(2, splitK), pass KBatch to BlockwiseKernel calls,
fix BlockwiseKernel typedef to include the int parameter, and restore
the splitK=0 default in bpreshuffle_cktile wrapper.
@samremes samremes requested a review from Copilot April 8, 2026 12:34
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables end-to-end Split-K for CK and CKTile A8W8 block-scale GEMM paths, so production dispatch can consume the tuned splitK values from the existing tuning CSV instead of effectively hardcoding KBatch=1.

Changes:

  • Plumbs splitK from Python dispatch/pybind into the C++ entry points for CK and CKTile block-scale GEMMs (default splitK=0 preserves current behavior).
  • Updates CK and CKTile kernel wrappers/codegen interfaces to accept a runtime K-splitting parameter (KBatch/k_batch) and dispatch accordingly.
  • Removes the CKTile “split-k not supported yet” guard and adds output zero-initialization for atomic accumulation.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
csrc/include/rocm_ops.hpp Adds splitK pybind args (default 0) for block-scale GEMM bindings.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h Extends CK block-scale GEMM API with splitK parameter.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh Adds KBatch plumbing into CK GEMM argument via SetKBatch.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h Extends CKTile block-scale GEMM API with splitK parameter.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh Enables split-K in CKTile path and zeros output for atomic accumulation.
csrc/ck_gemm_a8w8_blockscale/gen_instances.py Updates generated CK instance signatures to accept KBatch.
csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py Updates generated CKTile instance signatures to accept k_batch.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu Wires runtime splitK into CK dispatch via computed KBatch.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu For tuning path, forwards KBatch into selected CK kernel instance.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu Wires runtime splitK into CKTile dispatch via computed KBatch.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu For tuning path, forwards KBatch into selected CKTile kernel instance.
aiter/ops/gemm_op_a8w8.py Reads splitK from tuning config at runtime and forwards it into CK/CKTile block-scale entry points.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu Outdated
Comment thread csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu Outdated
Comment thread csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh Outdated
Comment thread aiter/ops/gemm_op_a8w8.py
@samremes
Copy link
Copy Markdown
Contributor Author

samremes commented Apr 8, 2026

@copilot apply changes based on the comments in this thread

…ue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 8, 2026

@copilot apply changes based on the comments in this thread

Applied all changes from the review thread in commit 2bf04dd:

  1. gemm_a8w8_blockscale.cu / gemm_a8w8_blockscale_cktile.cu: Added TORCH_CHECK(splitK >= 0 && splitK <= 30, ...) before computing KBatch; replaced static_cast<int>(std::pow(2, splitK)) with the safe integer shift 1 << splitK; removed the now-unused #include <cmath>.

  2. gemm_a8w8_blockscale_common.cuh: Added TORCH_CHECK(KBatch >= 1, ...) so invalid KBatch values fail immediately before the CK dispatch.

  3. gemm_a8w8_blockscale_cktile_common.cuh: Replaced hipMemsetAsync(Y.data_ptr(), 0, M * N * sizeof(OutDataType), ...) with Y.zero_(), which correctly zeros all rows regardless of leading-dimension stride (handles padded tensors where stride_C > N).

  4. op_tests/test_gemm_a8w8_blockscale.py: Added test_splitk_correctness that directly calls both the CK and CKTile kernels with splitK ∈ {1, 2} and verifies the results match splitK=0 within a relaxed tolerance (to account for reordered atomic accumulation).

Copilot AI changed the title Enable SplitK for CK Block-Scale GEMMs Fix splitK validation, zeroing correctness, and add split-K regression test Apr 8, 2026
@samremes samremes changed the title Fix splitK validation, zeroing correctness, and add split-K regression test Add SplitK support for CK/CKTile Block-Scale GEMMs Apr 9, 2026
@samremes
Copy link
Copy Markdown
Contributor Author

@valarLip Would you have time to review this PR? All the CI passes with a recent merge with main.

nholmber added a commit to nholmber/aiter that referenced this pull request Apr 22, 2026
PR ROCm#2862's CK bump (cbfb3e242) lacks the ABQuantGrouped/GemmTraits
APIs needed by PRs ROCm#2541 and ROCm#2487. Update to 020b6f435 which has
both the stride fix and the required CK-TILE blockscale APIs.
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 25, 2026
Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR ROCm#2862 (CK bump for stride fix in CK-TILE blockscale)
- PR ROCm#2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR ROCm#2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)
sunway513 added a commit that referenced this pull request May 3, 2026
azaidy added a commit that referenced this pull request May 4, 2026
aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sunway513 added a commit that referenced this pull request May 5, 2026
…3-Next, pa_mqa OOB) (#3005)

* fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support)

* Fixing two cascading bugs when running the MoE tuner

* Enable split-K for block-scale A8W8 CK and CKTile GEMMs

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor

* Wire splitK from tuning CSV through production blockscale GEMM dispatch

The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor

* fix: ck_moe_stage1 split-K output buffer overflow from padding scatter

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor

* Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>

* black format

* fix splitk test dimensions

* Add gdn fusions

* style: fix ruff F841 and black-format Triton PR files

Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor

* Update fused_rearrange_sigmoid_gdr.py

* Update op_tests

* Fix BLACK format problem

* Fix black check failure

* Update test_fused_rearrange_sigmoid_gdr.py

* Allow callers to pass pre-allocated moe_buf to avoid output copy

Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor

* Add moe_buf pass-through test to existing test_moe_sorting

Made-with: Cursor

* Replace _fast with _single_token for causal conv1d update kernels for single token decoding

* Fix blck format error

* Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X

Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR #2862 (CK bump for stride fix in CK-TILE blockscale)
- PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)

* refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor

* Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations

Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256).
Key changes:
- SplitK usage dropped from 613 to 88 CK shapes (splitK > 0)
- All shapes validated via --run_config (1482/1482 OK)
- E2e perf: 2-8% output throughput improvement vs untuned heuristic

* [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer

The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>

* style: fix Black formatting

* style: fix Black formatting (Python 3.12 compatible)

* ci: replace deprecated zmq package with pyzmq

The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.

* ci: increase pip retries and timeout for CI reliability

Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.

* ci: make pyzmq install non-blocking in triton test setup

pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.

* ci: retry pip install individually on batch failure

When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.

* [MLA] Fix nhead=32 non-persistent decode crash on gfx950

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor

* revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures

Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).

* fix: restore tuple unpack for FlyDSL fused-quant stage1 return

flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.

* Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions

* fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd)

* style: fix Black formatting for local imports

* fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert

Lingpeng explicitly rejected this change (from excluded PR #2457).
Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights.

* fix: restore silently-reverted main features from bad merge resolution

aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: remove #2464 from bulk merge per author request

@xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this
bulk merge — they don't need it for the uplift.

Verified that #2464 is the only PR in this bulk merge touching
aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the
branch and origin/main on those files is exactly #2464's +9/-1 and
+5/-0, with no other PR content mixed in. Restoring both files to
origin/main therefore drops #2464 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: vecheruk-amd <vecheruk@amd.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Li <chuali@amd.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com>
Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com>
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: frida-andersson <fanderss@amd.com>
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@sunway513
Copy link
Copy Markdown
Collaborator

This PR's content was bulk-merged via #3005 ([Silo] Bulk merge: kernel fixes and features, merged 2026-05-05 03:34 UTC). Please close this PR as superseded.

Tracking issue: ROCm/AI-Frameworks-Dashboard#141

sunway513 added a commit that referenced this pull request May 5, 2026
…st1 (#3005)

Squash-merged from main commit 2c855fb.

Includes 8 atomic Silo PRs (4 bug fixes + 3 features + 1):
Bug fixes:
- #2457 MoE dispatch fix for Quark W4A6 (MXFP4 weights with QuantType.No)
- #2464 CK MoE tuner cascading bugs
- #2547 ck_moe_stage1 split-K buffer overflow (memory safety)
- #2866 pa_mqa_logits OOB stores fix (memory safety)
Features:
- #2423 Triton optimized decode for Qwen3-Next (GDN, conv1d, fused FP8 quant)
- #2541 SplitK support for CK/CKTile Block-Scale GEMMs
- #2687 Allow preallocated MoE sorting buffer

Conflict resolutions (3 files):
- .github/workflows/aiter-test.yaml (3 blocks): took HEAD to preserve
  release/v0.1.13's prebuilt-image-extract CI flow. Theirs would replace
  it with main's inline  flow which has not
  been validated for this release branch.
- .github/workflows/vllm_benchmark.yaml (1 block): took HEAD for the same
  CI architecture preservation reason.
- aiter/ops/triton/gated_delta_net/__init__.py (1 block): took HEAD.
  Theirs would expose  and
  , but those functions exist only on main
  (added in a separate commit not part of this PR), so taking theirs
  would break import on this release branch with NameError. The new
  fused_rearrange_sigmoid_gdr.py wrapper that #3005 introduces is
  importable directly from its module path; not exporting it via
  __init__.py simply means library consumers must use the longer import
  path. Acceptable trade-off vs broken imports.

28 files changed, +5433/-47 (3 of original 31 dropped to HEAD per above).

Driver: vLLM 0.21 freeze 2026-05-08 — Silo customers need these kernel
fixes (especially #2547 / #2866 memory safety) on the AITER release
wheel, not nightly.

Verification gates added before tag:
- ATOM 5-model accuracy unchanged within +/- 0.005 vs v0.1.13-rc1
- New Qwen3-Next decode codepath smoke (GDN + causal_conv1d_single_token
  + fused_fp8_quant must JIT-compile and produce coherent output)
- Memory safety regression check on Kimi-K2.5-MXFP4 (exercises ck_moe
  stage1) and DeepSeek-V3.2 (exercises pa_mqa_logits)
- Perf delta sample on Kimi/MiniMax/DSv3.2 c=1 + c=64 vs rc1 baseline

(cherry picked from commit 2c855fb)
@samremes
Copy link
Copy Markdown
Contributor Author

samremes commented May 5, 2026

Included in #3005 - closing this PR.

@samremes samremes closed this May 5, 2026
Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
…3-Next, pa_mqa OOB) (#3005)

* fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support)

* Fixing two cascading bugs when running the MoE tuner

* Enable split-K for block-scale A8W8 CK and CKTile GEMMs

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor

* Wire splitK from tuning CSV through production blockscale GEMM dispatch

The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor

* fix: ck_moe_stage1 split-K output buffer overflow from padding scatter

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor

* Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>

* black format

* fix splitk test dimensions

* Add gdn fusions

* style: fix ruff F841 and black-format Triton PR files

Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor

* Update fused_rearrange_sigmoid_gdr.py

* Update op_tests

* Fix BLACK format problem

* Fix black check failure

* Update test_fused_rearrange_sigmoid_gdr.py

* Allow callers to pass pre-allocated moe_buf to avoid output copy

Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor

* Add moe_buf pass-through test to existing test_moe_sorting

Made-with: Cursor

* Replace _fast with _single_token for causal conv1d update kernels for single token decoding

* Fix blck format error

* Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X

Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR #2862 (CK bump for stride fix in CK-TILE blockscale)
- PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)

* refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor

* Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations

Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256).
Key changes:
- SplitK usage dropped from 613 to 88 CK shapes (splitK > 0)
- All shapes validated via --run_config (1482/1482 OK)
- E2e perf: 2-8% output throughput improvement vs untuned heuristic

* [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer

The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>

* style: fix Black formatting

* style: fix Black formatting (Python 3.12 compatible)

* ci: replace deprecated zmq package with pyzmq

The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.

* ci: increase pip retries and timeout for CI reliability

Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.

* ci: make pyzmq install non-blocking in triton test setup

pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.

* ci: retry pip install individually on batch failure

When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.

* [MLA] Fix nhead=32 non-persistent decode crash on gfx950

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor

* revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures

Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).

* fix: restore tuple unpack for FlyDSL fused-quant stage1 return

flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.

* Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions

* fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd)

* style: fix Black formatting for local imports

* fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert

Lingpeng explicitly rejected this change (from excluded PR #2457).
Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights.

* fix: restore silently-reverted main features from bad merge resolution

aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: remove #2464 from bulk merge per author request

@xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this
bulk merge — they don't need it for the uplift.

Verified that #2464 is the only PR in this bulk merge touching
aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the
branch and origin/main on those files is exactly #2464's +9/-1 and
+5/-0, with no other PR content mixed in. Restoring both files to
origin/main therefore drops #2464 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: vecheruk-amd <vecheruk@amd.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Li <chuali@amd.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com>
Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com>
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: frida-andersson <fanderss@amd.com>
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants