Expose AQLayout as tunable parameter for CKTile blockscale 8-warp GEMM kernels#2487
Merged
Expose AQLayout as tunable parameter for CKTile blockscale 8-warp GEMM kernels#2487
Conversation
For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently transposes x_scale from row-major to column-major at runtime before every kernel launch. This is unnecessary when the kernel can natively read row-major AQ data — the CK pipeline already supports both layouts via the AQLayout trait in TileGemmQuantTraits. Changes: - Add `AQRowMajor` bool field to TileKernelInstance (default False for backward compatibility). When True on an 8-warp config, the kernel uses RowMajor AQLayout and skips the host-side transpose. - Add `AQRowMajor` template parameter to CreateTileGemmConfig / TileGemmConfig and expose as AQRowMajor_v. - Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select the AQ layout in GemmTraits and condition the host-side transpose. - Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels 10/11 in kernels_list_95x, so the tuner benchmarks both options. - Update gen_instances_cktile.py to emit the new template argument. - Also fix hardcoded strides (stride_A=K, stride_B=K) to read from tensor metadata, matching the fix in the stride PR. Made-with: Cursor
Tests verify: - TileKernelInstance name encoding with _aqrm suffix - is_eight_warp property correctness - AQRowMajor variants exist in candidate kernel dict - Both ColumnMajor and RowMajor 8-warp kernels match PyTorch reference output - RowMajor variant works with padded (non-contiguous) weight tensors from vLLM's _maybe_pad_fp8_weight Made-with: Cursor
Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor set. Correct test instances to use actual 8-warp config (4x2x1). Made-with: Cursor
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…wmajor_tunable_rebase Made-with: Cursor # Conflicts: # csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh
The AQRowMajor kernel variant is only in kernels_list_95x, not kernels_list_942, so the test fails on MI325X (gfx942) because no AQRowMajor entries exist in candidate_kernels_cktile_dict. Made-with: Cursor
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
…arp kernels Introduces AQRowMajor as a configurable parameter for 8-warp CKTile blockscale GEMM kernels, allowing the tuner to select between ColumnMajor and RowMajor AQ layouts. Resolved conflicts with our eight_warps naming vs upstream eight_waves. Source: ROCm#2487
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
…RowMajor optimization Source: ROCm#2487
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
Contributor
There was a problem hiding this comment.
Pull request overview
This PR makes activation-quantization scale layout (AQLayout) a tunable parameter for CKTile FP8 blockscale GEMM 8-warp kernels on gfx950, enabling an 8-warp RowMajor variant that can avoid the per-launch host-side x_scale transpose.
Changes:
- Add an
AQRowMajortemplate parameter toTileGemmConfigand use it to select RowMajor vs ColumnMajor AQ layout for 8-warp kernels. - Extend CKTile kernel instance generation to encode/propagate
AQRowMajor, registering a new 8-warp RowMajor variant (kernel ID 12). - Add a new op test validating name encoding, numerical accuracy vs reference, and padded-weight stride handling for the RowMajor AQ variant.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py | Adds a gfx950-only test for RowMajor vs ColumnMajor AQ variants, including accuracy and padded weight stride coverage. |
| csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh | Introduces AQRowMajor in config and adjusts host-side x_scale handling + AQ layout selection logic. |
| csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py | Propagates AQRowMajor into generated C++ template instantiations. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py | Adds AQRowMajor to instance metadata, name suffix encoding, and registers the new RowMajor 8-warp kernel variant. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Contributor
Author
|
@valarLip Would you have time to review this PR? The CI passes with a recent merge with main. |
1 task
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 22, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 25, 2026
Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 on MI355X using CK + CK-TILE backends with splitK support. Depends on: - PR ROCm#2862 (CK bump for stride fix in CK-TILE blockscale) - PR ROCm#2541 (splitK support for CK/CK-TILE blockscale GEMMs) - PR ROCm#2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)
valarLip
approved these changes
May 1, 2026
chun-wan
pushed a commit
that referenced
this pull request
May 4, 2026
…M kernels (#2487) * Add AQRowMajor tunable for CKTile blockscale 8-warp kernels For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently transposes x_scale from row-major to column-major at runtime before every kernel launch. This is unnecessary when the kernel can natively read row-major AQ data — the CK pipeline already supports both layouts via the AQLayout trait in TileGemmQuantTraits. Changes: - Add `AQRowMajor` bool field to TileKernelInstance (default False for backward compatibility). When True on an 8-warp config, the kernel uses RowMajor AQLayout and skips the host-side transpose. - Add `AQRowMajor` template parameter to CreateTileGemmConfig / TileGemmConfig and expose as AQRowMajor_v. - Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select the AQ layout in GemmTraits and condition the host-side transpose. - Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels 10/11 in kernels_list_95x, so the tuner benchmarks both options. - Update gen_instances_cktile.py to emit the new template argument. - Also fix hardcoded strides (stride_A=K, stride_B=K) to read from tensor metadata, matching the fix in the stride PR. Made-with: Cursor * Add tests for CKTile blockscale FP8 GEMM AQRowMajor optimization Tests verify: - TileKernelInstance name encoding with _aqrm suffix - is_eight_warp property correctness - AQRowMajor variants exist in candidate kernel dict - Both ColumnMajor and RowMajor 8-warp kernels match PyTorch reference output - RowMajor variant works with padded (non-contiguous) weight tensors from vLLM's _maybe_pad_fp8_weight Made-with: Cursor * Fix AQRowMajor kernel variant and test assertions Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor set. Correct test instances to use actual 8-warp config (4x2x1). Made-with: Cursor * Fix f-string Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run black * Update gemm_a8w8_blockscale_cktile_instance.py formatting * Update gemm_a8w8_blockscale_cktile_instance.py formatting * Gate AQRowMajor test to gfx950 only The AQRowMajor kernel variant is only in kernels_list_95x, not kernels_list_942, so the test fails on MI325X (gfx942) because no AQRowMajor entries exist in candidate_kernels_cktile_dict. Made-with: Cursor * Update op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
sunway513
added a commit
that referenced
this pull request
May 5, 2026
…3-Next, pa_mqa OOB) (#3005) * fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support) * Fixing two cascading bugs when running the MoE tuner * Enable split-K for block-scale A8W8 CK and CKTile GEMMs Propagate the splitK parameter (as KBatch = 2^splitK) through the block-scale GEMM kernel infrastructure so that the tuning scripts can sweep split-K values to improve occupancy on small-M shapes. CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call SetKBatch on the device argument. The CK invoker handles output zeroing and atomic accumulation internally. CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl, remove the "split-k is not supported yet" runtime guard, and add hipMemsetAsync to zero the output buffer before atomic accumulation. Non-tune entry points pass KBatch=1 (no split-K) to preserve existing behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py) updated to include the new parameter in generated wrappers and manifests. Made-with: Cursor * Wire splitK from tuning CSV through production blockscale GEMM dispatch The tuning infrastructure already sweeps splitK and writes it to the CSV, but the production dispatch ignored it and hardcoded KBatch=1. Add splitK as a runtime parameter to the non-tune entry points so tuned split-K values are used without compiling the full _tune instance set. Made-with: Cursor * fix: ck_moe_stage1 split-K output buffer overflow from padding scatter The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor * Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2 Co-authored-by: samremes <181322991+samremes@users.noreply.github.com> * black format * fix splitk test dimensions * Add gdn fusions * style: fix ruff F841 and black-format Triton PR files Remove unused variable in rmsnorm FP8 test ref. Apply Black to kernels, launchers, tests, and gated_delta_rule decode __init__. Made-with: Cursor * Update fused_rearrange_sigmoid_gdr.py * Update op_tests * Fix BLACK format problem * Fix black check failure * Update test_fused_rearrange_sigmoid_gdr.py * Allow callers to pass pre-allocated moe_buf to avoid output copy Add an optional `moe_buf` parameter through the moe_sorting and fused_moe call chain. When provided, the sorting kernel writes directly into the caller's buffer instead of allocating a new one, eliminating a redundant copy on the output path. Made-with: Cursor * Add moe_buf pass-through test to existing test_moe_sorting Made-with: Cursor * Replace _fast with _single_token for causal conv1d update kernels for single token decoding * Fix blck format error * Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 on MI355X using CK + CK-TILE backends with splitK support. Depends on: - PR #2862 (CK bump for stride fix in CK-TILE blockscale) - PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs) - PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels) * refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8 ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in _triton_kernels/quant/fused_fp8_quant.py; the Python entry point is fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring that contrasts it with fused_rms_fp8_group_quant. Remove the old rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file. Re-export the new symbol and helpers (get_fp8_min_max_bounds, calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to test_fused_rms_gated_fp8_group_quant.py and update test.sh. BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use fused_rms_gated_fp8_group_quant instead. Made-with: Cursor * Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256). Key changes: - SplitK usage dropped from 613 to 88 CK shapes (splitK > 0) - All shapes validated via --run_config (1482/1482 OK) - E2e perf: 2-8% output throughput improvement vs untuned heuristic * [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10 `buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the upper-bound mask present on their sibling stores. When `context_length == max_model_len` (the last-token position in a long- context decode step), `split_context_length` is rounded UP to a `KVBlockSize` multiple at line 427 and the final prefix/suffix store then writes up to `ChunkKPerStage` float32 elements past the logical row end. With `stride_out_batch == max_model_len`, those writes cross into the next row / the next allocation, causing intermittent HIP memory-access faults on gfx950 during DeepSeek V3.2 MTP decoding. This change adds `mask=<offset> < max_model_len` to every unmasked `buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching the pattern of their already-masked neighbours. The existing `tl.where(..., -inf)` masking of the *values* is preserved; the only behavioural change is that out-of-row lanes no longer emit buffer stores. Hardware overhead is negligible: `buffer_store` with a predicate is the same SMEM descriptor path as the unmasked variant, just with a VCC mask setup. Repro + end-to-end fix evidence: see PR description. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> * style: fix Black formatting * style: fix Black formatting (Python 3.12 compatible) * ci: replace deprecated zmq package with pyzmq The `zmq` meta-package fails to install on some CI runners because it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly, which is the actual package providing ZeroMQ bindings for Python. Fixes Triton Test Shard 7 setup failures. * ci: increase pip retries and timeout for CI reliability Set pip global retries=15 and timeout=120s in build_aiter_triton.sh to handle transient PyPI network failures on self-hosted runners. Shard 5/7 failures were caused by RemoteDisconnected during pip install. * ci: make pyzmq install non-blocking in triton test setup pyzmq is only used by aiter.dist.shm_broadcast, not by any triton test. When PyPI is unreachable on self-hosted runners, the pyzmq install failure should not block the entire CI shard. Split pyzmq into a separate pip install with || fallback so triton tests can proceed even when PyPI connectivity is degraded. * ci: retry pip install individually on batch failure When batch pip install fails (e.g., PyPI connectivity issues on self-hosted runners), retry each package individually. Only pyzmq is allowed to fail silently since it's only used by aiter.dist.shm_broadcast and not required by any CI test suite. Critical packages (pandas, einops, numpy) must still succeed. * [MLA] Fix nhead=32 non-persistent decode crash on gfx950 Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32 (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes directly to the output buffer via ptr_RP when kv_split==1. Dereferencing nullptr causes a GPU memory access fault during CUDA graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4. Fix: - Conditionally restore ptr_RP and out_16_nosplit in the non-persistent path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while keeping nullptr for newer kernels (e.g. gqa_ratio=64). - Restore the bf16 nhead in [32,64] early-return after stage1 when num_kv_splits==1 to prevent stage2 from overwriting the kernel's direct output. Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32): - No crash during CUDA graph capture - Correct GSM8K accuracy Made-with: Cursor * revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32 non-persistent decode fix causes deterministic test_mla k_cache and mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2). #2983 should go through its own PR with proper CI validation by the original author (frida-andersson). * fix: restore tuple unpack for FlyDSL fused-quant stage1 return flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses fused fp4/fp8 quantization. The tuple unpack logic was removed during earlier refactoring but the kernel behavior was not changed, causing fused_moe_2stages to crash with: AttributeError: 'tuple' object has no attribute 'view' Restore the unpack: detect tuple return, extract tensor and scale, handle fp4 byte-packing trim, and skip redundant Python-side requant when the kernel already produced sorted scales. * Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py - Restore import to match main: use `from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of importing from internal triton path and fp4_utils - Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd() using correct parameter names (cols= instead of block_size=) - Remove all moe_buf preallocated buffer additions (PR #2687 rejected): parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl, moe_sorting, fused_moe, fused_moe_fake, and fused_moe_ - Fix moe_sorting_dispatch_policy type annotation: bool -> int in fused_moe_fake and fused_moe_ - Remove moe_buf pass-through test from test_moe_sorting.py - Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with local imports in stage1/stage2 fallback functions * fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd) * style: fix Black formatting for local imports * fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert Lingpeng explicitly rejected this change (from excluded PR #2457). Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights. * fix: restore silently-reverted main features from bad merge resolution aiter/fused_moe.py: - Restore to origin/main. Per sunway513's own comment, #2457 and #2547 were excluded from this bulk merge; per valarLip, #2687 was rejected. No source PR should land changes in this file. The previous state (+110/-119 vs main) was collateral damage from auto-resolved conflicts taking older sides, which silently reverted #2262 (xbf16 asm fmoe path), #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8 blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip, flagged by valarLip). op_tests/test_gemm_a8w8_blockscale.py: - Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0 vs main, matching #2541's actual contribution exactly. The previous state was silently reverting #2645 (CK GEMM multi-arch + test infra: TEST_NUM_ITERS, --csv/--output args, kernel_name= param). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore: remove #2464 from bulk merge per author request @xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this bulk merge — they don't need it for the uplift. Verified that #2464 is the only PR in this bulk merge touching aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the branch and origin/main on those files is exactly #2464's +9/-1 and +5/-0, with no other PR content mixed in. Restoring both files to origin/main therefore drops #2464 cleanly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: vecheruk-amd <vecheruk@amd.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Sami Remes <samremes@amd.com> Co-authored-by: Li <chuali@amd.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: samremes <181322991+samremes@users.noreply.github.com> Co-authored-by: hellozhuo <zhuo.su@amd.com> Co-authored-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com> Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: frida-andersson <fanderss@amd.com> Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Liang-jianhao97
pushed a commit
that referenced
this pull request
May 7, 2026
…M kernels (#2487) * Add AQRowMajor tunable for CKTile blockscale 8-warp kernels For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently transposes x_scale from row-major to column-major at runtime before every kernel launch. This is unnecessary when the kernel can natively read row-major AQ data — the CK pipeline already supports both layouts via the AQLayout trait in TileGemmQuantTraits. Changes: - Add `AQRowMajor` bool field to TileKernelInstance (default False for backward compatibility). When True on an 8-warp config, the kernel uses RowMajor AQLayout and skips the host-side transpose. - Add `AQRowMajor` template parameter to CreateTileGemmConfig / TileGemmConfig and expose as AQRowMajor_v. - Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select the AQ layout in GemmTraits and condition the host-side transpose. - Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels 10/11 in kernels_list_95x, so the tuner benchmarks both options. - Update gen_instances_cktile.py to emit the new template argument. - Also fix hardcoded strides (stride_A=K, stride_B=K) to read from tensor metadata, matching the fix in the stride PR. Made-with: Cursor * Add tests for CKTile blockscale FP8 GEMM AQRowMajor optimization Tests verify: - TileKernelInstance name encoding with _aqrm suffix - is_eight_warp property correctness - AQRowMajor variants exist in candidate kernel dict - Both ColumnMajor and RowMajor 8-warp kernels match PyTorch reference output - RowMajor variant works with padded (non-contiguous) weight tensors from vLLM's _maybe_pad_fp8_weight Made-with: Cursor * Fix AQRowMajor kernel variant and test assertions Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor set. Correct test instances to use actual 8-warp config (4x2x1). Made-with: Cursor * Fix f-string Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run black * Update gemm_a8w8_blockscale_cktile_instance.py formatting * Update gemm_a8w8_blockscale_cktile_instance.py formatting * Gate AQRowMajor test to gfx950 only The AQRowMajor kernel variant is only in kernels_list_95x, not kernels_list_942, so the test fails on MI325X (gfx942) because no AQRowMajor entries exist in candidate_kernels_cktile_dict. Made-with: Cursor * Update op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Liang-jianhao97
pushed a commit
that referenced
this pull request
May 7, 2026
…3-Next, pa_mqa OOB) (#3005) * fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support) * Fixing two cascading bugs when running the MoE tuner * Enable split-K for block-scale A8W8 CK and CKTile GEMMs Propagate the splitK parameter (as KBatch = 2^splitK) through the block-scale GEMM kernel infrastructure so that the tuning scripts can sweep split-K values to improve occupancy on small-M shapes. CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call SetKBatch on the device argument. The CK invoker handles output zeroing and atomic accumulation internally. CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl, remove the "split-k is not supported yet" runtime guard, and add hipMemsetAsync to zero the output buffer before atomic accumulation. Non-tune entry points pass KBatch=1 (no split-K) to preserve existing behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py) updated to include the new parameter in generated wrappers and manifests. Made-with: Cursor * Wire splitK from tuning CSV through production blockscale GEMM dispatch The tuning infrastructure already sweeps splitK and writes it to the CSV, but the production dispatch ignored it and hardcoded KBatch=1. Add splitK as a runtime parameter to the non-tune entry points so tuned split-K values are used without compiling the full _tune instance set. Made-with: Cursor * fix: ck_moe_stage1 split-K output buffer overflow from padding scatter The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor * Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2 Co-authored-by: samremes <181322991+samremes@users.noreply.github.com> * black format * fix splitk test dimensions * Add gdn fusions * style: fix ruff F841 and black-format Triton PR files Remove unused variable in rmsnorm FP8 test ref. Apply Black to kernels, launchers, tests, and gated_delta_rule decode __init__. Made-with: Cursor * Update fused_rearrange_sigmoid_gdr.py * Update op_tests * Fix BLACK format problem * Fix black check failure * Update test_fused_rearrange_sigmoid_gdr.py * Allow callers to pass pre-allocated moe_buf to avoid output copy Add an optional `moe_buf` parameter through the moe_sorting and fused_moe call chain. When provided, the sorting kernel writes directly into the caller's buffer instead of allocating a new one, eliminating a redundant copy on the output path. Made-with: Cursor * Add moe_buf pass-through test to existing test_moe_sorting Made-with: Cursor * Replace _fast with _single_token for causal conv1d update kernels for single token decoding * Fix blck format error * Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 on MI355X using CK + CK-TILE backends with splitK support. Depends on: - PR #2862 (CK bump for stride fix in CK-TILE blockscale) - PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs) - PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels) * refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8 ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in _triton_kernels/quant/fused_fp8_quant.py; the Python entry point is fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring that contrasts it with fused_rms_fp8_group_quant. Remove the old rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file. Re-export the new symbol and helpers (get_fp8_min_max_bounds, calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to test_fused_rms_gated_fp8_group_quant.py and update test.sh. BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use fused_rms_gated_fp8_group_quant instead. Made-with: Cursor * Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256). Key changes: - SplitK usage dropped from 613 to 88 CK shapes (splitK > 0) - All shapes validated via --run_config (1482/1482 OK) - E2e perf: 2-8% output throughput improvement vs untuned heuristic * [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10 `buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the upper-bound mask present on their sibling stores. When `context_length == max_model_len` (the last-token position in a long- context decode step), `split_context_length` is rounded UP to a `KVBlockSize` multiple at line 427 and the final prefix/suffix store then writes up to `ChunkKPerStage` float32 elements past the logical row end. With `stride_out_batch == max_model_len`, those writes cross into the next row / the next allocation, causing intermittent HIP memory-access faults on gfx950 during DeepSeek V3.2 MTP decoding. This change adds `mask=<offset> < max_model_len` to every unmasked `buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching the pattern of their already-masked neighbours. The existing `tl.where(..., -inf)` masking of the *values* is preserved; the only behavioural change is that out-of-row lanes no longer emit buffer stores. Hardware overhead is negligible: `buffer_store` with a predicate is the same SMEM descriptor path as the unmasked variant, just with a VCC mask setup. Repro + end-to-end fix evidence: see PR description. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> * style: fix Black formatting * style: fix Black formatting (Python 3.12 compatible) * ci: replace deprecated zmq package with pyzmq The `zmq` meta-package fails to install on some CI runners because it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly, which is the actual package providing ZeroMQ bindings for Python. Fixes Triton Test Shard 7 setup failures. * ci: increase pip retries and timeout for CI reliability Set pip global retries=15 and timeout=120s in build_aiter_triton.sh to handle transient PyPI network failures on self-hosted runners. Shard 5/7 failures were caused by RemoteDisconnected during pip install. * ci: make pyzmq install non-blocking in triton test setup pyzmq is only used by aiter.dist.shm_broadcast, not by any triton test. When PyPI is unreachable on self-hosted runners, the pyzmq install failure should not block the entire CI shard. Split pyzmq into a separate pip install with || fallback so triton tests can proceed even when PyPI connectivity is degraded. * ci: retry pip install individually on batch failure When batch pip install fails (e.g., PyPI connectivity issues on self-hosted runners), retry each package individually. Only pyzmq is allowed to fail silently since it's only used by aiter.dist.shm_broadcast and not required by any CI test suite. Critical packages (pandas, einops, numpy) must still succeed. * [MLA] Fix nhead=32 non-persistent decode crash on gfx950 Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32 (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes directly to the output buffer via ptr_RP when kv_split==1. Dereferencing nullptr causes a GPU memory access fault during CUDA graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4. Fix: - Conditionally restore ptr_RP and out_16_nosplit in the non-persistent path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while keeping nullptr for newer kernels (e.g. gqa_ratio=64). - Restore the bf16 nhead in [32,64] early-return after stage1 when num_kv_splits==1 to prevent stage2 from overwriting the kernel's direct output. Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32): - No crash during CUDA graph capture - Correct GSM8K accuracy Made-with: Cursor * revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32 non-persistent decode fix causes deterministic test_mla k_cache and mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2). #2983 should go through its own PR with proper CI validation by the original author (frida-andersson). * fix: restore tuple unpack for FlyDSL fused-quant stage1 return flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses fused fp4/fp8 quantization. The tuple unpack logic was removed during earlier refactoring but the kernel behavior was not changed, causing fused_moe_2stages to crash with: AttributeError: 'tuple' object has no attribute 'view' Restore the unpack: detect tuple return, extract tensor and scale, handle fp4 byte-packing trim, and skip redundant Python-side requant when the kernel already produced sorted scales. * Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py - Restore import to match main: use `from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of importing from internal triton path and fp4_utils - Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd() using correct parameter names (cols= instead of block_size=) - Remove all moe_buf preallocated buffer additions (PR #2687 rejected): parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl, moe_sorting, fused_moe, fused_moe_fake, and fused_moe_ - Fix moe_sorting_dispatch_policy type annotation: bool -> int in fused_moe_fake and fused_moe_ - Remove moe_buf pass-through test from test_moe_sorting.py - Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with local imports in stage1/stage2 fallback functions * fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd) * style: fix Black formatting for local imports * fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert Lingpeng explicitly rejected this change (from excluded PR #2457). Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights. * fix: restore silently-reverted main features from bad merge resolution aiter/fused_moe.py: - Restore to origin/main. Per sunway513's own comment, #2457 and #2547 were excluded from this bulk merge; per valarLip, #2687 was rejected. No source PR should land changes in this file. The previous state (+110/-119 vs main) was collateral damage from auto-resolved conflicts taking older sides, which silently reverted #2262 (xbf16 asm fmoe path), #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8 blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip, flagged by valarLip). op_tests/test_gemm_a8w8_blockscale.py: - Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0 vs main, matching #2541's actual contribution exactly. The previous state was silently reverting #2645 (CK GEMM multi-arch + test infra: TEST_NUM_ITERS, --csv/--output args, kernel_name= param). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore: remove #2464 from bulk merge per author request @xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this bulk merge — they don't need it for the uplift. Verified that #2464 is the only PR in this bulk merge touching aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the branch and origin/main on those files is exactly #2464's +9/-1 and +5/-0, with no other PR content mixed in. Restoring both files to origin/main therefore drops #2464 cleanly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: vecheruk-amd <vecheruk@amd.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Sami Remes <samremes@amd.com> Co-authored-by: Li <chuali@amd.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: samremes <181322991+samremes@users.noreply.github.com> Co-authored-by: hellozhuo <zhuo.su@amd.com> Co-authored-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com> Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: frida-andersson <fanderss@amd.com> Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
For CKTile blockscale FP8 GEMM 8-warp kernels on gfx950, the activation quantization scale (
x_scale) is currently transposed from RowMajor to ColumnMajor on the host before every kernel launch. This PR makes theAQLayouta tunable parameter, allowing the kernel to natively readx_scalein RowMajor layout and skip the host-side transpose entirely.Changes
gemm_a8w8_blockscale_cktile_common.cuh: AddAQRowMajortemplate parameter toTileGemmConfig. When enabled, the 8-warp pipeline uses RowMajorAQLayoutand the host-sidex_scaletranspose + allocation is skipped.gemm_a8w8_blockscale_cktile_instance.py: AddAQRowMajorfield toTileKernelInstancewithis_eight_warpproperty. Register a new RowMajor variant (kernel ID 12) alongside the existing ColumnMajor 8-warp kernel (ID 11), so the tuner can evaluate both.gen_instances_cktile.py: PropagateAQRowMajorinto the generated C++ template instantiation.test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py: New test covering instance name encoding, numerical accuracy (RowMajor vs ColumnMajor vs PyTorch reference), and padded weight stride handling (simulating vLLM's_maybe_pad_fp8_weight).Performance
Benchmarked on gfx950 using the tuning script (
gemm_a8w8_blockscale_tune.py --libtype cktile). Comparing the two 8-warp variants head-to-head:The tuner automatically selects the best variant per shape — no manual configuration needed.
Non-8-warp kernels
Non-8-warp kernels always use RowMajor
AQLayoutand are unaffected by this change. TheAQRowMajorflag is only meaningful for 8-warp configurations.Test plan
test_instance_names: Verifies_aqrmsuffix encoding andis_eight_warppropertytest_accuracy: RowMajor and ColumnMajor outputs match PyTorch FP32 reference across 4 shapestest_padded_weight_stride: RowMajor kernel handles non-contiguous (padded) weight tensors correctly