[FlashInfer v0.6.10] [RL] [DSv32] [GLM-5] Add --dsa-topk-backend and integrate FlashInfer and pytorch topk#22851
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
SGLANG_NSA_TORCH_TOPK--nsa-topk-backend and integrate FlashInfer and pytorch topk
|
cc @nvjullin |
|
qq: Does flashinfer kernel support cuda-graph? I know flashinfer may dispatch to different algorithms based on static sequence length, but is that safe under CUDA graph? |
|
Hi @DarkSharpness , thank you for calling this out. This is indeed a valid concern. Current FlashInfer's dispatch heuritics use |
|
Hold until flashinfer-ai/flashinfer#3133 , which introduces a graph safe mode. |
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd Parent PR: #3095 SGLang PR: sgl-project/sglang#22851 Add `row_starts` and `dsa_graph_safe` for SGLang DSA integration. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues sgl-project/sglang#22851 (comment) <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added dsa_graph_safe flag to top-k APIs to opt into DSA-graph safe execution. * Added optional row_starts parameter to page-table and ragged top-k transforms to support per-row score offsets. * **Behavior** * When dsa_graph_safe=True the optimized clusters fast-path is disabled to ensure safe execution. * **Tests** * Added tests covering row_starts behavior for page-table and ragged transforms. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
|
Hold until flashinfer v0.6.10 release. |
|
Hold until #24452 . Also need to add v4 support. |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
According to discussion with @DarkSharpness we will add v4 support in later PRs. Also V4 has pending metadata refactoing work. |
--dsa-topk-backend and integrate FlashInfer and pytorch topk--dsa-topk-backend and integrate FlashInfer and pytorch topk
|
@Fridge003 could you check if your comments have been addressed? thanks! |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a modular backend system for DeepSeek Sparse Attention (DSA) top-k operations, allowing users to choose between sgl-kernel, torch, and flashinfer implementations via the new --dsa-topk-backend argument. The changes include the addition of a DSATopKBackend class to encapsulate the logic for both fused and unfused top-k transformations, along with new environment variables to configure FlashInfer's deterministic behavior and tie-breaking modes. The existing DSA attention backend was refactored to integrate this new system, and extensive tests were added to verify the correctness and equivalence of the backends. Reviewer feedback focused on performance optimizations, specifically suggesting the replacement of torch.diff with manual slicing and subtraction in performance-critical sections.
|
Posted an additional GSM8K Platinum accuracy run for the FlashInfer fused DSA topk path. Server was launched once, then the benchmark was repeated three times against the same server process. Server command: SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK=large \
SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 \
SGLANG_DSA_FUSE_TOPK=1 \
PYTHONPATH=/sgl-workspace/sglang/python:${PYTHONPATH:-} \
python3 -m sglang.launch_server \
--host 127.0.0.1 \
--port 30000 \
--dsa-topk-backend flashinfer \
--kv-cache-dtype bf16 \
--model /data/models/ziangli_v32/DeepSeek-V3.2 \
--tp 8 \
--dp 8 \
--enable-dp-attention \
--moe-runner-backend flashinfer_trtllm_routed \
--attention-backend dsa \
--dsa-decode-backend flashmla_sparse \
--dsa-prefill-backend flashmla_sparse \
--page-size 64 \
--trust-remote-codeBenchmark command, repeated 3 times: python3 benchmark/gsm8k/bench_sglang.py \
--num-shots 8 \
--num-questions 1209 \
--parallel 1209 \
--platinum
The |
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-stage base-c-test-4-gpu-b200 base-c-test-4-gpu-h100 base-c-test-8-gpu-h200 |
|
Stage granularity is too coarse — a stage usually doesn't map to one feature, so rerunning a stage re-pays the cost of unrelated tests. If you don't know which exact test files to rerun, you shouldn't be using Use one of these instead:
AMD CI: stage-level dispatch is still available via Actions UI → PR Test (AMD) / PR Test ROCm 7.2 (AMD) → Run workflow → pick a stage from the dropdown. |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
nv ci passed |
|
@Fridge003 this PR has passed the CI. Could you help to merge? Thanks! |
Motivation
@HumansAnd
Add
--dsa-topk-backendfor configurable topk backend implementation selection.torch.topkis used by GLM-5 for RL.FlashInfer topk has determinism and configurable tie break (flashinfer-ai/flashinfer#3095), and better long context performance.
Modifications
--dsa-topk-backend, default to existingsgl-kernelSGLANG_DSA_TOPK_FLASHINFER_TIE_BREAKandSGLANG_DSA_TOPK_FLASHINFER_DETERMINISTICAccuracy Tests
New unit test
python3 -m pytest -q test/registered/kernels/test_dsa_indexer.py -k test_topk_unfused_backends_valid_selectionpassed.Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ✅ Run #26347798409
Latest PR Test (Extra): ✅ Run #26347798377