[AMD] Skip x_scale.transpose+contiguous before bpreshuffle GEMM via upstream pre-transposed scale#24476
Open
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Open
[AMD] Skip x_scale.transpose+contiguous before bpreshuffle GEMM via upstream pre-transposed scale#24476Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
cf91580 to
1f7288a
Compare
…-transposed sgl-project#23319 ([AMD] Use bpreshuffle FP8 blockscale GEMM to replace ABScale GEMM) introduced .transpose(-1,-2).contiguous() on x_scale before every bpreshuffle CK GEMM call (fp8_utils.py:784) costing ~4.165 us per GEMM. aiter fused_rms_fp8_group_quant already supports transpose_scale=True (writes the scale directly in (N//128, M) column-major layout). All 7 sglang upstream call sites (3 in forward_mha, 2 in forward_mla, 2 in communicator) now opt in and tag the returned scale tensor with _aiter_bpreshuffle_layout=True so the GEMM dispatch in fp8_utils.py knows it can skip the redundant copy. Logic is symmetric: if downstream wants triton (default layout) but upstream wrote transposed, we still apply the transpose (so triton path stays correct). Net: bpreshuffle GEMMs with marker save 4 us each; triton GEMMs with marker pay 4 us each (untranspose). For GLM-5.1-FP8 trace which shows all 5 MLA GEMMs hit bpreshuffle, this saves ~20 us per attention block. Verified bit-exact (within CK GEMM run-to-run noise of ~4) via direct fp8_utils call comparison.
The previous commit covered fused_rms_fp8_group_quant (rms+quant for q_a/kv_a, qkv_a_proj, wk inputs). The o_proj path uses fused_flatten_fp8_group_quant which did not support transpose_scale, so the trace still showed a direct_copy 4.2us/layer for o_proj. With this change + an aiter-side patch adding transpose_scale to fused_flatten_fp8_group_quant, that last direct_copy is also gone. Also fixes a latent bug in the previous commit's aiter_w8a8_block_fp8_linear logic. The 'transpose-contig-view' trick is NOT a self-inverse, so when an upstream marker indicates pre-transposed scale but the GEMM dispatch picks triton (which wants default layout), the original code did 'apply transpose-contig-view to undo' which actually scrambles data instead. The new logic uses the proper inverse 'x_scale.view(num_bs_cols, M).transpose(0, 1) .contiguous()' for that path. The previous commit's E2E result was still correct because GLM-5.1-FP8 runtime never picks the triton path for any of the 5 MLA GEMMs - all dispatch to CK bpreshuffle in trace - but defensive correctness matters for future shapes and models. Companion aiter PR (ROCm/aiter#3041): fused_flatten_fp8_group_quant gains a transpose_scale: bool = False parameter, mirroring fused_rms_fp8_group_quant. Implementation swaps stride params passed to the inner _fused_flatten_fp8_group_quant_kernel and views back to shape (M, num_bs_cols) at the end. Inner triton kernel needs no changes. E2E impact on GLM-5.1-FP8 MI355X TP=8 averaged across 10 configs: - Throughput: +3.1% vs baseline - Median TPOT: -3.1% (faster) vs baseline - GSM8K: 0.943 (within noise of baseline 0.948)
1f7288a to
1e1a1e9
Compare
1 task
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Depends on aiter PR ROCm/aiter#3041 — must be merged first.
Motivation
#23319 ([AMD] Use bpreshuffle FP8 blockscale GEMM to replace ABScale GEMM, merged 2026-04-23) introduced a per-call
x_scale.transpose(-1, -2).contiguous().view(*x_scale.shape)on the bpreshuffle path ofaiter_w8a8_block_fp8_linear(fp8_utils.py:784). On GLM-5.1-FP8 / MI355X / TP=8, this dispatches anat::native::elementwise_kernel_manual_unroll<direct_copy_kernel_cuda>kernel before EVERY bpreshuffle CK GEMM, costing ~4.165 µs per call.Trace from DeepseekV2AttentionMLA:
The aiter side already supports a
transpose_scale=Trueflag onfused_rms_fp8_group_quantthat writes the scale tensor directly in(num_bs_cols, M)column-major layout — exactly what the bpreshuffle GEMM consumes. SGLang upstream callers were just not using it. This PR opts them in and tags the returned scale tensor with a_aiter_bpreshuffle_layout=Trueattribute marker so the GEMM dispatch infp8_utils.pycan skip the redundant transpose+copy.Modifications
The optimization is a producer/consumer pair connected by a tensor attribute marker
_aiter_bpreshuffle_layout.Consumer —
fp8_utils.py::aiter_w8a8_block_fp8_lineargates the transpose+contiguous on the marker. If the marker matches the GEMM's expected layout (bpreshuffle ↔ transposed), skip the copy. Callers that do not set the marker get the pre-patch behavior (backward compatible).Producer — updates seven call sites of
fused_rms_fp8_group_quant(2 incommunicator.py, 2 inforward_mla.py, 3 inforward_mha.py) to passtranspose_scale=Trueand setresult_tuple[1]._aiter_bpreshuffle_layout = Trueon the returned scale tensor.Producer (o_proj path) —
forward_mla.py::forward_absorb_corecallsfused_flatten_fp8_group_quant(transpose_scale=True)and tags the returned scale the same way. This requires the companion aiter PR [Triton] fused_flatten_fp8_group_quant: add transpose_scale param ROCm/aiter#3041.Validation
Trace verification (MI355X TP=8 GLM-5.1-FP8 NSA TileLang decode, prof_in8192_out1024_conc4_p8 / TP-0-DECODE)
DeepseekV2AttentionMLA section — number of
direct_copykernels per layer:DeepseekV2AttentionMLA total time per layer:
Accuracy Tests
GLM-5.1-FP8 launch cmd:
MI355X GSM8K (TP=8, 1200 questions, parallel=1200), measured against the same baseline as the Speed Tests section below:
Within run-to-run variance (~±1%); consistent with the optimization being byte-equivalent at the GEMM level.
Speed Tests and Profiling
Baseline:
sgl-project/sglangmain + #23562 (preshuffled paged MQA + page_size=64, by @1am9trash, still open) + #24125 (cat-skip, by @Jacob0226, has 2 approvals). Both are NSA TileLang prerequisites for GLM-5.1-FP8 on MI355X — the model + GPU combo this PR targets. Docker image:rocm/sgl-dev:v0.5.10.post1-rocm720-mi35x-20260503. The optimization itself is logically independent of both #23562 and #24125; the numbers would shift slightly without them but the direction holds.Bench cmd:
sglang.bench_serving --dataset-name random --random-range-ratio 0.8 --random-input <8192|1024> --random-output 1024 --max-concurrency <C> --num-prompt 10*C --output-file /dev/nullvs baseline (10 ISL/OSL/concurrency points)
Dependencies
o_projfused_flatten_fp8_group_quantopt-in. Without it, that call site will TypeError at runtime.Checklist
pre-commit run --files <changed files>: all hooks pass).