[Triton] fused_flatten_fp8_group_quant: add transpose_scale param#3041
Open
[Triton] fused_flatten_fp8_group_quant: add transpose_scale param#3041
Conversation
Mirror the same flag on fused_rms_fp8_group_quant. When transpose_scale=True, the returned scale tensor has shape (M, cdiv(N1*N2, group_size)) but stored in column-major (transposed) memory layout — equivalent to applying .transpose(0, 1).contiguous().view(*scale.shape) on the default-layout output. This lets downstream consumers (CK bpreshuffle GEMM in sglang aiter_w8a8_block_fp8_linear) skip a redundant transpose+copy that takes ~4us per call. Inner _fused_flatten_fp8_group_quant_kernel is unchanged; only the Python wrapper swaps the stride params and views the buffer back to (M, num_bs_cols) shape at the end. Validated bit-exact: scale_t (this patch) has byte-identical underlying storage as the equivalent manual transform from the default path. Co-authored-by: Cursor <cursoragent@cursor.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
5 tasks
Jacob0226
added a commit
to Jacob0226/sglang
that referenced
this pull request
May 6, 2026
The previous commit covered fused_rms_fp8_group_quant (rms+quant for q_a/kv_a, qkv_a_proj, wk inputs). The o_proj path uses fused_flatten_fp8_group_quant which did not support transpose_scale, so the trace still showed a direct_copy 4.2us/layer for o_proj. With this change + an aiter-side patch adding transpose_scale to fused_flatten_fp8_group_quant, that last direct_copy is also gone. Also fixes a latent bug in the previous commit's aiter_w8a8_block_fp8_linear logic. The 'transpose-contig-view' trick is NOT a self-inverse, so when an upstream marker indicates pre-transposed scale but the GEMM dispatch picks triton (which wants default layout), the original code did 'apply transpose-contig-view to undo' which actually scrambles data instead. The new logic uses the proper inverse 'x_scale.view(num_bs_cols, M).transpose(0, 1) .contiguous()' for that path. The previous commit's E2E result was still correct because GLM-5.1-FP8 runtime never picks the triton path for any of the 5 MLA GEMMs - all dispatch to CK bpreshuffle in trace - but defensive correctness matters for future shapes and models. Companion aiter PR (ROCm/aiter#3041): fused_flatten_fp8_group_quant gains a transpose_scale: bool = False parameter, mirroring fused_rms_fp8_group_quant. Implementation swaps stride params passed to the inner _fused_flatten_fp8_group_quant_kernel and views back to shape (M, num_bs_cols) at the end. Inner triton kernel needs no changes. E2E impact on GLM-5.1-FP8 MI355X TP=8 averaged across 10 configs: - Throughput: +3.1% vs baseline - Median TPOT: -3.1% (faster) vs baseline - GSM8K: 0.943 (within noise of baseline 0.948)
Jacob0226
added a commit
to Jacob0226/sglang
that referenced
this pull request
May 6, 2026
The previous commit covered fused_rms_fp8_group_quant (rms+quant for q_a/kv_a, qkv_a_proj, wk inputs). The o_proj path uses fused_flatten_fp8_group_quant which did not support transpose_scale, so the trace still showed a direct_copy 4.2us/layer for o_proj. With this change + an aiter-side patch adding transpose_scale to fused_flatten_fp8_group_quant, that last direct_copy is also gone. Also fixes a latent bug in the previous commit's aiter_w8a8_block_fp8_linear logic. The 'transpose-contig-view' trick is NOT a self-inverse, so when an upstream marker indicates pre-transposed scale but the GEMM dispatch picks triton (which wants default layout), the original code did 'apply transpose-contig-view to undo' which actually scrambles data instead. The new logic uses the proper inverse 'x_scale.view(num_bs_cols, M).transpose(0, 1) .contiguous()' for that path. The previous commit's E2E result was still correct because GLM-5.1-FP8 runtime never picks the triton path for any of the 5 MLA GEMMs - all dispatch to CK bpreshuffle in trace - but defensive correctness matters for future shapes and models. Companion aiter PR (ROCm/aiter#3041): fused_flatten_fp8_group_quant gains a transpose_scale: bool = False parameter, mirroring fused_rms_fp8_group_quant. Implementation swaps stride params passed to the inner _fused_flatten_fp8_group_quant_kernel and views back to shape (M, num_bs_cols) at the end. Inner triton kernel needs no changes. E2E impact on GLM-5.1-FP8 MI355X TP=8 averaged across 10 configs: - Throughput: +3.1% vs baseline - Median TPOT: -3.1% (faster) vs baseline - GSM8K: 0.943 (within noise of baseline 0.948)
Mirrors test_fused_rms_fp8_group_quant_transpose_scale on the flatten
variant: parametrized over M={1,32,256}, (N1,N2)={(16,128)}, dtype={f16,bf16}
(6 cases, same coverage as the existing default test_fused_flatten_fp8_group_quant).
Verifies that fused_flatten_fp8_group_quant(transpose_scale=True) produces:
- Same public shape (M, num_bs_cols) as transpose_scale=False.
- Scale tensor byte-equivalent to default_scale.transpose(0, 1).contiguous()
.view(*default_scale.shape) (the manual transform the new path replaces
on consumers like sglang aiter_w8a8_block_fp8_linear).
- Bit-identical FP8 quantized tensor (atol=0, rtol=0).
Result on MI355X: 6 / 6 PASS.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds a transpose_scale option to fused_flatten_fp8_group_quant so callers can request the scale tensor in the same pre-transposed storage format already supported by fused_rms_fp8_group_quant, avoiding an extra transpose().contiguous() step in downstream GEMM consumers.
Changes:
- Add
transpose_scale: bool = Falsetofused_flatten_fp8_group_quant, with an alternate scale-buffer allocation + stride wiring to have the kernel write directly into the transposed-storage layout. - Preserve the public scale shape
(M, cdiv(N1*N2, group_size))by viewing the transposed buffer back to(M, num_bs_cols)post-kernel. - Add a new pytest case validating
transpose_scale=Trueis byte-equivalent toscale_default.transpose(0, 1).contiguous().view(*scale_default.shape)and that FP8 outputs match exactly.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
aiter/ops/triton/quant/fused_fp8_quant.py |
Adds transpose_scale to fused_flatten_fp8_group_quant by allocating scales in transposed shape and passing swapped strides to the existing Triton kernel, then viewing back to the original public shape. |
op_tests/triton_tests/quant/test_fused_fp8_quant.py |
Adds test_fused_flatten_fp8_group_quant_transpose_scale to verify the new layout contract and that quantized outputs are unchanged. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
aiter/ops/triton/quant/fused_fp8_quant.py::fused_rms_fp8_group_quantalready supportstranspose_scale=True(writes the per-token-group scale tensor in a pre-transposed(num_bs_cols, M)row-major storage so downstream consumers like CK bpreshuffle GEMM can read it directly without an extra.transpose(-1, -2).contiguous()copy on the consumer side).fused_flatten_fp8_group_quantin the same file produces the same shape of scale tensor ((M, cdiv(N1*N2, group_size))), is fed into the same downstream GEMM path in sglang'saiter_w8a8_block_fp8_linear, but does NOT currently supporttranspose_scale. As a result, on the o_proj path of GLM-5.1-FP8 (and any other Linear consuming a flatten-quanted activation), sglang has to do an explicit.transpose(-1, -2).contiguous()on the scale tensor before each bpreshuffle GEMM, costing ~4us per call.This PR mirrors the same
transpose_scaleflag ontofused_flatten_fp8_group_quantso the consumer can opt in and skip that redundant copy. Companion sglang patch: Jacob0226/sglang:jacob/fused-transpose-scale (also opened upstream as sgl-project/sglang#24476; this aiter PR is its prerequisite).Technical Details
aiter/ops/triton/quant/fused_fp8_quant.py::fused_flatten_fp8_group_quant(Python wrapper only — inner_fused_flatten_fp8_group_quant_kernelis unchanged):transpose_scale: bool = False(default preserves existing behavior; default callers see no change).transpose_scale=True:out_block_scalesas(num_bs_cols, M)instead of(M, num_bs_cols).out_bs_row_stride / out_bs_col_stride(1, M) to the inner kernel so it writes column-major into the same buffer.view(M, num_bs_cols)to keep the public shape identical.transpose_scalehandling onfused_rms_fp8_group_quantin the same file; this is just propagating it to the flatten variant.Net diff: +33 / -4 lines, single function, single file.
Test Plan
op_tests/triton_tests/quant/test_fused_fp8_quant.pyto confirm:test_fused_flatten_fp8_group_quant(default path of the patched function) still passes — verifiestranspose_scale=Falsebehavior is unchanged.test_fused_rms_fp8_group_quantandtest_fused_rms_fp8_group_quant_transpose_scale(the function whosetranspose_scalehandling this patch mirrors) still pass — verifies no cross-function regression.test_fused_flatten_fp8_group_quant_transpose_scaleto the same file — a new pytest case mirroring the sistertest_fused_rms_fp8_group_quant_transpose_scale. Parametrized overM ∈ {1, 32, 256} × (N1, N2) ∈ {(16, 128)} × dtype ∈ {float16, bfloat16}(6 cases, same coverage as the existing defaulttest_fused_flatten_fp8_group_quant). For each case it:fused_flatten_fp8_group_quantwithtranspose_scale=Falseandtranspose_scale=True.(M, num_bs_cols)and are contiguous after the inner.view().scale_transposedis byte-equivalent toscale_default.transpose(0, 1).contiguous().view(*scale_default.shape)(the manual transform consumers like sglang'saiter_w8a8_block_fp8_linearperform on the bpreshuffle path).Hardware / software: MI355X (gfx950), ROCm 7.2, docker
rocm/sgl-dev:v0.5.10.post1-rocm720-mi35x-20260503.Test Result
Reproduce
op_tests/triton_tests/quant/test_fused_fp8_quant.py— 300 / 300 PASS on MI355X (27s)300 = 294 existing + 6 newly added. Highlights:
test_fused_flatten_fp8_group_quant(default path of the patched function): 6 / 6 PASS — confirmstranspose_scale=Falsebehavior is unchanged.test_fused_flatten_fp8_group_quant_transpose_scale(NEW, added in this PR): 6 / 6 PASS — confirmstranspose_scale=Trueproduces a scale tensor byte-equivalent todefault.transpose(0, 1).contiguous().view(...)and a bit-identical FP8 quantized tensor.test_fused_rms_fp8_group_quantandtest_fused_rms_fp8_group_quant_transpose_scale(the sister function this patch mirrors): 36 / 36 PASS — confirms no cross-function regression.End-to-end performance (via companion sglang PR sgl-project/sglang#24476)
fused_rms_fp8_group_quant(transpose_scale=True)) opt-in. Combined with the rms opt-in, total +3.1% throughput / -3.1% TPOT averaged across 10 ISL/OSL/concurrency points (1k1k, 8k1k × concurrency 4/8/16/32/64).Submission Checklist