Skip to content

[DSv4] Route large-q prefill through flash_mla_sparse_fwd#25502

Open
yuan-luo wants to merge 1 commit into
sgl-project:mainfrom
yuan-luo:dsv4-sparse-fwd-prefill
Open

[DSv4] Route large-q prefill through flash_mla_sparse_fwd#25502
yuan-luo wants to merge 1 commit into
sgl-project:mainfrom
yuan-luo:dsv4-sparse-fwd-prefill

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented May 17, 2026

Motivation

To fix crash in large --chunked-prefill-size.

DSv3 uses dense MLA: every query attends to all KV, no per-token
routing. The corresponding FlashMLA path is dense_decode.h, which takes
q in shape [num_seq, seq_len, h, d]. Here b == num_seq, capped by
max_running_requests (typically a few tens), so the per-b SMEM
allocation inside get_decoding_sched_meta stays in the single-digit KB
range and never trips the opt-in cap.

DSv4 introduces sparse attention with per-token top-k routing
(sparse_decode.h). The current SGLang DSv4 backend reuses the existing
DSv3 dense-decode infrastructure (same FlashMLA entrypoint, same
schedule kernel, same SMEM accounting), and only patches the q layout
to carry per-token routing. Concretely, since each token now needs its
own top-k indices, q is reshaped to [num_tokens, 1, h, d] before
dispatch — the kernel's b flips from "number of active sequences" to
"number of query tokens in the chunk", a 100–1000× jump. With
get_decoding_sched_meta still allocating 4 * (b * 5 + 1) bytes of
dynamic shared memory and the SM90/SM100 opt-in SMEM cap at 228 KB,
anything above b > 11673 (= ⌊(228 × 1024 / 4 − 1) / 5⌋) makes
cudaFuncSetAttribute return invalid argument and the scheduler
crashes:

csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu:111
Assertion error: smem_size <= SM90ArchSpec::smem_capacity

Concretely, any chunked_prefill_size > 11673 crashes the DSv4 backend on
the very first long-prompt request. This PR adds the plumbing required to
run DSv4-Flash safely at larger chunk sizes — specifically the
--chunked-prefill-size 32768 configuration we want for long-context
prefill throughput.

Modifications

The fix is split across three concerns. All three must land together:
removing any one makes the long-prompt request crash before the next can even be triggered.

1. Route large-q prefill through flash_mla_sparse_fwd

When q.shape[0] > 11673 the attention call now dispatches to
flash_mla.flash_mla_sparse_fwd (the dedicated prefill kernel — BF16 flat
KV, no schedule kernel, tile shapes tuned for prefill) instead of
flash_mla_with_kvcache (the decode kernel — FP8 paged KV, SMEM-limited
schedule kernel). FlashMLA exposes both kernels on SM90a and SM100f, so the
runtime dispatch happens inside FlashMLA and no Python-side arch branching
is required. The two-path pattern (sparse-decode for decode, sparse-fwd for
prefill) mirrors the NSA backend for DSv3.2.

For layers with compress_ratio in {4, 128} the sparse-fwd path additionally
needs to attend over [swa_kv; extra_kv] as a single flat KV tensor.
combine_swa_extra_indices handles index merging — it shifts each extra
index by s_kv_swa so the same flat-index assumption holds for both halves
of the concatenated buffer. Unit tests cover the SWA-only, both-halves
populated, and asymmetric per-token-length cases.

Files: deepseek_v4_backend.py (the branch + plumbing),
dsv4/sparse_prefill_utils.py (helper kernel + Python wrapper),
test/registered/kernels/test_combine_swa_extra_indices.py (UT).

2. DSv4-Flash MXFP8 page-SOA dequantizer

flash_mla_sparse_fwd requires a BF16 KV tensor, so the sparse-fwd branch
needs to dequantize the FP8 K cache before calling it. NSA already has a
dequantize_k_cache helper, but it was written for the DSv3.2 layout
(656 bytes per token, per-token AOS: 512 nope FP8 + 16 fp32 scales +
128 BF16 rope) and reusing it for DSv4-Flash trips an assertion:

swa_kv_bf16 = dequantize_k_cache(swa_k_cache)   # _forward_prefill_sparse_fwd
  → assert dim_quant == 656                     # AssertionError

DSv4-Flash uses an entirely different layout:

field DSv3.2 (656) DSv4-Flash (584)
nope element 512 dim FP8 e4m3 = 512 B 448 dim FP8 e4m3 = 448 B
nope scale 4 × fp32 = 16 B 7 × ue8m0 = 7 B (+1 pad)
rope element 64 dim BF16 = 128 B 64 dim BF16 = 128 B
per-token total 656 B (AOS) 584 B
per-page arrangement per-token AOS page-SOA (data region 256×576 B then scale region 256×8 B)

NSA's helper doesn't understand either the new byte count or the per-page
SOA split — re-using it produces garbage data and trips the dim assertion.

A new module sglang/srt/layers/attention/dsv4/dequant_k_cache.py mirrors
the writer side in dsv4/index_buf_accessor.py and
dsv4/quant_k_cache.py. One Triton kernel program per token; 2D fused load
of the (NUM_TILES, GROUP_SIZE) nope block; single-load scale row;
single-shift bitcast for ue8m0 → BF16 ((s << 7).bitcast<bf16> is exactly
2^(s-127) for non-degenerate exponents); FP32 multiply for full mantissa.
The function is called from both compress_ratio paths in
_forward_prefill_sparse_fwd. Verified by a writer → reader roundtrip
with max relative error 5.9% (within FP8 e4m3 precision).

3. Auto-route the DeepGEMM indexer for large query batches

Before attention ever runs, the DSv4 indexer builds a metadata tensor via
deep_gemm.get_paged_mqa_logits_metadata. On SM90 this call raises
CUDA_ERROR_INVALID_VALUE once the per-query batch grows past roughly the
FlashMLA SMEM cap (observed at chunked_prefill_size >= 16384):

File ".../sglang/srt/layers/attention/dsv4/metadata.py", line 119, in __post_init__
    self.deep_gemm_metadata = get_paged_mqa_logits_metadata(...)
tvm.error.InternalError: CUDA driver error
  (/deepgemm/csrc/apis/../jit_kernels/impls/../../jit/handle.hpp:178):
  1 (CUDA_ERROR_INVALID_VALUE, invalid argument)

This happens before the sparse-fwd path is even reached, so without
auto-routing this concern, large-q requests crash in the indexer and
items 1–2 never get exercised. The same call has a working alternative in
sglang/jit_kernel/deepseek_v4.py::get_paged_mqa_logits_metadata (vendored
Triton kernel, no SMEM cap), gated today by
SGLANG_OPT_USE_JIT_INDEXER_METADATA.

Change: metadata.py now also picks the JIT kernel when
c4_seq_lens.numel() > 11673. An explicit
SGLANG_OPT_USE_JIT_INDEXER_METADATA=1 still wins for users who want JIT
everywhere.

A note on the threshold: 11673 is the FlashMLA decode-kernel SMEM cap
(4 * (b*5+1) ≤ 228 KB ⇒ b ≤ 11673) that the prefill branch above uses to
gate _forward_prefill_sparse_fwd. It is not the indexer's observed
crash point — the indexer was seen to crash at
chunked_prefill_size ≥ 16384, so its true ceiling lives somewhere in
(11673, 16384]. Reusing the same constant ties the two switches together
(whenever attention falls back to sparse-fwd, the indexer falls back to JIT
in lockstep), keeps the code reader from juggling two thresholds, and
stays comfortably below the observed indexer ceiling. The two metadata
kernels are equivalent in cost, so over-routing in the (11673, 16384]
band has no measurable impact.

Files

file size purpose
python/sglang/srt/layers/attention/deepseek_v4_backend.py +207 −0 sparse-fwd routing branch, _forward_prefill_sparse_fwd, two callsites to new dequant
python/sglang/srt/layers/attention/dsv4/sparse_prefill_utils.py +187 (new) combine_swa_extra_indices Triton kernel + wrapper
python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py +214 (new) DSv4-Flash MXFP8 page-SOA dequant Triton kernel
python/sglang/srt/layers/attention/dsv4/metadata.py +13 −1 auto-route to JIT indexer at c4_seq_lens.numel() > 11673
test/registered/kernels/test_combine_swa_extra_indices.py +265 (new) unit test for index-merge helper

Verification

Tested on H20-3e × 8 (TP=8) with DeepSeek-V4-Flash, EAGLE-3 spec decoding,
393 216-token context, --moe-runner-backend flashinfer_mxfp4.

  • Correctness: chat-formatted long-prompt requests (~27 500 tokens)
    produce coherent output via the sparse-fwd path. Short prompts
    (q.shape[0] ≤ 11673) are unchanged — they still go through
    flash_mla_with_kvcache and match the pre-PR output character for
    character.

  • Performance (single-request prefill latency, temperature=0.0,
    max_tokens=8):

    prompt tokens chunked=8192 (decode-kernel path) chunked=32768 (sparse-fwd path, this PR)
    5 602 0.66 s 0.67 s (both branches hit the decode kernel)
    13 302 1.36 s 1.55 s (sparse-fwd ~15% slower at medium chunks)
    30 352 3.11 s 3.27 s (within 5%; sparse-fwd matches decode-kernel)

    The 15% regression at the 13 K point is a known characteristic of
    flash_mla_sparse_fwd: its per-call setup overhead doesn't amortize as
    well as the decode kernel on medium-sized prefills. The 30 K point — the
    primary motivation for the larger chunk size — already matches.

  • Backwards compat: SGLANG_OPT_USE_JIT_INDEXER_METADATA=1 continues
    to force JIT for all shapes (the new auto-route is OR-ed in, not
    overridden); SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1 still bypasses
    deep_gemm entirely.


CI States

Latest PR Test (Base): Run #25982140150
Latest PR Test (Extra): ⚠️ Not enabled — add run-ci-extra label to opt in.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fallback mechanism for DeepSeek-V4 attention to handle large query batches that exceed the shared memory capacity of the FlashMLA sparse decode kernel. It introduces a dedicated prefill path that utilizes new Triton kernels for dequantizing the DSv4-Flash K cache and combining sparse indices. Review feedback highlights opportunities to reduce peak memory usage by avoiding large transient tensors during concatenation and eliminating unnecessary memory copies caused by .contiguous() calls. Other suggestions include correcting a docstring reference to the cache layout and centralizing a duplicated threshold constant for better maintainability.

Comment thread python/sglang/srt/layers/attention/deepseek_v4_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/deepseek_v4_backend.py
Comment thread python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py
Comment thread python/sglang/srt/layers/attention/dsv4/metadata.py
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yuan-luo yuan-luo force-pushed the dsv4-sparse-fwd-prefill branch from 1c719cb to 5edccd9 Compare May 17, 2026 05:03
The DSv4 attention backend currently routes both prefill and decode
through flash_mla.flash_mla_with_kvcache. Its internal
get_decoding_sched_meta kernel allocates 4 * (b * 5 + 1) bytes of
dynamic shared memory, where b == q.shape[0]. SM90 and SM100 both have
a 228 KB opt-in smem cap (CUDA Programming Guide Table 16; Blackwell
whitepaper), so b > 11673 makes cudaFuncSetAttribute return "invalid
argument" and the scheduler crashes:

  csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu:111
  Assertion error: smem_size <= SM90ArchSpec::smem_capacity

For decode b == active sequence count, always small. For DSv4
prefill-extend SGLang reshapes q so b == tokens in the chunk, and any
chunked_prefill_size > 11673 crashes the backend.

Fix: when q.shape[0] > 11673, route through flash_mla.flash_mla_sparse_fwd
(BF16 flat KV, no smem-limited schedule kernel, tile shapes tuned for
prefill). flash_mla provides both kernels on SM90a and SM100f, so the
runtime dispatch happens inside flash_mla and no Python-side arch
branching is needed. The two-path pattern (sparse_decode for decode,
sparse_fwd for prefill) matches the NSA backend for DSv3.2.
@yuan-luo yuan-luo force-pushed the dsv4-sparse-fwd-prefill branch from 5edccd9 to 065a563 Compare May 17, 2026 05:15
@zcnrex
Copy link
Copy Markdown
Collaborator

zcnrex commented May 17, 2026

#25418 cherry picked the routing logic so it works for long prefix. Could you help take a look? The PR also optimizes the sparse prefill path so it's faster than the main as analyzed with kernel profiling.

@zcnrex zcnrex mentioned this pull request May 17, 2026
5 tasks
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 17, 2026

#25418 cherry picked the routing logic so it works for long prefix. Could you help take a look? The PR also optimizes the sparse prefill path so it's faster than the main as analyzed with kernel profiling.

@zcnrex Share with you the background.

My original intention for this PR was to fully realize the gains from #24816 (introducing FlashInfer Cutlass MoE for DSv4-Flash, which is M-hungry and benefits from larger chunk sizes). When pushing per-expert M higher I hit the same SMEM cap crash you ran into for ultra-long context.

After comparing, #25418 implements the same sparse-fwd prefill direction and goes further on the parts I had left as follow-ups:

  • gather-aware paged dequant vs. my full-cache materialization
  • SparsePrefillChunkCache scaffolding for cross-chunk/cross-layer reuse
  • vLLM-aligned combine_topk_swa_indices (more general than my combine_swa_extra_indices)

I'm happy to see some of the early commits made it into your PR, let's land it via #25418. I'm willing to help to run the H20 e2e workload I used here (DSv4-Flash, TP=8, chunked-prefill-size=32768, EAGLE-3) on #25418 and report numbers.

Thanks @zcnrex for the work.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

@zcnrex Btw, feel free to lift any of the writeup from this PR's description (the SMEM cap derivation, DSv3 / DSv4 reshape rationale, etc.) into #25418 if it's useful for reviewers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants