Skip to content

Adjust flashinfer workspace size for Qwen2 models#2879

Merged
zhyncs merged 2 commits intosgl-project:mainfrom
ispobock:fix-qwen2
Jan 14, 2025
Merged

Adjust flashinfer workspace size for Qwen2 models#2879
zhyncs merged 2 commits intosgl-project:mainfrom
ispobock:fix-qwen2

Conversation

@ispobock
Copy link
Copy Markdown
Collaborator

Motivation

Fix issues #2344, #2102, #1405.

Qwen2 models (mostly Qwen2-7B) require higher flashinfer workspace size. This issue occurs on different types of hardware (SM80 & SM90). Temporarily fixed it with specific adjustments.

cc: @yzh119 @merrymercy @Ying1123 @xiezhq-hermann

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Jan 14, 2025

As a temporary solution, merge first and reprocess after upgrading to the new flashinfer version.

@zhyncs zhyncs merged commit c19d848 into sgl-project:main Jan 14, 2025
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
HaiShaw added a commit to HaiShaw/sglang that referenced this pull request May 1, 2026
HaiShaw added a commit that referenced this pull request May 1, 2026
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request May 6, 2026
GLM-5 NSA TileLang decode on ROCm dispatches a `CatArrayBatchedCopy` kernel
once per layer per decode step that rebuilds an already-existing tensor.
This is a strict-improvement bug fix: ~2.6 us / layer saved, 0 changes for
non-HIP backends.

==============================================================================
Root cause
==============================================================================

For the NSA TileLang fused-rope decode path (`_use_aiter_gfx95 + nsa +
nsa_decode_backend == "tilelang"`), `forward_absorb_core` calls
`fused_qk_rope_cat_and_cache_mla` which produces a contiguous q_cat tensor
of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch
flow then sliced q_cat into q_nope_fused / q_pe_fused and passed them as
separate args to attn_mqa.

attn_mqa -> NSABackend.forward_decode then takes the if-branch (q_rope
is not None), views the slices, and for tilelang / flashmla_sparse /
flashmla_kv / aiter decode impls calls
`concat_mla_absorb_q_general(q_nope, q_rope)` to rebuild q_all. On ROCm,
that helper falls back to `torch.cat([q_nope, q_rope], dim=-1)`, which
allocates a fresh contiguous tensor and dispatches a copy kernel. The
result is byte-identical to the q_cat we already had — the cat is pure
overhead.

==============================================================================
Fix
==============================================================================

(1) `forward_absorb_core` now passes q_cat directly to attn_mqa with
    q_rope=None on the decode path. Prefill (forward_extend) keeps the
    split form because `nsa_backend.forward_extend` asserts
    `q_rope is not None`.

(2) `nsa_backend.forward_decode` is updated to track q_all explicitly:

    - When the caller passes split q_nope / q_rope, q_all is initialized
      to None and each impl block re-cats as before (byte-identical to
      pre-patch behavior).
    - When the caller passes q_rope=None on HIP, q_all is set to a
      zero-copy `q.contiguous().view(...)` and the cat is skipped.

    The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
    backends always re-cat (preserves CUDA / MUSA paths bit-exactly).

==============================================================================
Validation
==============================================================================

MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA TileLang decode (on top of
PR sgl-project#23562 + aiter PR sgl-project#2879):

   scenario              | before    | after     | TPOT  Δ
   --------------------- | --------- | --------- | --------
   8k1k conc4   TPOT     | 21.21 ms  | 20.76 ms  | -2.17%
   8k1k conc8   TPOT     | 25.28 ms  | 24.82 ms  | -1.82%
   8k1k conc16  TPOT     | 30.79 ms  | 30.33 ms  | -1.49%
   8k1k conc32  TPOT     | 42.92 ms  | 42.46 ms  | -1.07%
   8k1k conc64  TPOT     | 61.79 ms  | 61.33 ms  | -0.74%
   1k1k conc4   TPOT     | 18.79 ms  | 18.33 ms  | -2.45%
   1k1k conc8   TPOT     | 21.14 ms  | 20.66 ms  | -2.27%
   1k1k conc16  TPOT     | 23.63 ms  | 23.15 ms  | -2.03%
   1k1k conc32  TPOT     | 29.19 ms  | 28.69 ms  | -1.71%
   1k1k conc64  TPOT     | 35.02 ms  | 34.60 ms  | -1.20%

Output throughput improves by the same percentage on every scenario.
Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative
gain is highest at small batch + short prompt (where total layer time is
smallest) and decays with batch size.

GSM8K accuracy: 0.942 vs 0.951 baseline (within run-to-run variance
observed across multiple runs of the same config: 0.946-0.953).

==============================================================================
Files
==============================================================================

  forward_mla.py    (+50 -16)  forward_absorb_core:_skip_rope_for_nsa_tilelang_fused
                                branch passes q_cat with q_rope=None for decode.
  nsa_backend.py    (+12  -4)  forward_decode tracks q_all and skips cat on HIP
                                when caller already provided concatenated q.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants