Skip to content

Support preshuffled layout in indexer_k_quant_and_cache / cp_gather_indexer_k_quant_cache#2879

Merged
valarLip merged 5 commits intoROCm:mainfrom
1am9trash:support-indexer_k_quant_and_cache_kernel-with-preshuffle
Apr 29, 2026
Merged

Support preshuffled layout in indexer_k_quant_and_cache / cp_gather_indexer_k_quant_cache#2879
valarLip merged 5 commits intoROCm:mainfrom
1am9trash:support-indexer_k_quant_and_cache_kernel-with-preshuffle

Conversation

@1am9trash
Copy link
Copy Markdown
Member

Motivation

In sglang PR#23562, we bump the nsa indexer's page_size to 64, which switches the indexer attention path to the _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle kernel. The kernel consumes the indexer k cache in an MFMA 16×16 preshuffled layout.

Technical Details

This PR extends indexer_k_quant_and_cache and cp_gather_indexer_k_quant_cache to support writing and gathering the preshuffled layout.

We add an optional preshuffle: bool = False argument to both ops. Default is False, so existing callers are unaffected.

  • preshuffle=False (default): unchanged row-major [block_size, head_dim] layout inside each paged block.
  • preshuffle=True: each block's k region is written / read as a sequence of MFMA 16×16 tiles.

Also extend the testcase with -p/--preshuffle flag.

Test Plan

Test Result

Submission Checklist

@1am9trash 1am9trash requested a review from a team April 23, 2026 10:05
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2879 --add-label <label>

@valarLip
Copy link
Copy Markdown
Collaborator

could you please make the ci pass...

@1am9trash
Copy link
Copy Markdown
Member Author

1am9trash commented Apr 29, 2026

Apologies for the oversight! I've just pushed a commit fixing the ruff and black format errors.
CI is running now — I'll follow up to confirm it goes green.

Update: all ci pass

Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 29, 2026
Squashed implementation of three HIP-only optimizations that together
shrink the GLM-5-FP8 NSA tilelang decode layer on MI355X from ~397 us
to ~324 us (-73 us / -18.4%, MI355X TP=8 fp8 KV cache).

==============================================================================
1. fix(rocm): restore `_is_hip` in DeepseekV2Model.alt_stream creation
==============================================================================

Commit a1ceb2e ("[AMD] Enable MoE dual stream overlap on HIP for GLM4/GLM5")
added `_is_hip` to the alt_stream gate. The MUSA backend PR b35213b
("[MUSA][16/N] Add MUSA backend support for layers and DeepSeek models")
was branched off a parent that did not contain a1ceb2e, and on merge
inadvertently dropped `_is_hip` while adding `_is_musa`. Result: on ROCm
`self.alt_stream is None`, so `forward_normal_dual_stream` and the MLA
dual-stream fork are never entered — decode traces show only one
physical stream.

This commit restores `_is_hip` alongside `_is_musa` and re-applies the
`not _use_aiter` guard in `forward_normal_dual_stream`'s
routed_scaling_factor multiply (aiter's biased_grouped_topk already
fuses the scaling, so multiplying again would double it).

Both changes are HIP-only: CUDA / MUSA / NPU branches are unaffected.

==============================================================================
2. perf(rocm-nsa): A_v4 dual-stream layout in forward_absorb_prepare
==============================================================================

Refactor the q_b_proj / NSA-indexer dual-stream fork in
DeepseekMLAForwardMixin.forward_absorb_prepare so that on HIP the
indexer chain on alt overlaps not just with q_b_proj but also with the
gap-fill that follows on cur (bmm w_kc absorb + rotary_emb on q_pe/k_pe,
plus fused_qk_rope_cat_and_cache_mla on the gfx95 NSA tilelang path).

Two HIP-graph capture rules drive the layout (validated by the
microbenchmark in SGLang-benchmarks/tools/glm5_proposalA_v3_test.py
variant A_v4: -18.9 us/layer over the prior layout):

  1. Dispatch order picks the physical stream — the branch dispatched
     first at the fork keeps the predecessor stream (phys 0); the
     later-dispatched branch lands on a fresh aux stream (phys 4).
     We dispatch q_b_proj on cur FIRST and only afterwards enter
     `with stream(alt):` for the indexer.
  2. `alt.wait_stream(cur)` snapshots cur's state at call time. Since
     the indexer needs only q_lora (phase1 output), placing wait_stream
     BEFORE q_b_proj lets alt's heavy indexer chain start the instant
     phase1 completes — in parallel with cur's q_b_proj plus gap-fill,
     instead of waiting for q_b_proj first.

The `cur.wait_stream(alt)` join is moved past rotary_emb so cur's
gap-fill chain overlaps with alt's indexer.

CUDA / MUSA / NPU paths are gated to keep the original PR sgl-project#23562 layout
(byte-identical) — these were not validated under the new schedule.

Drives `overlap_indexer_with_gap_fill` flag used by sub-optimization (3).

==============================================================================
3. perf(rocm-nsa): pull fused_qk_rope_cat_and_cache_mla into the dual-stream
   window, and skip the redundant CatArrayBatchedCopy that follows attn_mqa
==============================================================================

For the gfx95 NSA tilelang fused-rope path, the
`fused_qk_rope_cat_and_cache_mla` kernel that normally runs in
`forward_absorb_core` is moved into `forward_absorb_prepare` so it runs
on cur inside the dual-stream window — overlapping with the alt
indexer instead of running serially after the join. The result is
forwarded from prepare to core via a new optional `fused_qk_kv_cache`
return field; core falls back to the original inline computation when
the prepare-side fast path was not taken (non-capture, non-decode, or
non-HIP).

In addition, `forward_absorb_core` now passes the already-concatenated
`q_cat` directly to `attn_mqa` with `q_rope=None` on the decode path
(prefill keeps the split form because `nsa_backend.forward_extend`
asserts `q_rope is not None`). On the receiving side,
`nsa_backend.forward_decode` is updated to track `q_all` explicitly:

  - When caller passes split q_nope / q_rope (CUDA / non-HIP paths or
    non-decode HIP), q_all is initialized to None and each impl block
    re-cats as before — byte-identical to the pre-patch behavior.
  - When caller passes q_rope=None on HIP decode, q_all is set to a
    zero-copy `q.contiguous().view(...)` of `q_cat` and each impl block
    skips the otherwise-redundant `concat_mla_absorb_q_general` call.

The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
backends always re-cat (preserves prior behavior bit-exactly).

This eliminates the CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel
that previously fired once per layer per decode step (~5 us/layer)
between fused_qk_rope_cat and main_kernel on ROCm tilelang traces:
390 invocations → 0 in DualStream0429_v2 trace.

==============================================================================
Validation
==============================================================================

  * MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA tilelang decode:
      - Layer latency: ~397 us → ~324 us  (-73 us / -18.4%)
      - 8k1k conc4 TPOT: 24.48 ms median (output throughput 117 tok/s)
      - GSM8K 1200q: 0.953  (PR sgl-project#23562 baseline 0.951)
  * trace: results/.../GLM-5.1-FP8-prof-DualStream0429_v2/
            prof_in8192_out1024_conc4_p8/*-TP-0-DECODE.trace.json.gz
  * Stacks on top of sgl-project#23562 (preshuffled paged MQA + page_size=64) and
    requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer
    k-cache kernels).

==============================================================================
Files
==============================================================================

  * deepseek_v2.py         (+5 -2)  alt_stream gate + routed_scaling guard
  * forward_mla.py        (+212 -73)  A_v4 layout + fused pull-up + cat-skip
                                      plumbing, HIP-only via `_is_hip` gate
  * nsa_backend.py        (+15 -4)  q_all tracking + cat-skip, HIP-only
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 29, 2026
…ual-stream

This commit lands two HIP-only optimizations on top of PR sgl-project#23562:

1. Cat-skip in nsa_backend.forward_decode (default ON, ~2.6 us / layer)
2. A_v4 NSA dual-stream layout (gated OFF by default — regresses on MI355X)

Validated on MI355X TP=8 GLM-5.1-FP8 (8k1k conc4):

  Variant                                 Median TPOT     Δ vs Thomas
  ---------------------------------------------------------------------
  Thomas (PR sgl-project#23562 only)                   21.21 ms        baseline
  This commit, default (cat-skip on,
    dual-stream off)                        20.48 ms        −3.4% (faster)
  This commit + SGLANG_ENABLE_HIP_DUAL_STREAM=1
    + --disable-shared-experts-fusion       24.45 ms        +15.3% (regression)

==============================================================================
1. Cat-skip optimization (default ON, HIP-only)
==============================================================================

In the NSA TileLang fused-rope decode path, fused_qk_rope_cat_and_cache_mla
produces a contiguous `q_cat` tensor of shape (M, num_heads, kv_lora_rank +
qk_rope_head_dim). The pre-patch flow then sliced q_cat into (q_nope_fused,
q_pe_fused) and passed them as separate args to attn_mqa, which causes
nsa_backend.forward_decode to call concat_mla_absorb_q_general(q_nope, q_rope)
to rebuild q_all. On ROCm that fallback hits torch.cat → CatArrayBatchedCopy,
producing a tensor that is byte-identical to the q_cat we already have.

forward_absorb_core now passes q_cat directly to attn_mqa with q_rope=None on
the decode path (prefill keeps the split form because forward_extend asserts
q_rope is not None). nsa_backend.forward_decode is updated to track q_all
explicitly:

  - When caller passes split q_nope / q_rope, q_all=None and each impl block
    re-cats as before — byte-identical to pre-patch behavior.
  - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy
    `q.contiguous().view(...)` and the cat is skipped.

The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends
always re-cat (preserves CUDA / MUSA behavior bit-exactly).

Effect: CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired
once per layer per decode step disappears from ROCm tilelang traces.

==============================================================================
2. A_v4 dual-stream layout (opt-in via SGLANG_ENABLE_HIP_DUAL_STREAM=1)
==============================================================================

forward_absorb_prepare gains a HIP-only A_v4 dual-stream layout that overlaps
the NSA indexer chain on alt with [q_b_proj + bmm w_kc + fused_qk_rope_cat]
on cur. Two HIP-graph capture rules drive the layout:

  1. Dispatch order picks the physical stream — the branch dispatched first
     keeps the predecessor stream (phys 0); the later-dispatched branch lands
     on a fresh aux stream (phys 4). q_b_proj is dispatched on cur FIRST,
     then `with stream(alt):` for the indexer.
  2. alt.wait_stream(cur) is placed BEFORE q_b_proj. Indexer needs only
     q_lora (phase1 output), not q_b_proj's q, so alt's heavy indexer chain
     can start the moment phase1 completes — in parallel with cur's q_b_proj
     plus gap-fill.

The cur.wait_stream(alt) join is moved past rotary_emb so cur's gap-fill
chain overlaps with alt's indexer. fused_qk_rope_cat_and_cache_mla is also
pulled from forward_absorb_core into prepare's dual-stream window, with the
result forwarded via a new optional fused_qk_kv_cache return field.

CUDA / MUSA / NPU paths take the original q_b_proj ∥ NSA-indexer layout from
PR sgl-project#23562 base (byte-identical) — the new layout was not validated on those
platforms.

Why opt-in: on MI355X the layout regresses ~30 us / layer due to three
contention sources:

  - HBM bandwidth contention: indexer's memory-bound kernels lose 0.5-2.4 us
    each when sharing HBM with cur GEMMs (+8 us total).
  - Compute-unit split: scheduler partitions 256 CUs across concurrent
    kernels, slowing both compute-bound kernels (+5 us total).
  - HIP-graph AllReduce slowdown: aiter::cross_device_reduce_1stage takes
    23 us under dual-stream graph capture vs 9.5 us single-stream — same
    kernel, same TP=8 topology. Likely caused by the AR's first-stage peer
    fence having to drain alt's KV-cache writes too. ~+26 us / layer (2 ARs).

Theoretical A_v4 saving (gap-fill ∥ indexer ≈ −10 us / layer) is dwarfed
by these costs. The layout is preserved behind SGLANG_ENABLE_HIP_DUAL_STREAM
for future ROCm releases that may fix the AR fence cost.

To enable for testing:

  SGLANG_ENABLE_HIP_DUAL_STREAM=1 ./GLM.sh --dual-stream-rocm ...

==============================================================================
Files changed
==============================================================================

  environ.py            (+8)   New env var SGLANG_ENABLE_HIP_DUAL_STREAM
  deepseek_v2.py        (+15 -2)
                              alt_stream gate now requires _is_hip + env var.
                              forward_normal_dual_stream's routed_scaling
                              multiply also adds `not _use_aiter` (aiter's
                              biased_grouped_topk already fuses the scaling).
  forward_mla.py        (+212 -73)
                              A_v4 layout in forward_absorb_prepare (gated on
                              _is_hip; degrades to serial when alt_stream is
                              None). fused_qk_rope_cat pull-up + q_rope=None
                              cat-skip plumbing in forward_absorb_core.
  nsa_backend.py        (+15 -4)
                              q_all tracking + cat-skip in forward_decode.
                              HIP-only — non-HIP always re-cats.

Stacks on top of PR sgl-project#23562 (preshuffled paged MQA + page_size=64) and
requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache
kernels).

Detailed regression analysis:
  ~/SGLang-benchmarks/tmp/dual_stream_regression_analysis.md
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 29, 2026
…ual-stream

This commit lands two HIP-only optimizations on top of PR sgl-project#23562:

1. Cat-skip in nsa_backend.forward_decode (default ON, ~2.6 us / layer)
2. A_v4 NSA dual-stream layout (gated OFF by default — regresses on MI355X)

Validated on MI355X TP=8 GLM-5.1-FP8 (8k1k conc4):

  Variant                                 Median TPOT     Δ vs Thomas
  ---------------------------------------------------------------------
  Thomas (PR sgl-project#23562 only)                   21.21 ms        baseline
  This commit, default (cat-skip on,
    dual-stream off)                        20.48 ms        −3.4% (faster)
  This commit + SGLANG_ENABLE_HIP_DUAL_STREAM=1
    + --disable-shared-experts-fusion       24.45 ms        +15.3% (regression)

==============================================================================
1. Cat-skip optimization (default ON, HIP-only)
==============================================================================

In the NSA TileLang fused-rope decode path, fused_qk_rope_cat_and_cache_mla
produces a contiguous `q_cat` tensor of shape (M, num_heads, kv_lora_rank +
qk_rope_head_dim). The pre-patch flow then sliced q_cat into (q_nope_fused,
q_pe_fused) and passed them as separate args to attn_mqa, which causes
nsa_backend.forward_decode to call concat_mla_absorb_q_general(q_nope, q_rope)
to rebuild q_all. On ROCm that fallback hits torch.cat → CatArrayBatchedCopy,
producing a tensor that is byte-identical to the q_cat we already have.

forward_absorb_core now passes q_cat directly to attn_mqa with q_rope=None on
the decode path (prefill keeps the split form because forward_extend asserts
q_rope is not None). nsa_backend.forward_decode is updated to track q_all
explicitly:

  - When caller passes split q_nope / q_rope, q_all=None and each impl block
    re-cats as before — byte-identical to pre-patch behavior.
  - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy
    `q.contiguous().view(...)` and the cat is skipped.

The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends
always re-cat (preserves CUDA / MUSA behavior bit-exactly).

Effect: CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired
once per layer per decode step disappears from ROCm tilelang traces.

==============================================================================
2. A_v4 dual-stream layout (opt-in via SGLANG_ENABLE_HIP_DUAL_STREAM=1)
==============================================================================

forward_absorb_prepare gains a HIP-only A_v4 dual-stream layout that overlaps
the NSA indexer chain on alt with [q_b_proj + bmm w_kc + fused_qk_rope_cat]
on cur. Two HIP-graph capture rules drive the layout:

  1. Dispatch order picks the physical stream — the branch dispatched first
     keeps the predecessor stream (phys 0); the later-dispatched branch lands
     on a fresh aux stream (phys 4). q_b_proj is dispatched on cur FIRST,
     then `with stream(alt):` for the indexer.
  2. alt.wait_stream(cur) is placed BEFORE q_b_proj. Indexer needs only
     q_lora (phase1 output), not q_b_proj's q, so alt's heavy indexer chain
     can start the moment phase1 completes — in parallel with cur's q_b_proj
     plus gap-fill.

The cur.wait_stream(alt) join is moved past rotary_emb so cur's gap-fill
chain overlaps with alt's indexer. fused_qk_rope_cat_and_cache_mla is also
pulled from forward_absorb_core into prepare's dual-stream window, with the
result forwarded via a new optional fused_qk_kv_cache return field.

CUDA / MUSA / NPU paths take the original q_b_proj ∥ NSA-indexer layout from
PR sgl-project#23562 base (byte-identical) — the new layout was not validated on those
platforms.

Why opt-in: on MI355X the layout regresses ~30 us / layer due to three
contention sources:

  - HBM bandwidth contention: indexer's memory-bound kernels lose 0.5-2.4 us
    each when sharing HBM with cur GEMMs (+8 us total).
  - Compute-unit split: scheduler partitions 304 CUs across concurrent
    kernels, slowing both compute-bound kernels (+5 us total).
  - HIP-graph AllReduce slowdown: aiter::cross_device_reduce_1stage takes
    23 us under dual-stream graph capture vs 9.5 us single-stream — same
    kernel, same TP=8 topology. Likely caused by the AR's first-stage peer
    fence having to drain alt's KV-cache writes too. ~+26 us / layer (2 ARs).

Theoretical A_v4 saving (gap-fill ∥ indexer ≈ −10 us / layer) is dwarfed
by these costs. The layout is preserved behind SGLANG_ENABLE_HIP_DUAL_STREAM
for future ROCm releases that may fix the AR fence cost.

To enable for testing:

  SGLANG_ENABLE_HIP_DUAL_STREAM=1 ./GLM.sh --dual-stream-rocm ...

==============================================================================
Files changed
==============================================================================

  environ.py            (+8)   New env var SGLANG_ENABLE_HIP_DUAL_STREAM
  deepseek_v2.py        (+15 -2)
                              alt_stream gate now requires _is_hip + env var.
                              forward_normal_dual_stream's routed_scaling
                              multiply also adds `not _use_aiter` (aiter's
                              biased_grouped_topk already fuses the scaling).
  forward_mla.py        (+212 -73)
                              A_v4 layout in forward_absorb_prepare (gated on
                              _is_hip; degrades to serial when alt_stream is
                              None). fused_qk_rope_cat pull-up + q_rope=None
                              cat-skip plumbing in forward_absorb_core.
  nsa_backend.py        (+15 -4)
                              q_all tracking + cat-skip in forward_decode.
                              HIP-only — non-HIP always re-cats.

Stacks on top of PR sgl-project#23562 (preshuffled paged MQA + page_size=64) and
requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache
kernels).

Detailed regression analysis:
  ~/SGLang-benchmarks/tmp/dual_stream_regression_analysis.md
@valarLip valarLip merged commit a6bb499 into ROCm:main Apr 29, 2026
69 of 87 checks passed
Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
…ndexer_k_quant_cache (#2879)

* Support preshuffle in indexer_k_quant_and_cache kernel

* Fix unittest error

* Fix format issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants