[AMD] Skip redundant CatArrayBatchedCopy in GLM-5 NSA TileLang decode#24125
[AMD] Skip redundant CatArrayBatchedCopy in GLM-5 NSA TileLang decode#24125Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Conversation
GLM-5 NSA TileLang decode on ROCm dispatches a `CatArrayBatchedCopy` kernel
once per layer per decode step that rebuilds an already-existing tensor.
This is a strict-improvement bug fix: ~2.6 us / layer saved, 0 changes for
non-HIP backends.
==============================================================================
Root cause
==============================================================================
For the NSA TileLang fused-rope decode path (`_use_aiter_gfx95 + nsa +
nsa_decode_backend == "tilelang"`), `forward_absorb_core` calls
`fused_qk_rope_cat_and_cache_mla` which produces a contiguous q_cat tensor
of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch
flow then sliced q_cat into q_nope_fused / q_pe_fused and passed them as
separate args to attn_mqa.
attn_mqa -> NSABackend.forward_decode then takes the if-branch (q_rope
is not None), views the slices, and for tilelang / flashmla_sparse /
flashmla_kv / aiter decode impls calls
`concat_mla_absorb_q_general(q_nope, q_rope)` to rebuild q_all. On ROCm,
that helper falls back to `torch.cat([q_nope, q_rope], dim=-1)`, which
allocates a fresh contiguous tensor and dispatches a copy kernel. The
result is byte-identical to the q_cat we already had — the cat is pure
overhead.
==============================================================================
Fix
==============================================================================
(1) `forward_absorb_core` now passes q_cat directly to attn_mqa with
q_rope=None on the decode path. Prefill (forward_extend) keeps the
split form because `nsa_backend.forward_extend` asserts
`q_rope is not None`.
(2) `nsa_backend.forward_decode` is updated to track q_all explicitly:
- When the caller passes split q_nope / q_rope, q_all is initialized
to None and each impl block re-cats as before (byte-identical to
pre-patch behavior).
- When the caller passes q_rope=None on HIP, q_all is set to a
zero-copy `q.contiguous().view(...)` and the cat is skipped.
The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
backends always re-cat (preserves CUDA / MUSA paths bit-exactly).
==============================================================================
Validation
==============================================================================
MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA TileLang decode (on top of
PR sgl-project#23562 + aiter PR sgl-project#2879):
scenario | before | after | TPOT Δ
--------------------- | --------- | --------- | --------
8k1k conc4 TPOT | 21.21 ms | 20.76 ms | -2.17%
8k1k conc8 TPOT | 25.28 ms | 24.82 ms | -1.82%
8k1k conc16 TPOT | 30.79 ms | 30.33 ms | -1.49%
8k1k conc32 TPOT | 42.92 ms | 42.46 ms | -1.07%
8k1k conc64 TPOT | 61.79 ms | 61.33 ms | -0.74%
1k1k conc4 TPOT | 18.79 ms | 18.33 ms | -2.45%
1k1k conc8 TPOT | 21.14 ms | 20.66 ms | -2.27%
1k1k conc16 TPOT | 23.63 ms | 23.15 ms | -2.03%
1k1k conc32 TPOT | 29.19 ms | 28.69 ms | -1.71%
1k1k conc64 TPOT | 35.02 ms | 34.60 ms | -1.20%
Output throughput improves by the same percentage on every scenario.
Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative
gain is highest at small batch + short prompt (where total layer time is
smallest) and decays with batch size.
GSM8K accuracy: 0.942 vs 0.951 baseline (within run-to-run variance
observed across multiple runs of the same config: 0.946-0.953).
==============================================================================
Files
==============================================================================
forward_mla.py (+50 -16) forward_absorb_core:_skip_rope_for_nsa_tilelang_fused
branch passes q_cat with q_rope=None for decode.
nsa_backend.py (+12 -4) forward_decode tracks q_all and skips cat on HIP
when caller already provided concatenated q.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
…24125 viz tools/dual_stream_regression.html (new): - 3-trace SVG timeline visualizing one full layer-pass on MI355X baseline (1 stream), MI355X dual-stream (cur + alt), and B200 dual-stream (4 lanes from CUDA-graph stream pool), all on the same 4 px/us scale. - Bubble overlays color-coded by gap kind: red diagonal stripes for >=2us real CPU stalls (visible only on MI355X dual-stream as 3 distinct 5-6us regions), amber for 1-2us small fences, faint azure for sub-1us per-kernel dispatch latency on B200. - Zoom toolbar (Fit / 1x / 1.5x / 2x) with horizontal scrollbar so the full trace fits one screen for live presentation, then zoomable for detailed inspection. - Hover tooltips with full kernel name + duration + start time. - Decomposition card with kind classification table, profiler- accounting caveat (CUPTI vs roctracer ts source comparison), and per-stall annotated table for MI355X dual-stream. tools/dual_stream_regression_analysis.md (updated): - New 'Bubble decomposition (kind classification)' section breaking the +18us / layer MI355X dual-stream bubble into sub-1us / 1-2us / >=2us gaps. 91% of it is real >=2us stalls (3 stalls per layer at 5-6us each), proving they are HIP-graph fence cost, not generic launch overhead. - Cross-platform validation: B200 GLM-5-FP8 dual-stream has 0 us of >=2us stalls and 0 us of 1-2us fences. CUDA-graph stream-pool replay doesn't put cross-stream synchronization in the critical path; HIP graph currently does. - Profiler-accounting caveat: B200's 23.4us of sub-1us gaps reported by CUPTI is *not* a real GPU bubble. Verified from raw trace timestamps: B200 has 4,296 negative gaps (next-start_ts < prev-end_ts, smallest -6.66us), MI355X roctracer has 0. The CP->SM dispatch latency is similar on both platforms (~200-500ns); CUPTI exposes it as inter-kernel gap while roctracer absorbs it into kernel duration. The meaningful comparison is the >=2us real-stall column, where MI355X dual-stream has 16.5us and B200 dual-stream has 0us. tools/pr24125_jacob-glm5-rocm-nsa-cat-skip.html (new): - Visual companion to upstream PR sgl-project/sglang#24125 (cat-skip in NSA TileLang decode). Memory diagram + kernel pipeline + reviewer invariants + MI355X TPOT/throughput/GSM8K results, dark editorial layout for sharing with reviewers.
| # "always cat" behavior to keep CUDA / MUSA paths byte-identical. | ||
| if self.nsa_decode_impl == "flashmla_sparse": | ||
| if q_rope is not None: | ||
| if q_all is None or not _is_hip: |
There was a problem hiding this comment.
I think we can only make the change in amd side (e.g, tilelang/aiter backend).
This change may be never reached in nv code path (e.g, flashmla_sparse/flashmla_kv).
There was a problem hiding this comment.
Good catch, thanks! Done in 1f2b7c4.
Reverted the gate change for flashmla_sparse and flashmla_kv back to the pre-patch if q_rope is not None: pattern. Both impls import from sgl_kernel.flashmla_ops which requires CUDA driver, so on HIP they are unreachable and the new not _is_hip clause was a strict no-op there. On CUDA the behavior is byte-identical (gate is always True either way).
The cat-skip itself stays on the HIP-only tilelang and aiter blocks, which is what this PR is actually solving. Also moved the "Cat-skip is HIP-only ..." comment from above flashmla_sparse to above tilelang to match the new (narrower) scope.
Per @1am9trash review (PR sgl-project#24125 line 1587): > I think we can only make the change in amd side (e.g, tilelang/aiter > backend). This change may be never reached in nv code path (e.g, > flashmla_sparse/flashmla_kv). Revert the gate change for `flashmla_sparse` and `flashmla_kv` decode impls back to the pre-patch `if q_rope is not None:` pattern. Both impls import from `sgl_kernel.flashmla_ops`, which requires CUDA driver, so they are unreachable on HIP and the new `not _is_hip` clause was a strict no-op there. Functionally identical on CUDA (byte-for-byte), just smaller and more focused diff. The cat-skip itself stays on the HIP-only `tilelang` and `aiter` decode impls, which is what the PR is actually solving. Moved the "Cat-skip is HIP-only ..." comment from above flashmla_sparse to above tilelang to match the new (narrower) scope.
|
/tag-and-rerun-ci |
…-transposed PR sgl-project#24125 introduced .transpose(-1,-2).contiguous() on x_scale before every bpreshuffle CK GEMM call (fp8_utils.py:784) costing ~4.165 us per GEMM. aiter fused_rms_fp8_group_quant already supports transpose_scale=True (writes the scale directly in (N//128, M) column-major layout). All 7 sglang upstream call sites (3 in forward_mha, 2 in forward_mla, 2 in communicator) now opt in and tag the returned scale tensor with _aiter_bpreshuffle_layout=True so the GEMM dispatch in fp8_utils.py knows it can skip the redundant copy. Logic is symmetric: if downstream wants triton (default layout) but upstream wrote transposed, we still apply the transpose (so triton path stays correct). Net: bpreshuffle GEMMs with marker save 4 us each; triton GEMMs with marker pay 4 us each (untranspose). For GLM-5.1-FP8 trace which shows all 5 MLA GEMMs hit bpreshuffle, this saves ~20 us per attention block. Verified bit-exact (within CK GEMM run-to-run noise of ~4) via direct fp8_utils call comparison.
|
Quick CI triage: none of the failing jobs exercise this PR's code path. The 6 unique failures are on |
Motivation
In GLM-5 NSA TileLang decode on ROCm, the fused-rope path dispatches a redundant
CatArrayBatchedCopy<OpaqueType<1u>, ...>kernel once per layer per decode step that rebuilds an already-existing tensor.The cause:
fused_qk_rope_cat_and_cache_mlaproduces a contiguousq_catof shape(M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch flow thenq_catintoq_nope_fused/q_pe_fused,attn_mqa,nsa_backend.forward_decodecallsconcat_mla_absorb_q_general(q_nope, q_rope)to rebuildq_all,torch.cat([q_nope, q_rope], dim=-1), dispatchingCatArrayBatchedCopyto produce a tensor byte-identical to theq_catwe already had.The cat is pure overhead — same data, same layout, fresh allocation + copy.
Modifications
(1)
forward_absorb_coreindeepseek_common/attention_forward_methods/forward_mla.pynow passesq_catdirectly toattn_mqawithq_rope=Noneon the decode path. Prefill keeps the split form becausensa_backend.forward_extendassertsq_rope is not None.(2)
forward_decodeinlayers/attention/nsa_backend.pynow tracksq_allexplicitly:q_nope/q_rope,q_all = Noneand each impl block re-cats as before — byte-identical to pre-patch behavior.q_rope=Noneon HIP decode,q_allis set to a zero-copyq.contiguous().view(...)ofq_catand the cat is skipped.if q_all is None or not _is_hipso CUDA / MUSA / non-HIP backends always re-cat — those paths are bit-exact to pre-patch.Accuracy Tests
GLM-5.1-FP8 launch cmd:
MI355X GSM8K (TP=8, 1200 questions, parallel=1200): 0.941
python3 /sgl-workspace/sglang/benchmark/gsm8k/bench_sglang.py \ --port 8552 --num-questions 1200 --parallel 1200Speed Tests and Profiling
Per-layer profiling (NSA TileLang decode on MI355X TP=8):
CatArrayBatchedCopy<OpaqueType<1u>, ...>invocations per decode step: 390 -> 0 (one per layer eliminated).Benchmark on MI355X TP=8, concurrency 4/8/16/32/64 averaged (
sglang.bench_serving --dataset-name random --random-range-ratio 0.8):Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative gain is highest at small batch + short prompt (where total layer time is smallest) and decays with batch size.
Bench cmd template:
python3 -m sglang.bench_serving \ --host localhost --port 8552 \ --model /data/huggingface/hub/zai-org/GLM-5.1-FP8 \ --dataset-name random \ --random-input <8192|1024> --random-output 1024 --random-range-ratio 0.8 \ --max-concurrency <4|8|16|32|64> --num-prompt <10*cc> \ --output-file /dev/nullChecklist
pre-commit run --files <changed files>: all hooks pass).Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci