Skip to content

[FlashInfer v0.6.10] [RL] [DSv32] [GLM-5] Add --dsa-topk-backend and integrate FlashInfer and pytorch topk#22851

Merged
Fridge003 merged 13 commits into
sgl-project:mainfrom
zianglih:torch-topk
May 25, 2026
Merged

[FlashInfer v0.6.10] [RL] [DSv32] [GLM-5] Add --dsa-topk-backend and integrate FlashInfer and pytorch topk#22851
Fridge003 merged 13 commits into
sgl-project:mainfrom
zianglih:torch-topk

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 15, 2026

Motivation

@HumansAnd

Add --dsa-topk-backend for configurable topk backend implementation selection.

torch.topk is used by GLM-5 for RL.
FlashInfer topk has determinism and configurable tie break (flashinfer-ai/flashinfer#3095), and better long context performance.

Modifications

  • Add --dsa-topk-backend, default to existing sgl-kernel
  • Integrate flashinfer and torch topk for unfused code path
  • Integrate flashinfer topk for fused code path
  • Add SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK and SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC
  • Add new unit test

Accuracy Tests

New unit test python3 -m pytest -q test/registered/kernels/test_dsa_indexer.py -k test_topk_unfused_backends_valid_selection passed.

SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 python3 -m sglang.launch_server --dsa-topk-backend sgl-kernel --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend dsa --dsa-decode-backend flashmla_sparse --dsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.977
Invalid: 0.000
Latency: 13.146 s
Output throughput: 8566.746 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 12.749 s
Output throughput: 8878.813 token/s
Accuracy: 0.981
Invalid: 0.000
Latency: 17.272 s
Output throughput: 6584.294 token/s
# torch unfused
SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_DSA_FUSE_TOPK=0 python3 -m sglang.launch_server --dsa-topk-backend torch --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend dsa --dsa-decode-backend flashmla_sparse --dsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.982
Invalid: 0.000
Latency: 18.256 s
Output throughput: 6183.790 token/s
Accuracy: 0.983
Invalid: 0.000
Latency: 17.637 s
Output throughput: 6388.987 token/s
Accuracy: 0.980
Invalid: 0.000
Latency: 17.609 s
Output throughput: 6403.039 token/s
# flashinfer unfused
SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_DSA_FUSE_TOPK=0 python3 -m sglang.launch_server --dsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend dsa --dsa-decode-backend flashmla_sparse --dsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.978
Invalid: 0.000
Latency: 20.846 s
Output throughput: 5413.876 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 24.896 s
Output throughput: 4557.003 token/s
Accuracy: 0.979
Invalid: 0.000
Latency: 21.313 s
Output throughput: 5292.839 token/s
# flashinfer fused
SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_DSA_FUSE_TOPK=1 python3 -m sglang.launch_server --dsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend dsa --dsa-decode-backend flashmla_sparse --dsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.980
Invalid: 0.000
Latency: 13.531 s
Output throughput: 8320.213 token/s
Accuracy: 0.981
Invalid: 0.000
Latency: 12.771 s
Output throughput: 8832.274 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 12.121 s
Output throughput: 9267.255 token/s
# flashinfer fused with tie_break=1
SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK=1 SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_DSA_FUSE_TOPK=1 python3 -m sglang.launch_server --dsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend dsa --dsa-decode-backend flashmla_sparse --dsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.978
Invalid: 0.000
Latency: 13.716 s
Output throughput: 8219.616 token/s
Accuracy: 0.980
Invalid: 0.000
Latency: 13.008 s
Output throughput: 8652.700 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 17.669 s
Output throughput: 6457.714 token/s
# flashinfer fused with tie_break=2
SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK=2 SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_DSA_FUSE_TOPK=1 python3 -m sglang.launch_server --dsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend dsa --dsa-decode-backend flashmla_sparse --dsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.979
Invalid: 0.000
Latency: 13.370 s
Output throughput: 8438.633 token/s
Accuracy: 0.982
Invalid: 0.000
Latency: 13.129 s
Output throughput: 8628.890 token/s
Accuracy: 0.980
Invalid: 0.000
Latency: 12.498 s
Output throughput: 9047.713 token/s

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ✅ Run #26347798409
Latest PR Test (Extra): ✅ Run #26347798377

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 15, 2026
@ziang-and ziang-and requested a review from wisclmy0611 as a code owner April 21, 2026 02:54
@zianglih zianglih changed the title [RL] [V3.2] [GLM-5] Add SGLANG_NSA_TORCH_TOPK [RL] [DSv32] [GLM-5] Add --nsa-topk-backend and integrate FlashInfer and pytorch topk Apr 21, 2026
@nvpohanh
Copy link
Copy Markdown
Collaborator

cc @nvjullin

@DarkSharpness
Copy link
Copy Markdown
Collaborator

qq: Does flashinfer kernel support cuda-graph? I know flashinfer may dispatch to different algorithms based on static sequence length, but is that safe under CUDA graph?

@zianglih
Copy link
Copy Markdown
Contributor Author

Hi @DarkSharpness , thank you for calling this out. This is indeed a valid concern. Current FlashInfer's dispatch heuritics use max_len, which is not CUDA graph safe in current implementation. We are also working with CCCL team for a graph safe topk (flashinfer-ai/flashinfer#3091 etc) which will be integrated into flashinfer soon. As of now for this PR we can disallow cuda graph if flashinfer topk backend is used.

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Apr 21, 2026

Hold until flashinfer-ai/flashinfer#3133 , which introduces a graph safe mode.

kahyunnam pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Apr 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description
@HumansAnd
Parent PR: #3095
SGLang PR: sgl-project/sglang#22851

Add `row_starts` and `dsa_graph_safe` for SGLang DSA integration.
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues
sgl-project/sglang#22851 (comment)

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added dsa_graph_safe flag to top-k APIs to opt into DSA-graph safe
execution.
* Added optional row_starts parameter to page-table and ragged top-k
transforms to support per-row score offsets.

* **Behavior**
* When dsa_graph_safe=True the optimized clusters fast-path is disabled
to ensure safe execution.

* **Tests**
* Added tests covering row_starts behavior for page-table and ragged
transforms.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@zianglih
Copy link
Copy Markdown
Contributor Author

Hold until flashinfer v0.6.10 release.

@zianglih
Copy link
Copy Markdown
Contributor Author

Hold until #24452 . Also need to add v4 support.

@b8zhong b8zhong added the run-ci label May 14, 2026
@zianglih zianglih marked this pull request as draft May 14, 2026 17:42
@zianglih zianglih marked this pull request as ready for review May 14, 2026 18:02
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zianglih
Copy link
Copy Markdown
Contributor Author

According to discussion with @DarkSharpness we will add v4 support in later PRs. Also V4 has pending metadata refactoing work.

@zianglih zianglih requested a review from zijiexia as a code owner May 16, 2026 07:27
@zianglih zianglih changed the title [RL] [DSv32] [GLM-5] Add --dsa-topk-backend and integrate FlashInfer and pytorch topk [FlashInfer v0.6.10] [RL] [DSv32] [GLM-5] Add --dsa-topk-backend and integrate FlashInfer and pytorch topk May 21, 2026
@nvpohanh
Copy link
Copy Markdown
Collaborator

@Fridge003 could you check if your comments have been addressed? thanks!

@zianglih
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a modular backend system for DeepSeek Sparse Attention (DSA) top-k operations, allowing users to choose between sgl-kernel, torch, and flashinfer implementations via the new --dsa-topk-backend argument. The changes include the addition of a DSATopKBackend class to encapsulate the logic for both fused and unfused top-k transformations, along with new environment variables to configure FlashInfer's deterministic behavior and tie-breaking modes. The existing DSA attention backend was refactored to integrate this new system, and extensive tests were added to verify the correctness and equivalence of the backends. Reviewer feedback focused on performance optimizations, specifically suggesting the replacement of torch.diff with manual slicing and subtraction in performance-critical sections.

Comment thread python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py Outdated
Copy link
Copy Markdown
Contributor Author

Posted an additional GSM8K Platinum accuracy run for the FlashInfer fused DSA topk path. Server was launched once, then the benchmark was repeated three times against the same server process.

Server command:

SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK=large \
SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 \
SGLANG_DSA_FUSE_TOPK=1 \
PYTHONPATH=/sgl-workspace/sglang/python:${PYTHONPATH:-} \
python3 -m sglang.launch_server \
  --host 127.0.0.1 \
  --port 30000 \
  --dsa-topk-backend flashinfer \
  --kv-cache-dtype bf16 \
  --model /data/models/ziangli_v32/DeepSeek-V3.2 \
  --tp 8 \
  --dp 8 \
  --enable-dp-attention \
  --moe-runner-backend flashinfer_trtllm_routed \
  --attention-backend dsa \
  --dsa-decode-backend flashmla_sparse \
  --dsa-prefill-backend flashmla_sparse \
  --page-size 64 \
  --trust-remote-code

Benchmark command, repeated 3 times:

python3 benchmark/gsm8k/bench_sglang.py \
  --num-shots 8 \
  --num-questions 1209 \
  --parallel 1209 \
  --platinum

gsm8k_1.log final output:

Accuracy: 0.976
Invalid: 0.000
Latency: 13.918 s
Output throughput: 8094.320 token/s

gsm8k_2.log final output:

Accuracy: 0.979
Invalid: 0.000
Latency: 13.249 s
Output throughput: 8498.528 token/s

gsm8k_3.log final output:

Accuracy: 0.981
Invalid: 0.000
Latency: 11.975 s
Output throughput: 9395.941 token/s

The large tie-break setting is the current env-var API equivalent of the previous numeric tie-break mode 2. Runtime logs were captured under run id dsa_topk_flashinfer_gsm8k_3x_20260522_184257.

@Fridge003
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@zianglih
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

2 similar comments
@zianglih
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@zianglih
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@zianglih
Copy link
Copy Markdown
Contributor Author

/rerun-stage base-c-test-4-gpu-b200 base-c-test-4-gpu-h100 base-c-test-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

⚠️ /rerun-stage has been deprecated.

Stage granularity is too coarse — a stage usually doesn't map to one feature, so rerunning a stage re-pays the cost of unrelated tests. If you don't know which exact test files to rerun, you shouldn't be using /rerun-stage or /rerun-test in the first place.

Use one of these instead:

  • Selective tests (you know exactly which files to rerun):
    /rerun-test test_foo.py test_bar.py
    
  • Rerun only failed jobs:
    /rerun-failed-ci
    
  • Full CI rerun (with extra coverage): add the run-ci or run-ci-extra label and push a new commit (or use /tag-and-rerun-ci).

AMD CI: stage-level dispatch is still available via Actions UI → PR Test (AMD) / PR Test ROCm 7.2 (AMD)Run workflow → pick a stage from the dropdown.

@zianglih
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@zianglih
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@zianglih
Copy link
Copy Markdown
Contributor Author

nv ci passed

@nvpohanh
Copy link
Copy Markdown
Collaborator

@Fridge003 this PR has passed the CI. Could you help to merge? Thanks!

@Fridge003 Fridge003 merged commit 2b9dd9c into sgl-project:main May 25, 2026
224 of 279 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci run-ci-extra

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants