Skip to content

[AMD] Skip x_scale.transpose+contiguous before bpreshuffle GEMM via upstream pre-transposed scale#24476

Open
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226:jacob/fused-transpose-scale
Open

[AMD] Skip x_scale.transpose+contiguous before bpreshuffle GEMM via upstream pre-transposed scale#24476
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226:jacob/fused-transpose-scale

Conversation

@Jacob0226
Copy link
Copy Markdown
Contributor

@Jacob0226 Jacob0226 commented May 6, 2026

Depends on aiter PR ROCm/aiter#3041 — must be merged first.

Motivation

#23319 ([AMD] Use bpreshuffle FP8 blockscale GEMM to replace ABScale GEMM, merged 2026-04-23) introduced a per-call x_scale.transpose(-1, -2).contiguous().view(*x_scale.shape) on the bpreshuffle path of aiter_w8a8_block_fp8_linear (fp8_utils.py:784). On GLM-5.1-FP8 / MI355X / TP=8, this dispatches an at::native::elementwise_kernel_manual_unroll<direct_copy_kernel_cuda> kernel before EVERY bpreshuffle CK GEMM, costing ~4.165 µs per call.

Trace from DeepseekV2AttentionMLA:

ReplicatedLinear (qkv_a_proj)         direct_copy 4.25 + CK GEMM 9.85   →  2 kernels
ColumnParallelLinear (q_b_proj)       direct_copy 4.25 + CK GEMM 9.85   →  2 kernels
Indexer > ReplicatedLinear (wk)       direct_copy 4.25 + CK GEMM 9.85   →  2 kernels
RowParallelLinear (o_proj)            direct_copy 4.25 + CK GEMM 9.85   →  2 kernels
                                                                           =====
                                                                           4 redundant copies / layer

The aiter side already supports a transpose_scale=True flag on fused_rms_fp8_group_quant that writes the scale tensor directly in (num_bs_cols, M) column-major layout — exactly what the bpreshuffle GEMM consumes. SGLang upstream callers were just not using it. This PR opts them in and tags the returned scale tensor with a _aiter_bpreshuffle_layout=True attribute marker so the GEMM dispatch in fp8_utils.py can skip the redundant transpose+copy.

Modifications

image

The optimization is a producer/consumer pair connected by a tensor attribute marker _aiter_bpreshuffle_layout.

  1. Consumerfp8_utils.py::aiter_w8a8_block_fp8_linear gates the transpose+contiguous on the marker. If the marker matches the GEMM's expected layout (bpreshuffle ↔ transposed), skip the copy. Callers that do not set the marker get the pre-patch behavior (backward compatible).

  2. Producer — updates seven call sites of fused_rms_fp8_group_quant (2 in communicator.py, 2 in forward_mla.py, 3 in forward_mha.py) to pass transpose_scale=True and set result_tuple[1]._aiter_bpreshuffle_layout = True on the returned scale tensor.

  3. Producer (o_proj path)forward_mla.py::forward_absorb_core calls fused_flatten_fp8_group_quant(transpose_scale=True) and tags the returned scale the same way. This requires the companion aiter PR [Triton] fused_flatten_fp8_group_quant: add transpose_scale param ROCm/aiter#3041.

Validation

Trace verification (MI355X TP=8 GLM-5.1-FP8 NSA TileLang decode, prof_in8192_out1024_conc4_p8 / TP-0-DECODE)

DeepseekV2AttentionMLA section — number of direct_copy kernels per layer:

Baseline:  4 / 4  GEMMs hit direct_copy
This PR:   0 / 4  (all four direct_copy kernels eliminated)

DeepseekV2AttentionMLA total time per layer:

Baseline:  ~177 µs
This PR:   157.10 µs   (–19.9 µs)

Accuracy Tests

GLM-5.1-FP8 launch cmd:

export SAFETENSORS_FAST_GPU=1
export SGLANG_ROCM_FUSED_DECODE_MLA=0
export ROCM_QUICK_REDUCE_QUANTIZATION=INT4
python3 -m sglang.launch_server \
  --model GLM-5.1-FP8 \
  --tp 8 --port 8552 \
  --tool-call-parser glm47 --reasoning-parser glm45 \
  --watchdog-timeout 1200 \
  --mem-fraction-static 0.85 \
  --kv-cache-dtype fp8_e4m3 --disable-radix-cache \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
  --nsa-prefill-backend tilelang --nsa-decode-backend tilelang

MI355X GSM8K (TP=8, 1200 questions, parallel=1200), measured against the same baseline as the Speed Tests section below:

Baseline:  0.948
This PR:   0.943

Within run-to-run variance (~±1%); consistent with the optimization being byte-equivalent at the GEMM level.

Speed Tests and Profiling

Baseline: sgl-project/sglang main + #23562 (preshuffled paged MQA + page_size=64, by @1am9trash, still open) + #24125 (cat-skip, by @Jacob0226, has 2 approvals). Both are NSA TileLang prerequisites for GLM-5.1-FP8 on MI355X — the model + GPU combo this PR targets. Docker image: rocm/sgl-dev:v0.5.10.post1-rocm720-mi35x-20260503. The optimization itself is logically independent of both #23562 and #24125; the numbers would shift slightly without them but the direction holds.

Bench cmd: sglang.bench_serving --dataset-name random --random-range-ratio 0.8 --random-input <8192|1024> --random-output 1024 --max-concurrency <C> --num-prompt 10*C --output-file /dev/null

vs baseline (10 ISL/OSL/concurrency points)

Workload Cnc Δ TPUT Δ TPOT
in1024_out1024 4 +3.83% -3.84%
in1024_out1024 8 +3.57% -3.55%
in1024_out1024 16 +3.91% -3.83%
in1024_out1024 32 +3.51% -3.41%
in1024_out1024 64 +3.19% -3.07%
in8192_out1024 4 +3.57% -3.57%
in8192_out1024 8 +2.65% -2.63%
in8192_out1024 16 +2.73% -2.99%
in8192_out1024 32 +2.49% -2.42%
in8192_out1024 64 +1.66% -1.58%
  • Headline: throughput +3.1% / TPOT -3.1% (faster) averaged across all 10 points.
  • Best gain at small input + high concurrency (decode-bound).
  • Smallest gain at large input + high concurrency (prefill-heavy).

Dependencies

Checklist

  • Format your code according to the Format code with pre-commit (pre-commit run --files <changed files>: all hooks pass).
  • Add unit tests according to the Run and add unit tests.
  • Update documentation according to Write documentations.
  • Provide accuracy and speed benchmark results according to Test the accuracy and Benchmark the speed.
  • Follow the SGLang code style guidance.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Jacob0226 Jacob0226 force-pushed the jacob/fused-transpose-scale branch 2 times, most recently from cf91580 to 1f7288a Compare May 6, 2026 07:29
Jacob0226 added 2 commits May 6, 2026 07:37
…-transposed

sgl-project#23319 ([AMD] Use bpreshuffle FP8 blockscale GEMM to replace ABScale GEMM)
introduced .transpose(-1,-2).contiguous() on x_scale before every bpreshuffle
CK GEMM call (fp8_utils.py:784) costing ~4.165 us per GEMM.

aiter fused_rms_fp8_group_quant already supports transpose_scale=True (writes
the scale directly in (N//128, M) column-major layout). All 7 sglang upstream
call sites (3 in forward_mha, 2 in forward_mla, 2 in communicator) now opt in
and tag the returned scale tensor with _aiter_bpreshuffle_layout=True so the
GEMM dispatch in fp8_utils.py knows it can skip the redundant copy.

Logic is symmetric: if downstream wants triton (default layout) but upstream
wrote transposed, we still apply the transpose (so triton path stays correct).
Net: bpreshuffle GEMMs with marker save 4 us each; triton GEMMs with marker
pay 4 us each (untranspose). For GLM-5.1-FP8 trace which shows all 5 MLA
GEMMs hit bpreshuffle, this saves ~20 us per attention block.

Verified bit-exact (within CK GEMM run-to-run noise of ~4) via direct
fp8_utils call comparison.
The previous commit covered fused_rms_fp8_group_quant (rms+quant for q_a/kv_a,
qkv_a_proj, wk inputs). The o_proj path uses fused_flatten_fp8_group_quant
which did not support transpose_scale, so the trace still showed a direct_copy
4.2us/layer for o_proj. With this change + an aiter-side patch adding
transpose_scale to fused_flatten_fp8_group_quant, that last direct_copy is
also gone.

Also fixes a latent bug in the previous commit's aiter_w8a8_block_fp8_linear
logic. The 'transpose-contig-view' trick is NOT a self-inverse, so when an
upstream marker indicates pre-transposed scale but the GEMM dispatch picks
triton (which wants default layout), the original code did 'apply
transpose-contig-view to undo' which actually scrambles data instead. The new
logic uses the proper inverse 'x_scale.view(num_bs_cols, M).transpose(0, 1)
.contiguous()' for that path. The previous commit's E2E result was still
correct because GLM-5.1-FP8 runtime never picks the triton path for any of
the 5 MLA GEMMs - all dispatch to CK bpreshuffle in trace - but defensive
correctness matters for future shapes and models.

Companion aiter PR (ROCm/aiter#3041): fused_flatten_fp8_group_quant gains a
transpose_scale: bool = False parameter, mirroring fused_rms_fp8_group_quant.
Implementation swaps stride params passed to the inner
_fused_flatten_fp8_group_quant_kernel and views back to shape (M, num_bs_cols)
at the end. Inner triton kernel needs no changes.

E2E impact on GLM-5.1-FP8 MI355X TP=8 averaged across 10 configs:
- Throughput: +3.1% vs baseline
- Median TPOT: -3.1% (faster) vs baseline
- GSM8K: 0.943 (within noise of baseline 0.948)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants