Skip to content

[Triton] fused_flatten_fp8_group_quant: add transpose_scale param#3041

Open
Jacob0226 wants to merge 2 commits intoROCm:mainfrom
Jacob0226:jacob/fused-flatten-transpose-scale
Open

[Triton] fused_flatten_fp8_group_quant: add transpose_scale param#3041
Jacob0226 wants to merge 2 commits intoROCm:mainfrom
Jacob0226:jacob/fused-flatten-transpose-scale

Conversation

@Jacob0226
Copy link
Copy Markdown
Contributor

@Jacob0226 Jacob0226 commented May 6, 2026

Motivation

aiter/ops/triton/quant/fused_fp8_quant.py::fused_rms_fp8_group_quant already supports transpose_scale=True (writes the per-token-group scale tensor in a pre-transposed (num_bs_cols, M) row-major storage so downstream consumers like CK bpreshuffle GEMM can read it directly without an extra .transpose(-1, -2).contiguous() copy on the consumer side).

fused_flatten_fp8_group_quant in the same file produces the same shape of scale tensor ((M, cdiv(N1*N2, group_size))), is fed into the same downstream GEMM path in sglang's aiter_w8a8_block_fp8_linear, but does NOT currently support transpose_scale. As a result, on the o_proj path of GLM-5.1-FP8 (and any other Linear consuming a flatten-quanted activation), sglang has to do an explicit .transpose(-1, -2).contiguous() on the scale tensor before each bpreshuffle GEMM, costing ~4us per call.

This PR mirrors the same transpose_scale flag onto fused_flatten_fp8_group_quant so the consumer can opt in and skip that redundant copy. Companion sglang patch: Jacob0226/sglang:jacob/fused-transpose-scale (also opened upstream as sgl-project/sglang#24476; this aiter PR is its prerequisite).

Technical Details

aiter/ops/triton/quant/fused_fp8_quant.py::fused_flatten_fp8_group_quant (Python wrapper only — inner _fused_flatten_fp8_group_quant_kernel is unchanged):

  • New keyword arg transpose_scale: bool = False (default preserves existing behavior; default callers see no change).
  • When transpose_scale=True:
    • Allocate out_block_scales as (num_bs_cols, M) instead of (M, num_bs_cols).
    • Pass swapped out_bs_row_stride / out_bs_col_stride (1, M) to the inner kernel so it writes column-major into the same buffer.
    • At the end, view(M, num_bs_cols) to keep the public shape identical.
  • The pattern is identical to the existing transpose_scale handling on fused_rms_fp8_group_quant in the same file; this is just propagating it to the flatten variant.

Net diff: +33 / -4 lines, single function, single file.

Test Plan

  1. Run the entire existing UT file op_tests/triton_tests/quant/test_fused_fp8_quant.py to confirm:
    • test_fused_flatten_fp8_group_quant (default path of the patched function) still passes — verifies transpose_scale=False behavior is unchanged.
    • test_fused_rms_fp8_group_quant and test_fused_rms_fp8_group_quant_transpose_scale (the function whose transpose_scale handling this patch mirrors) still pass — verifies no cross-function regression.
  2. Add test_fused_flatten_fp8_group_quant_transpose_scale to the same file — a new pytest case mirroring the sister test_fused_rms_fp8_group_quant_transpose_scale. Parametrized over M ∈ {1, 32, 256} × (N1, N2) ∈ {(16, 128)} × dtype ∈ {float16, bfloat16} (6 cases, same coverage as the existing default test_fused_flatten_fp8_group_quant). For each case it:
    • Calls fused_flatten_fp8_group_quant with transpose_scale=False and transpose_scale=True.
    • Asserts both outputs have shape (M, num_bs_cols) and are contiguous after the inner .view().
    • Asserts scale_transposed is byte-equivalent to scale_default.transpose(0, 1).contiguous().view(*scale_default.shape) (the manual transform consumers like sglang's aiter_w8a8_block_fp8_linear perform on the bpreshuffle path).
    • Asserts the FP8 quanted tensor is bit-identical between the two calls (atol=0, rtol=0).
  3. End-to-end performance measurement via the companion sglang PR ([AMD] Skip x_scale.transpose+contiguous before bpreshuffle GEMM via upstream pre-transposed scale sgl-project/sglang#24476) on GLM-5.1-FP8 / MI355X / TP=8.

Hardware / software: MI355X (gfx950), ROCm 7.2, docker rocm/sgl-dev:v0.5.10.post1-rocm720-mi35x-20260503.

Test Result

Reproduce

cd <aiter>
# Full file (300 cases — regression + new transpose_scale UT):
pytest op_tests/triton_tests/quant/test_fused_fp8_quant.py -v

# Just the new transpose_scale case (6 cases):
pytest op_tests/triton_tests/quant/test_fused_fp8_quant.py::test_fused_flatten_fp8_group_quant_transpose_scale -v

op_tests/triton_tests/quant/test_fused_fp8_quant.py — 300 / 300 PASS on MI355X (27s)

300 = 294 existing + 6 newly added. Highlights:

  • test_fused_flatten_fp8_group_quant (default path of the patched function): 6 / 6 PASS — confirms transpose_scale=False behavior is unchanged.
  • test_fused_flatten_fp8_group_quant_transpose_scale (NEW, added in this PR): 6 / 6 PASS — confirms transpose_scale=True produces a scale tensor byte-equivalent to default.transpose(0, 1).contiguous().view(...) and a bit-identical FP8 quantized tensor.
  • test_fused_rms_fp8_group_quant and test_fused_rms_fp8_group_quant_transpose_scale (the sister function this patch mirrors): 36 / 36 PASS — confirms no cross-function regression.

End-to-end performance (via companion sglang PR sgl-project/sglang#24476)

  • GLM-5.1-FP8 on MI355X TP=8, attention block: o_proj direct-copy kernel (~4 us / layer) eliminated in trace.
  • E2E throughput: +0.6% incremental on top of the rms variant (fused_rms_fp8_group_quant(transpose_scale=True)) opt-in. Combined with the rms opt-in, total +3.1% throughput / -3.1% TPOT averaged across 10 ISL/OSL/concurrency points (1k1k, 8k1k × concurrency 4/8/16/32/64).
  • GSM8K accuracy: 0.943 (within run-to-run noise of baseline 0.948) — confirms the optimization is byte-equivalent at the GEMM level.

Submission Checklist

Mirror the same flag on fused_rms_fp8_group_quant. When transpose_scale=True,
the returned scale tensor has shape (M, cdiv(N1*N2, group_size)) but stored in
column-major (transposed) memory layout — equivalent to applying
.transpose(0, 1).contiguous().view(*scale.shape) on the default-layout output.

This lets downstream consumers (CK bpreshuffle GEMM in sglang
aiter_w8a8_block_fp8_linear) skip a redundant transpose+copy that takes
~4us per call. Inner _fused_flatten_fp8_group_quant_kernel is unchanged;
only the Python wrapper swaps the stride params and views the buffer back
to (M, num_bs_cols) shape at the end.

Validated bit-exact: scale_t (this patch) has byte-identical underlying
storage as the equivalent manual transform from the default path.

Co-authored-by: Cursor <cursoragent@cursor.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 3041 --add-label <label>

Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request May 6, 2026
The previous commit covered fused_rms_fp8_group_quant (rms+quant for q_a/kv_a,
qkv_a_proj, wk inputs). The o_proj path uses fused_flatten_fp8_group_quant
which did not support transpose_scale, so the trace still showed a direct_copy
4.2us/layer for o_proj. With this change + an aiter-side patch adding
transpose_scale to fused_flatten_fp8_group_quant, that last direct_copy is
also gone.

Also fixes a latent bug in the previous commit's aiter_w8a8_block_fp8_linear
logic. The 'transpose-contig-view' trick is NOT a self-inverse, so when an
upstream marker indicates pre-transposed scale but the GEMM dispatch picks
triton (which wants default layout), the original code did 'apply
transpose-contig-view to undo' which actually scrambles data instead. The new
logic uses the proper inverse 'x_scale.view(num_bs_cols, M).transpose(0, 1)
.contiguous()' for that path. The previous commit's E2E result was still
correct because GLM-5.1-FP8 runtime never picks the triton path for any of
the 5 MLA GEMMs - all dispatch to CK bpreshuffle in trace - but defensive
correctness matters for future shapes and models.

Companion aiter PR (ROCm/aiter#3041): fused_flatten_fp8_group_quant gains a
transpose_scale: bool = False parameter, mirroring fused_rms_fp8_group_quant.
Implementation swaps stride params passed to the inner
_fused_flatten_fp8_group_quant_kernel and views back to shape (M, num_bs_cols)
at the end. Inner triton kernel needs no changes.

E2E impact on GLM-5.1-FP8 MI355X TP=8 averaged across 10 configs:
- Throughput: +3.1% vs baseline
- Median TPOT: -3.1% (faster) vs baseline
- GSM8K: 0.943 (within noise of baseline 0.948)
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request May 6, 2026
The previous commit covered fused_rms_fp8_group_quant (rms+quant for q_a/kv_a,
qkv_a_proj, wk inputs). The o_proj path uses fused_flatten_fp8_group_quant
which did not support transpose_scale, so the trace still showed a direct_copy
4.2us/layer for o_proj. With this change + an aiter-side patch adding
transpose_scale to fused_flatten_fp8_group_quant, that last direct_copy is
also gone.

Also fixes a latent bug in the previous commit's aiter_w8a8_block_fp8_linear
logic. The 'transpose-contig-view' trick is NOT a self-inverse, so when an
upstream marker indicates pre-transposed scale but the GEMM dispatch picks
triton (which wants default layout), the original code did 'apply
transpose-contig-view to undo' which actually scrambles data instead. The new
logic uses the proper inverse 'x_scale.view(num_bs_cols, M).transpose(0, 1)
.contiguous()' for that path. The previous commit's E2E result was still
correct because GLM-5.1-FP8 runtime never picks the triton path for any of
the 5 MLA GEMMs - all dispatch to CK bpreshuffle in trace - but defensive
correctness matters for future shapes and models.

Companion aiter PR (ROCm/aiter#3041): fused_flatten_fp8_group_quant gains a
transpose_scale: bool = False parameter, mirroring fused_rms_fp8_group_quant.
Implementation swaps stride params passed to the inner
_fused_flatten_fp8_group_quant_kernel and views back to shape (M, num_bs_cols)
at the end. Inner triton kernel needs no changes.

E2E impact on GLM-5.1-FP8 MI355X TP=8 averaged across 10 configs:
- Throughput: +3.1% vs baseline
- Median TPOT: -3.1% (faster) vs baseline
- GSM8K: 0.943 (within noise of baseline 0.948)
Mirrors test_fused_rms_fp8_group_quant_transpose_scale on the flatten
variant: parametrized over M={1,32,256}, (N1,N2)={(16,128)}, dtype={f16,bf16}
(6 cases, same coverage as the existing default test_fused_flatten_fp8_group_quant).

Verifies that fused_flatten_fp8_group_quant(transpose_scale=True) produces:
- Same public shape (M, num_bs_cols) as transpose_scale=False.
- Scale tensor byte-equivalent to default_scale.transpose(0, 1).contiguous()
  .view(*default_scale.shape) (the manual transform the new path replaces
  on consumers like sglang aiter_w8a8_block_fp8_linear).
- Bit-identical FP8 quantized tensor (atol=0, rtol=0).

Result on MI355X: 6 / 6 PASS.
@Jacob0226 Jacob0226 marked this pull request as ready for review May 6, 2026 11:05
@Jacob0226 Jacob0226 requested review from a team and Copilot May 6, 2026 11:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a transpose_scale option to fused_flatten_fp8_group_quant so callers can request the scale tensor in the same pre-transposed storage format already supported by fused_rms_fp8_group_quant, avoiding an extra transpose().contiguous() step in downstream GEMM consumers.

Changes:

  • Add transpose_scale: bool = False to fused_flatten_fp8_group_quant, with an alternate scale-buffer allocation + stride wiring to have the kernel write directly into the transposed-storage layout.
  • Preserve the public scale shape (M, cdiv(N1*N2, group_size)) by viewing the transposed buffer back to (M, num_bs_cols) post-kernel.
  • Add a new pytest case validating transpose_scale=True is byte-equivalent to scale_default.transpose(0, 1).contiguous().view(*scale_default.shape) and that FP8 outputs match exactly.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
aiter/ops/triton/quant/fused_fp8_quant.py Adds transpose_scale to fused_flatten_fp8_group_quant by allocating scales in transposed shape and passing swapped strides to the existing Triton kernel, then viewing back to the original public shape.
op_tests/triton_tests/quant/test_fused_fp8_quant.py Adds test_fused_flatten_fp8_group_quant_transpose_scale to verify the new layout contract and that quantized outputs are unchanged.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants