Skip to content

integrate flash_mla_sparse_fwd#25418

Open
zcnrex wants to merge 10 commits into
sgl-project:mainfrom
zcnrex:dsv4_mla_sparse
Open

integrate flash_mla_sparse_fwd#25418
zcnrex wants to merge 10 commits into
sgl-project:mainfrom
zcnrex:dsv4_mla_sparse

Conversation

@zcnrex
Copy link
Copy Markdown
Collaborator

@zcnrex zcnrex commented May 15, 2026

Motivation

  • flash_mla_with_kvcache is slow because of complicated loading logic. Switching to flash_mla_sparse_fwd see 1.35x speedup compared to the flash_mla_with_kvcache kernel
    • total cuda wall 1.1x speedup
  • Chunk prefill only works for <=8192, fails for 32768. ([DSv4] Route large-q prefill through flash_mla_sparse_fwd #25502 has more details on this motivation)

Modifications

Accuracy Tests

$ python3 -m sglang.test.run_eval \
    --eval-name gsm8k \
    --num-shots 64 \
    --num-examples 1319 \
    --num-threads 1319 \
    --port 30000 \
    --host 0.0.0.0

Chunk prefill size 8192

ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=2048 self.reasoning_effort=None self.extra_body=None
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1255/1255 [01:57<00:00, 10.71it/s]
Total latency: 117.380 s
Score: 0.960
Output throughput: 1342.760 token/s
[METRIC] gsm8k_score=0.9601593625498008 labels={"model": "deepseek-ai/DeepSeek-V4-Flash", "eval": "gsm8k"}
[METRIC] gsm8k_latency=117.37985401274636 labels={"model": "deepseek-ai/DeepSeek-V4-Flash", "eval": "gsm8k"}
Writing report to /tmp/gsm8k_deepseek-ai_DeepSeek-V4-Flash.html
{'score:std': np.float64(0.19558466467941954), 'score': np.float64(0.9601593625498008), 'latency': 117.37985401274636, 'output_throughput': 1342.7602319464863}
Writing results to /tmp/gsm8k_deepseek-ai_DeepSeek-V4-Flash.json

Chunk prefill size 32768

ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=2048 self.reasoning_effort=None self.extra_body=None
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1255/1255 [02:38<00:00,  7.91it/s]
Total latency: 158.834 s
Score: 0.961
Output throughput: 1037.340 token/s
[METRIC] gsm8k_score=0.9609561752988047 labels={"model": "deepseek-ai/DeepSeek-V4-Flash", "eval": "gsm8k"}
[METRIC] gsm8k_latency=158.83408985985443 labels={"model": "deepseek-ai/DeepSeek-V4-Flash", "eval": "gsm8k"}
Writing report to /tmp/gsm8k_deepseek-ai_DeepSeek-V4-Flash.html
{'score:std': np.float64(0.19369926291521494), 'score': np.float64(0.9609561752988047), 'latency': 158.83408985985443, 'output_throughput': 1037.3402847296738}
Writing results to /tmp/gsm8k_deepseek-ai_DeepSeek-V4-Flash.json

Speed Tests and Profiling

Prefill 128k, 4 x B200, DSv4 Flash, --chunked-prefill-size 8192

SGLANG_OPT_USE_TOPK_V2=0 \
SGLANG_JIT_DEEPGEMM_PRECOMPILE=0 \
SGLANG_OPT_FLASHMLA_SPARSE_PREFILL=1 \
sglang serve \
    --trust-remote-code \
    --model-path deepseek-ai/DeepSeek-V4-Flash \
    --tp 4 \
    --moe-runner-backend flashinfer_mxfp4 \
    --chunked-prefill-size 32768 \
    --disable-flashinfer-autotune \
    --swa-full-tokens-ratio 0.1 \
    --cuda-graph-max-bs 8 \
    --mem-fraction-static 0.835 \
    --host 0.0.0.0 \
    --port 8888

python3 -m sglang.test.send_one \
    --different-prompts \
    --batch-size 1 \
    --host 0.0.0.0 \
    --port 8888 \
    --profile-by-stage \
    --random-input-len 131072 \
    --profile
################################################################################
#  KERNEL DIFF SUMMARY (baseline = A, sparse_prefill = B)
################################################################################

Summary:

  Speedup A → B (no-NCCL): 1.10x  (B faster)

  Metric           A = baseline  B = sparse_prefill         Δ        Δ%
  ---------------------------------------------------------------------
  Total events        1,170,782             622,855  -547,927    -46.8%
  GPU kernels            13,190              13,725      +535     +4.1%
  Annotations               905                 905        +0     +0.0%
  CUDA total wall      785.7 ms            729.4 ms  -56.4 ms     -7.2%
  CUDA no-NCCL         701.1 ms            637.1 ms  -64.0 ms     -9.1%

  (Note: NCCL kernel wall has large run-to-run fluctuation; treat CUDA no-NCCL as the meaningful comparison.)

Breakdown:

  [+ Added]             10 kernels, only in B (sparse_prefill)
     #   count    sum_wall   name
  -----------------------------------------------------------------------------------------------------------------------------------------------
     1     215    146.78ms   void sm100::fwd::head64::sparse_attn_fwd_kernel<false, sm100::fwd::head64::TmaParams<cute::tuple<int, int, int>, cute::TiledCopy<cute::Copy_
     2     105     21.47ms   void (anonymous namespace)::topk_512_transform<true>((anonymous namespace)::TopK512Params)
     3     420      5.22ms   _dequantize_k_cache_paged_kernel
     4     115      3.95ms   _combine_topk_swa_indices_kernel
     5       5    187.14us   _build_swa_token_ids_kernel
     6      15     23.78us   void at::native::vectorized_elementwise_kernel<4, at::native::AUnaryFunctor<int, int, int, at::native::binary_internal::MulFunctor<int> >, s
     7       5     20.73us   void at::native::_scatter_gather_elementwise_kernel<128, 8, at::native::_cuda_scatter_gather_internal_kernel<false, at::native::OpaqueType<4
     8       5     10.24us   void at::native::unrolled_elementwise_kernel<at::native::CUDAFunctorOnSelf_add<int>, std::array<char*, 2ul>, 4, TrivialOffsetCalculator<1, u
     9       5      8.74us   void at::native::vectorized_elementwise_kernel<4, at::native::BinaryFunctor<int, int, int, at::native::minimum_kernel_cuda(at::TensorIterato
    10       5      7.46us   void at::native::vectorized_elementwise_kernel<4, at::native::BUnaryFunctor<int, int, int, at::native::remainder_kernel_cuda(at::TensorItera

  [- Removed]            4 kernels, only in A (baseline)
     #   count    sum_wall   name
  -----------------------------------------------------------------------------------------------------------------------------------------------
     1     215    205.89ms   void sm100::decode::head64::flash_fwd_splitkv_mla_fp8_sparse_kernel<sm100::decode::head64::KernelTemplate<(ModelType)1>, sm100::decode::head
     2     105     14.57ms   (anonymous namespace)::topk_short_transform((anonymous namespace)::TopKParams)
     3     215     12.36ms   void smxx::decode::flash_fwd_mla_combine_kernel<cutlass::bfloat16_t, 512, 8, 160, 256>(CombineParams)
     4      15      9.04ms   smxx::decode::get_mla_metadata_kernel(GetDecodeSchedMetaParams)

  [# Count changed]     12 kernels, launch count differs
     #   count_A   count_B    Δcount      wall_A      wall_B       Δwall   name
  -----------------------------------------------------------------------------------------------------------------------------------------------
     1        10        45       +35     23.33us     95.49us    +72.16us   void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, std::array<char*, 1ul> >(int, at::native::FillFunctor<int>,
     2        25        50       +25    107.94us    155.07us    +47.14us   void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIteratorBase&)::{lambda()#3}::operator()() const:
     3        10        30       +20     13.60us     37.28us    +23.68us   void (anonymous namespace)::elementwise_kernel_with_index<int, at::native::arange_cuda_out(c10::Scalar const&, c10::Scalar const&, c10::Scal
     4       245       265       +20    790.46us    826.08us    +35.62us   void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIteratorBase&)::{lambda()#3}::operator()() const:
     5        10        25       +15     13.63us     37.50us    +23.87us   void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorOnSelf_add<int>, std::array<char*, 2ul> >(int, at::native::CUDAFunc
     6         5        15       +10      5.09us     17.86us    +12.77us   void at_cuda_detail::cub::detail::scan::DeviceScanInitKernel<at_cuda_detail::cub::ScanTileState<long, true> >(at_cuda_detail::cub::ScanTileS
     7       215       225       +10    438.39us    456.90us    +18.51us   void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctor_add<int>, std::array<char*, 3ul> >(int, at::native::CUDAFunctor_ad
     8       225       235       +10    430.78us    446.85us    +16.07us   void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scala
     9         5        15       +10      9.79us     28.93us    +19.13us   void at_cuda_detail::cub::detail::scan::DeviceScanKernel<at_cuda_detail::cub::detail::scan::policy_hub<long, long, long, unsigned int, std::
    10        10        20       +10     95.90us    133.60us    +37.70us   void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel<at::native::index_kernel_impl<at::native::OpaqueType<4> >(at:
    11         5        10        +5     16.42us     28.93us    +12.51us   void at::native::vectorized_elementwise_kernel<4, at::native::BUnaryFunctor<int, int, int, at::native::binary_internal::div_floor_kernel_cud
    12       215       220        +5    517.73us    537.88us    +20.16us   void at::native::elementwise_kernel<128, 2, at::native::gpu_kernel_impl_nocast<at::native::direct_copy_kernel_cuda(at::TensorIteratorBase&):

  [~ Top wall changed]   1 kernels, same count, |Δwall| ≥ 10us and ≥ 20%
     #   count      wall_A      wall_B       Δwall       Δ%   name
  -----------------------------------------------------------------------------------------------------------------------------------------------
     1      10     44.32us     57.31us    +12.99us     +29%   alloc_extend_kernel


Accounting:
  Added    (only in B = sparse_prefill)     +177.7 ms
  Removed  (only in A = baseline)           -241.9 ms
  Count-changed                               +0.3 ms
  Wall-changed  (incl. NCCL fluctuation)      +7.5 ms
  ---------------------------------------------------
  Total CUDA wall Δ                          -56.4 ms
  Total CUDA wall Δ (no-NCCL)                -64.0 ms

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

@github-actions github-actions Bot added quant LLM Quantization deepseek labels May 15, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zcnrex
Copy link
Copy Markdown
Collaborator Author

zcnrex commented May 20, 2026

@yuan-luo Could you help test on H20 with flag SGLANG_OPT_FLASHMLA_SPARSE_PREFILL=1? I was able to verify it works on B200.

@zcnrex zcnrex force-pushed the dsv4_mla_sparse branch from a9d8c02 to e43253f Compare May 20, 2026 21:19
Comment on lines +1163 to +1169
if compressed_slice is not None:
dequantize_k_cache_paged(
extra_k_cache,
flat_token_ids,
page_size=extra_page_size,
out=compressed_slice,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this dequant all the c4 cache, or only the selected c4 cache?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like all the c4 cache. Since selected c4 cache changes between different layers, but the sparse prefill cache here is only computed once at the first layer (so the value of flat_token_ids doesn't change). Need @zcnrex to confirm this

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo Could you help test on H20 with flag SGLANG_OPT_FLASHMLA_SPARSE_PREFILL=1? I was able to verify it works on B200.

@zcnrex Sure, kinda busy with some other staff at hand. Will do it this weekend.

c4_sparse: means "compressed by 4" but only attend to top-512 tokens.
all related length will be clipped to 512.
"""
_LARGE_INDEXER_QUERY_THRESHOLD = 11673
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hardcode to this value. Can we avoid hardcoding

if envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get():
use_jit_indexer = (
envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get()
or self.c4_seq_lens.numel() > _LARGE_INDEXER_QUERY_THRESHOLD
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it numel here? I think the numel of c4_seq_lens should be a small value like batch size?

fp8_vals = tl.load(buf_fp8_ptr + fp8_off).to(tl.float32)

scale_u8 = tl.load(buf_uint8_ptr + token_scale_base + tile_id).to(tl.int32)
scale_pow2 = tl.exp2((scale_u8 - 127).to(tl.float32))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will tl.int32 and tl.float32 cause overflow here?

@@ -0,0 +1,584 @@
"""Per-query sparse-index combiner for the FlashMLA sparse prefill path.

Adapts vllm's ``combine_topk_swa_indices`` (vllm/v1/attention/ops/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paste the link to reference file

Comment on lines +1163 to +1169
if compressed_slice is not None:
dequantize_k_cache_paged(
extra_k_cache,
flat_token_ids,
page_size=extra_page_size,
out=compressed_slice,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like all the c4 cache. Since selected c4 cache changes between different layers, but the sparse prefill cache here is only computed once at the first layer (so the value of flat_token_ids doesn't change). Need @zcnrex to confirm this


Adapts vllm's ``combine_topk_swa_indices`` (vllm/v1/attention/ops/
deepseek_v4_ops/cache_utils.py) to sglang's flat-workspace layout. For each
query token in a prefill chunk, emits one row of combined indices into the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For we change all the index tensors to tl.int64 in this file. int32 is prone to IMA

@@ -0,0 +1,136 @@
from typing import Optional
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a torch ref impl here, and add an output comparison test between ref & applied kerenl (maybe under __main__)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants