Skip to content

[AMD] Skip redundant CatArrayBatchedCopy in GLM-5 NSA TileLang decode#24125

Open
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226:jacob/glm5-rocm-nsa-cat-skip
Open

[AMD] Skip redundant CatArrayBatchedCopy in GLM-5 NSA TileLang decode#24125
Jacob0226 wants to merge 2 commits intosgl-project:mainfrom
Jacob0226:jacob/glm5-rocm-nsa-cat-skip

Conversation

@Jacob0226
Copy link
Copy Markdown
Contributor

@Jacob0226 Jacob0226 commented Apr 30, 2026

Motivation

In GLM-5 NSA TileLang decode on ROCm, the fused-rope path dispatches a redundant CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel once per layer per decode step that rebuilds an already-existing tensor.

The cause: fused_qk_rope_cat_and_cache_mla produces a contiguous q_cat of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch flow then

  1. slices q_cat into q_nope_fused / q_pe_fused,
  2. passes them as separate args to attn_mqa,
  3. nsa_backend.forward_decode calls concat_mla_absorb_q_general(q_nope, q_rope) to rebuild q_all,
  4. on ROCm that helper falls back to torch.cat([q_nope, q_rope], dim=-1), dispatching CatArrayBatchedCopy to produce a tensor byte-identical to the q_cat we already had.

The cat is pure overhead — same data, same layout, fresh allocation + copy.

Modifications

(1) forward_absorb_core in deepseek_common/attention_forward_methods/forward_mla.py now passes q_cat directly to attn_mqa with q_rope=None on the decode path. Prefill keeps the split form because nsa_backend.forward_extend asserts q_rope is not None.

(2) forward_decode in layers/attention/nsa_backend.py now tracks q_all explicitly:

  • When the caller passes split q_nope / q_rope, q_all = None and each impl block re-cats as before — byte-identical to pre-patch behavior.
  • When the caller passes q_rope=None on HIP decode, q_all is set to a zero-copy q.contiguous().view(...) of q_cat and the cat is skipped.
  • The cat-skip is gated if q_all is None or not _is_hip so CUDA / MUSA / non-HIP backends always re-cat — those paths are bit-exact to pre-patch.

Accuracy Tests

GLM-5.1-FP8 launch cmd:

export SAFETENSORS_FAST_GPU=1
export SGLANG_ROCM_FUSED_DECODE_MLA=0
export ROCM_QUICK_REDUCE_QUANTIZATION=INT4
python3 -m sglang.launch_server \
  --model GLM-5.1-FP8 \
  --tp 8 --port 8552 \
  --tool-call-parser glm47 --reasoning-parser glm45 \
  --watchdog-timeout 1200 \
  --mem-fraction-static 0.85 \
  --kv-cache-dtype fp8_e4m3 --disable-radix-cache \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
  --nsa-prefill-backend tilelang --nsa-decode-backend tilelang

MI355X GSM8K (TP=8, 1200 questions, parallel=1200): 0.941

python3 /sgl-workspace/sglang/benchmark/gsm8k/bench_sglang.py \
    --port 8552 --num-questions 1200 --parallel 1200

Speed Tests and Profiling

Baseline: Thomas's PR #23562 (preshuffled paged MQA + page_size=64) applied on top of sgl-project/sglang main, plus aiter PR ROCm/aiter#2879. This PR is the delta on top of that baseline. The optimization is logically independent of #23562 — it only requires the _use_aiter_gfx95 + nsa_decode_backend == "tilelang" fast path which already exists on main — but #23562 is what we run as our perf-comparison reference.

Per-layer profiling (NSA TileLang decode on MI355X TP=8):

  • CatArrayBatchedCopy<OpaqueType<1u>, ...> invocations per decode step: 390 -> 0 (one per layer eliminated).
  • Saving: ~2.6 us / layer (constant; absolute, not relative).

Benchmark on MI355X TP=8, concurrency 4/8/16/32/64 averaged (sglang.bench_serving --dataset-name random --random-range-ratio 0.8):

Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative gain is highest at small batch + short prompt (where total layer time is smallest) and decays with batch size.

Bench cmd template:

python3 -m sglang.bench_serving \
    --host localhost --port 8552 \
    --model /data/huggingface/hub/zai-org/GLM-5.1-FP8 \
    --dataset-name random \
    --random-input <8192|1024> --random-output 1024 --random-range-ratio 0.8 \
    --max-concurrency <4|8|16|32|64> --num-prompt <10*cc> \
    --output-file /dev/null

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

GLM-5 NSA TileLang decode on ROCm dispatches a `CatArrayBatchedCopy` kernel
once per layer per decode step that rebuilds an already-existing tensor.
This is a strict-improvement bug fix: ~2.6 us / layer saved, 0 changes for
non-HIP backends.

==============================================================================
Root cause
==============================================================================

For the NSA TileLang fused-rope decode path (`_use_aiter_gfx95 + nsa +
nsa_decode_backend == "tilelang"`), `forward_absorb_core` calls
`fused_qk_rope_cat_and_cache_mla` which produces a contiguous q_cat tensor
of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch
flow then sliced q_cat into q_nope_fused / q_pe_fused and passed them as
separate args to attn_mqa.

attn_mqa -> NSABackend.forward_decode then takes the if-branch (q_rope
is not None), views the slices, and for tilelang / flashmla_sparse /
flashmla_kv / aiter decode impls calls
`concat_mla_absorb_q_general(q_nope, q_rope)` to rebuild q_all. On ROCm,
that helper falls back to `torch.cat([q_nope, q_rope], dim=-1)`, which
allocates a fresh contiguous tensor and dispatches a copy kernel. The
result is byte-identical to the q_cat we already had — the cat is pure
overhead.

==============================================================================
Fix
==============================================================================

(1) `forward_absorb_core` now passes q_cat directly to attn_mqa with
    q_rope=None on the decode path. Prefill (forward_extend) keeps the
    split form because `nsa_backend.forward_extend` asserts
    `q_rope is not None`.

(2) `nsa_backend.forward_decode` is updated to track q_all explicitly:

    - When the caller passes split q_nope / q_rope, q_all is initialized
      to None and each impl block re-cats as before (byte-identical to
      pre-patch behavior).
    - When the caller passes q_rope=None on HIP, q_all is set to a
      zero-copy `q.contiguous().view(...)` and the cat is skipped.

    The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
    backends always re-cat (preserves CUDA / MUSA paths bit-exactly).

==============================================================================
Validation
==============================================================================

MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA TileLang decode (on top of
PR sgl-project#23562 + aiter PR sgl-project#2879):

   scenario              | before    | after     | TPOT  Δ
   --------------------- | --------- | --------- | --------
   8k1k conc4   TPOT     | 21.21 ms  | 20.76 ms  | -2.17%
   8k1k conc8   TPOT     | 25.28 ms  | 24.82 ms  | -1.82%
   8k1k conc16  TPOT     | 30.79 ms  | 30.33 ms  | -1.49%
   8k1k conc32  TPOT     | 42.92 ms  | 42.46 ms  | -1.07%
   8k1k conc64  TPOT     | 61.79 ms  | 61.33 ms  | -0.74%
   1k1k conc4   TPOT     | 18.79 ms  | 18.33 ms  | -2.45%
   1k1k conc8   TPOT     | 21.14 ms  | 20.66 ms  | -2.27%
   1k1k conc16  TPOT     | 23.63 ms  | 23.15 ms  | -2.03%
   1k1k conc32  TPOT     | 29.19 ms  | 28.69 ms  | -1.71%
   1k1k conc64  TPOT     | 35.02 ms  | 34.60 ms  | -1.20%

Output throughput improves by the same percentage on every scenario.
Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative
gain is highest at small batch + short prompt (where total layer time is
smallest) and decays with batch size.

GSM8K accuracy: 0.942 vs 0.951 baseline (within run-to-run variance
observed across multiple runs of the same config: 0.946-0.953).

==============================================================================
Files
==============================================================================

  forward_mla.py    (+50 -16)  forward_absorb_core:_skip_rope_for_nsa_tilelang_fused
                                branch passes q_cat with q_rope=None for decode.
  nsa_backend.py    (+12  -4)  forward_decode tracks q_all and skips cat on HIP
                                when caller already provided concatenated q.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Jacob0226 Jacob0226 marked this pull request as ready for review April 30, 2026 08:08
@Jacob0226 Jacob0226 changed the title [AMD] perf(rocm-nsa): skip redundant CatArrayBatchedCopy in GLM-5 NSA TileLang decode [AMD] Skip redundant CatArrayBatchedCopy in GLM-5 NSA TileLang decode Apr 30, 2026
Jacob0226 added a commit to Jacob0226/SGLang-benchmarks that referenced this pull request Apr 30, 2026
…24125 viz

tools/dual_stream_regression.html (new):
- 3-trace SVG timeline visualizing one full layer-pass on MI355X
  baseline (1 stream), MI355X dual-stream (cur + alt), and B200
  dual-stream (4 lanes from CUDA-graph stream pool), all on the same
  4 px/us scale.
- Bubble overlays color-coded by gap kind: red diagonal stripes for
  >=2us real CPU stalls (visible only on MI355X dual-stream as 3
  distinct 5-6us regions), amber for 1-2us small fences, faint azure
  for sub-1us per-kernel dispatch latency on B200.
- Zoom toolbar (Fit / 1x / 1.5x / 2x) with horizontal scrollbar so the
  full trace fits one screen for live presentation, then zoomable for
  detailed inspection.
- Hover tooltips with full kernel name + duration + start time.
- Decomposition card with kind classification table, profiler-
  accounting caveat (CUPTI vs roctracer ts source comparison), and
  per-stall annotated table for MI355X dual-stream.

tools/dual_stream_regression_analysis.md (updated):
- New 'Bubble decomposition (kind classification)' section breaking
  the +18us / layer MI355X dual-stream bubble into sub-1us / 1-2us /
  >=2us gaps. 91% of it is real >=2us stalls (3 stalls per layer at
  5-6us each), proving they are HIP-graph fence cost, not generic
  launch overhead.
- Cross-platform validation: B200 GLM-5-FP8 dual-stream has 0 us of
  >=2us stalls and 0 us of 1-2us fences. CUDA-graph stream-pool replay
  doesn't put cross-stream synchronization in the critical path; HIP
  graph currently does.
- Profiler-accounting caveat: B200's 23.4us of sub-1us gaps reported
  by CUPTI is *not* a real GPU bubble. Verified from raw trace
  timestamps: B200 has 4,296 negative gaps (next-start_ts < prev-end_ts,
  smallest -6.66us), MI355X roctracer has 0. The CP->SM dispatch
  latency is similar on both platforms (~200-500ns); CUPTI exposes it
  as inter-kernel gap while roctracer absorbs it into kernel duration.
  The meaningful comparison is the >=2us real-stall column, where
  MI355X dual-stream has 16.5us and B200 dual-stream has 0us.

tools/pr24125_jacob-glm5-rocm-nsa-cat-skip.html (new):
- Visual companion to upstream PR sgl-project/sglang#24125 (cat-skip
  in NSA TileLang decode). Memory diagram + kernel pipeline + reviewer
  invariants + MI355X TPOT/throughput/GSM8K results, dark editorial
  layout for sharing with reviewers.
# "always cat" behavior to keep CUDA / MUSA paths byte-identical.
if self.nsa_decode_impl == "flashmla_sparse":
if q_rope is not None:
if q_all is None or not _is_hip:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can only make the change in amd side (e.g, tilelang/aiter backend).
This change may be never reached in nv code path (e.g, flashmla_sparse/flashmla_kv).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks! Done in 1f2b7c4.

Reverted the gate change for flashmla_sparse and flashmla_kv back to the pre-patch if q_rope is not None: pattern. Both impls import from sgl_kernel.flashmla_ops which requires CUDA driver, so on HIP they are unreachable and the new not _is_hip clause was a strict no-op there. On CUDA the behavior is byte-identical (gate is always True either way).

The cat-skip itself stays on the HIP-only tilelang and aiter blocks, which is what this PR is actually solving. Also moved the "Cat-skip is HIP-only ..." comment from above flashmla_sparse to above tilelang to match the new (narrower) scope.

Per @1am9trash review (PR sgl-project#24125 line 1587):
> I think we can only make the change in amd side (e.g, tilelang/aiter
> backend). This change may be never reached in nv code path (e.g,
> flashmla_sparse/flashmla_kv).

Revert the gate change for `flashmla_sparse` and `flashmla_kv` decode
impls back to the pre-patch `if q_rope is not None:` pattern. Both
impls import from `sgl_kernel.flashmla_ops`, which requires CUDA
driver, so they are unreachable on HIP and the new `not _is_hip`
clause was a strict no-op there. Functionally identical on CUDA
(byte-for-byte), just smaller and more focused diff.

The cat-skip itself stays on the HIP-only `tilelang` and `aiter`
decode impls, which is what the PR is actually solving. Moved the
"Cat-skip is HIP-only ..." comment from above flashmla_sparse to
above tilelang to match the new (narrower) scope.
Copy link
Copy Markdown
Collaborator

@1am9trash 1am9trash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@kkHuang-amd kkHuang-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented May 5, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 5, 2026
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request May 6, 2026
…-transposed

PR sgl-project#24125 introduced .transpose(-1,-2).contiguous() on x_scale before every
bpreshuffle CK GEMM call (fp8_utils.py:784) costing ~4.165 us per GEMM.

aiter fused_rms_fp8_group_quant already supports transpose_scale=True (writes
the scale directly in (N//128, M) column-major layout). All 7 sglang upstream
call sites (3 in forward_mha, 2 in forward_mla, 2 in communicator) now opt in
and tag the returned scale tensor with _aiter_bpreshuffle_layout=True so the
GEMM dispatch in fp8_utils.py knows it can skip the redundant copy.

Logic is symmetric: if downstream wants triton (default layout) but upstream
wrote transposed, we still apply the transpose (so triton path stays correct).
Net: bpreshuffle GEMMs with marker save 4 us each; triton GEMMs with marker
pay 4 us each (untranspose). For GLM-5.1-FP8 trace which shows all 5 MLA
GEMMs hit bpreshuffle, this saves ~20 us per attention block.

Verified bit-exact (within CK GEMM run-to-run noise of ~4) via direct
fp8_utils call comparison.
@Jacob0226
Copy link
Copy Markdown
Contributor Author

Quick CI triage: none of the failing jobs exercise this PR's code path.

The 6 unique failures are on gpt-oss-20b, Qwen3-Reranker-0.6B, Llama-3.1-8B-Instruct, DeepSeek-OCR (XPU), LLaDA2-Mini (NPU), and Wan2.2 T2V (NPU) — none of which use MLA + NSA TileLang decode on gfx95, which is the only path this PR touches. The cat-skip is also gated if q_all is None or not _is_hip, so CUDA / MUSA / XPU paths are byte-identical to pre-patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants