[DSv4] Route large-q prefill through flash_mla_sparse_fwd#25502
[DSv4] Route large-q prefill through flash_mla_sparse_fwd#25502yuan-luo wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fallback mechanism for DeepSeek-V4 attention to handle large query batches that exceed the shared memory capacity of the FlashMLA sparse decode kernel. It introduces a dedicated prefill path that utilizes new Triton kernels for dequantizing the DSv4-Flash K cache and combining sparse indices. Review feedback highlights opportunities to reduce peak memory usage by avoiding large transient tensors during concatenation and eliminating unnecessary memory copies caused by .contiguous() calls. Other suggestions include correcting a docstring reference to the cache layout and centralizing a duplicated threshold constant for better maintainability.
|
/tag-and-rerun-ci |
1c719cb to
5edccd9
Compare
The DSv4 attention backend currently routes both prefill and decode through flash_mla.flash_mla_with_kvcache. Its internal get_decoding_sched_meta kernel allocates 4 * (b * 5 + 1) bytes of dynamic shared memory, where b == q.shape[0]. SM90 and SM100 both have a 228 KB opt-in smem cap (CUDA Programming Guide Table 16; Blackwell whitepaper), so b > 11673 makes cudaFuncSetAttribute return "invalid argument" and the scheduler crashes: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu:111 Assertion error: smem_size <= SM90ArchSpec::smem_capacity For decode b == active sequence count, always small. For DSv4 prefill-extend SGLang reshapes q so b == tokens in the chunk, and any chunked_prefill_size > 11673 crashes the backend. Fix: when q.shape[0] > 11673, route through flash_mla.flash_mla_sparse_fwd (BF16 flat KV, no smem-limited schedule kernel, tile shapes tuned for prefill). flash_mla provides both kernels on SM90a and SM100f, so the runtime dispatch happens inside flash_mla and no Python-side arch branching is needed. The two-path pattern (sparse_decode for decode, sparse_fwd for prefill) matches the NSA backend for DSv3.2.
5edccd9 to
065a563
Compare
|
#25418 cherry picked the routing logic so it works for long prefix. Could you help take a look? The PR also optimizes the sparse prefill path so it's faster than the main as analyzed with kernel profiling. |
@zcnrex Share with you the background. My original intention for this PR was to fully realize the gains from #24816 (introducing FlashInfer Cutlass MoE for DSv4-Flash, which is M-hungry and benefits from larger chunk sizes). When pushing per-expert M higher I hit the same SMEM cap crash you ran into for ultra-long context. After comparing, #25418 implements the same sparse-fwd prefill direction and goes further on the parts I had left as follow-ups:
I'm happy to see some of the early commits made it into your PR, let's land it via #25418. I'm willing to help to run the H20 e2e workload I used here (DSv4-Flash, TP=8, chunked-prefill-size=32768, EAGLE-3) on #25418 and report numbers. Thanks @zcnrex for the work. |
Motivation
To fix crash in large --chunked-prefill-size.
DSv3 uses dense MLA: every query attends to all KV, no per-token
routing. The corresponding FlashMLA path is
dense_decode.h, which takesqin shape[num_seq, seq_len, h, d]. Hereb == num_seq, capped bymax_running_requests(typically a few tens), so the per-bSMEMallocation inside
get_decoding_sched_metastays in the single-digit KBrange and never trips the opt-in cap.
DSv4 introduces sparse attention with per-token top-k routing
(
sparse_decode.h). The current SGLang DSv4 backend reuses the existingDSv3 dense-decode infrastructure (same FlashMLA entrypoint, same
schedule kernel, same SMEM accounting), and only patches the
qlayoutto carry per-token routing. Concretely, since each token now needs its
own top-k indices,
qis reshaped to[num_tokens, 1, h, d]beforedispatch — the kernel's
bflips from "number of active sequences" to"number of query tokens in the chunk", a 100–1000× jump. With
get_decoding_sched_metastill allocating4 * (b * 5 + 1)bytes ofdynamic shared memory and the SM90/SM100 opt-in SMEM cap at 228 KB,
anything above
b > 11673(= ⌊(228 × 1024 / 4 − 1) / 5⌋) makescudaFuncSetAttributereturninvalid argumentand the schedulercrashes:
Concretely, any
chunked_prefill_size > 11673crashes the DSv4 backend onthe very first long-prompt request. This PR adds the plumbing required to
run DSv4-Flash safely at larger chunk sizes — specifically the
--chunked-prefill-size 32768configuration we want for long-contextprefill throughput.
Modifications
The fix is split across three concerns. All three must land together:
removing any one makes the long-prompt request crash before the next can even be triggered.
1. Route large-
qprefill throughflash_mla_sparse_fwdWhen
q.shape[0] > 11673the attention call now dispatches toflash_mla.flash_mla_sparse_fwd(the dedicated prefill kernel — BF16 flatKV, no schedule kernel, tile shapes tuned for prefill) instead of
flash_mla_with_kvcache(the decode kernel — FP8 paged KV, SMEM-limitedschedule kernel). FlashMLA exposes both kernels on SM90a and SM100f, so the
runtime dispatch happens inside FlashMLA and no Python-side arch branching
is required. The two-path pattern (sparse-decode for decode, sparse-fwd for
prefill) mirrors the NSA backend for DSv3.2.
For layers with
compress_ratio in {4, 128}the sparse-fwd path additionallyneeds to attend over
[swa_kv; extra_kv]as a single flat KV tensor.combine_swa_extra_indiceshandles index merging — it shifts each extraindex by
s_kv_swaso the same flat-index assumption holds for both halvesof the concatenated buffer. Unit tests cover the SWA-only, both-halves
populated, and asymmetric per-token-length cases.
Files:
deepseek_v4_backend.py(the branch + plumbing),dsv4/sparse_prefill_utils.py(helper kernel + Python wrapper),test/registered/kernels/test_combine_swa_extra_indices.py(UT).2. DSv4-Flash MXFP8 page-SOA dequantizer
flash_mla_sparse_fwdrequires a BF16 KV tensor, so the sparse-fwd branchneeds to dequantize the FP8 K cache before calling it. NSA already has a
dequantize_k_cachehelper, but it was written for the DSv3.2 layout(656 bytes per token, per-token AOS: 512 nope FP8 + 16 fp32 scales +
128 BF16 rope) and reusing it for DSv4-Flash trips an assertion:
DSv4-Flash uses an entirely different layout:
NSA's helper doesn't understand either the new byte count or the per-page
SOA split — re-using it produces garbage data and trips the dim assertion.
A new module
sglang/srt/layers/attention/dsv4/dequant_k_cache.pymirrorsthe writer side in
dsv4/index_buf_accessor.pyanddsv4/quant_k_cache.py. One Triton kernel program per token; 2D fused loadof the
(NUM_TILES, GROUP_SIZE)nope block; single-load scale row;single-shift bitcast for ue8m0 → BF16 (
(s << 7).bitcast<bf16>is exactly2^(s-127)for non-degenerate exponents); FP32 multiply for full mantissa.The function is called from both compress_ratio paths in
_forward_prefill_sparse_fwd. Verified by a writer → reader roundtripwith max relative error 5.9% (within FP8 e4m3 precision).
3. Auto-route the DeepGEMM indexer for large query batches
Before attention ever runs, the DSv4 indexer builds a metadata tensor via
deep_gemm.get_paged_mqa_logits_metadata. On SM90 this call raisesCUDA_ERROR_INVALID_VALUEonce the per-query batch grows past roughly theFlashMLA SMEM cap (observed at
chunked_prefill_size >= 16384):This happens before the sparse-fwd path is even reached, so without
auto-routing this concern, large-
qrequests crash in the indexer anditems 1–2 never get exercised. The same call has a working alternative in
sglang/jit_kernel/deepseek_v4.py::get_paged_mqa_logits_metadata(vendoredTriton kernel, no SMEM cap), gated today by
SGLANG_OPT_USE_JIT_INDEXER_METADATA.Change:
metadata.pynow also picks the JIT kernel whenc4_seq_lens.numel() > 11673. An explicitSGLANG_OPT_USE_JIT_INDEXER_METADATA=1still wins for users who want JITeverywhere.
A note on the threshold:
11673is the FlashMLA decode-kernel SMEM cap(
4 * (b*5+1) ≤ 228 KB ⇒ b ≤ 11673) that the prefill branch above uses togate
_forward_prefill_sparse_fwd. It is not the indexer's observedcrash point — the indexer was seen to crash at
chunked_prefill_size ≥ 16384, so its true ceiling lives somewhere in(11673, 16384]. Reusing the same constant ties the two switches together(whenever attention falls back to sparse-fwd, the indexer falls back to JIT
in lockstep), keeps the code reader from juggling two thresholds, and
stays comfortably below the observed indexer ceiling. The two metadata
kernels are equivalent in cost, so over-routing in the
(11673, 16384]band has no measurable impact.
Files
python/sglang/srt/layers/attention/deepseek_v4_backend.py_forward_prefill_sparse_fwd, two callsites to new dequantpython/sglang/srt/layers/attention/dsv4/sparse_prefill_utils.pycombine_swa_extra_indicesTriton kernel + wrapperpython/sglang/srt/layers/attention/dsv4/dequant_k_cache.pypython/sglang/srt/layers/attention/dsv4/metadata.pyc4_seq_lens.numel() > 11673test/registered/kernels/test_combine_swa_extra_indices.pyVerification
Tested on H20-3e × 8 (TP=8) with DeepSeek-V4-Flash, EAGLE-3 spec decoding,
393 216-token context,
--moe-runner-backend flashinfer_mxfp4.Correctness: chat-formatted long-prompt requests (~27 500 tokens)
produce coherent output via the sparse-fwd path. Short prompts
(
q.shape[0] ≤ 11673) are unchanged — they still go throughflash_mla_with_kvcacheand match the pre-PR output character forcharacter.
Performance (single-request prefill latency,
temperature=0.0,max_tokens=8):The 15% regression at the 13 K point is a known characteristic of
flash_mla_sparse_fwd: its per-call setup overhead doesn't amortize aswell as the decode kernel on medium-sized prefills. The 30 K point — the
primary motivation for the larger chunk size — already matches.
Backwards compat:
SGLANG_OPT_USE_JIT_INDEXER_METADATA=1continuesto force JIT for all shapes (the new auto-route is OR-ed in, not
overridden);
SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1still bypassesdeep_gemm entirely.
CI States
Latest PR Test (Base): Run #25982140150⚠️ Not enabled — add
Latest PR Test (Extra):
run-ci-extralabel to opt in.