Skip to content

Support Flashinfer Cute-DSL MLA attention#24737

Open
b8zhong wants to merge 5 commits into
mainfrom
brayden/cutedsl-mla
Open

Support Flashinfer Cute-DSL MLA attention#24737
b8zhong wants to merge 5 commits into
mainfrom
brayden/cutedsl-mla

Conversation

@b8zhong
Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong commented May 9, 2026

Motivation

@nvpohanh

Ref:
flashinfer-ai/flashinfer#2805
flashinfer-ai/flashinfer#2743

(closed, could need new/reopened PR) Flashinfer autotune PR: flashinfer-ai/flashinfer#3086

Modifications

Add as a new backend (for the purposes of debugging and easily switching impls for now, ideally, in the future, the trtllm_mla backend will still be allowed to be autotuned to use the cute-dsl implementation, as we don't pass in the explicit backend string.
The little int8 change is because, for some weird reason, the cute-dsl MLA backend doesn't support unsigned. Change it for simplicity because no real difference.

Kernel limitations:
Head dim only support: DSR1 DP attention, only support decode (Kimi K2 dim through padding here: flashinfer-ai/flashinfer#3161)

Accuracy Tests

Before:

python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --num-questions 1209 --parallel 1209 --platinum
Loading GSM8K Platinum dataset from HuggingFace...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1209/1209 [00:50<00:00, 23.89it/s]
Accuracy: 0.979
Invalid: 0.000
Latency: 50.623 s
Output throughput: 2378.002 token/s

After

python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --num-questions 1209 --parallel 1209 --platinum
Loading GSM8K Platinum dataset from HuggingFace...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1209/1209 [00:42<00:00, 28.24it/s]
Accuracy: 0.978
Invalid: 0.000
Latency: 42.813 s
Output throughput: 2843.533 token/s

After in TP mode:

python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --num-questions 1209 --parallel 1209 --platinum
Loading GSM8K Platinum dataset from HuggingFace...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1209/1209 [01:50<00:00, 10.95it/s]
Accuracy: 0.980
Invalid: 0.000
Latency: 110.392 s
Output throughput: 1121.063 token/s

Looks fine.

Speed Tests and Profiling

python3 -m sglang.launch_server \
    --model-path nvidia/DeepSeek-R1-0528-FP4-v2 \
    --tp 4 --ep 4 --dp 4 --enable-dp-attention \
    --attention-backend cutedsl_mla \
    --quantization modelopt_fp4 \
    --max-running-requests 512 \
    --reasoning-parser deepseek-r1 \
    --tool-call-parser deepseekv3
SGLANG_TORCH_PROFILER_DIR="./" \
python3 -m sglang.bench_serving \
    --backend sglang \
    --base-url http://localhost:30000 \
    --dataset-name random \
    --random-input-len 8192 \
    --random-output-len 64 \
    --random-range-ratio 0.5 \
    --num-prompts 512 \
    --max-concurrency 512 \
    --profile \
    --profile-steps 10 \
    --profile-start-step 20

On SM103

Before:
image

After:
image
Attention speedup: around 18%

For TP mode:

CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server \
    --model-path nvidia/DeepSeek-R1-0528-NVFP4-v2 \
    --tp 4 --ep 4 \
    --attention-backend cutedsl_mla \
    --quantization modelopt_fp4 \
    --reasoning-parser deepseek-r1 \
    --tool-call-parser deepseekv3 \
    --speculative-algorithm EAGLE --speculative-attention-mode decode
SGLANG_TORCH_PROFILER_DIR="./" \
SGLANG_PROFILE_RECORD_SHAPES=true \
SGLANG_PROFILE_WITH_STACK=true \
python3 -m sglang.bench_one_batch_server \
  --model baseten-admin/glm-4.7-fp8-attn-fp4-mlp \
  --base-url http://localhost:30000 \
  --batch-size 8 \
  --input-len 80000 \
  --output-len 64 \
  --profile \
  --profile-steps 10 \
  --show-report \
  --profile-by-stage

Before:
Screenshot 2026-05-20 at 3 11 14 PM

After:
Screenshot 2026-05-20 at 11 35 26 PM


CI States

Latest PR Test (Base): ✅ Run #26258690377
Latest PR Test (Extra): ❌ Run #26258690334

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@nvpohanh
Copy link
Copy Markdown
Collaborator

cc @leejnau

Comment thread python/sglang/srt/server_args.py
Comment thread python/sglang/srt/speculative/draft_utils.py Outdated
Comment thread python/sglang/srt/server_args.py
Comment thread python/sglang/srt/server_args.py
@b8zhong b8zhong force-pushed the brayden/cutedsl-mla branch from 27abb6f to e4ab12b Compare May 11, 2026 18:29
Comment thread docs_new/docs/advanced_features/attention_backend.mdx Outdated
@b8zhong b8zhong assigned Fridge003 and Qiaolin-Yu and unassigned Fridge003 May 13, 2026
@nvpohanh
Copy link
Copy Markdown
Collaborator

Just FYI, more FlashInfer optimizations for cute-dsl MLA decode: flashinfer-ai/flashinfer#3309

Request to autotune between trtllm-gen and cutedsl MLA in FlashInfer: flashinfer-ai/flashinfer#2891

@b8zhong b8zhong requested a review from zijiexia as a code owner May 20, 2026 18:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants