feat(gdn): add FlashInfer K-last SSM layout support for GDN prefill and decode for Hopper#18361
feat(gdn): add FlashInfer K-last SSM layout support for GDN prefill and decode for Hopper#18361ispobock merged 64 commits intosgl-project:mainfrom
Conversation
This commit introduces a new file for the CuTe DSL GDN Verify kernel, optimized for K-last layout. The kernel supports multiple tokens in sequence, caches intermediate states for verification, and maintains gate computation within the kernel. Additionally, it includes configuration updates to support K-last layout in Mamba2StateShape and Qwen3NextConfig, as well as adjustments in the GDNAttnBackend to utilize the new kernel when applicable. Key features: - Efficient memory access with K-last layout. - Support for speculative decoding verification. - Integration into existing attention mechanisms. This enhancement aims to improve performance and flexibility in handling GDN operations in neural network architectures.
- Optimize extend path: skip transpose for first prefill (zero initial state) - Fix CUDA Graph compatibility: guard debug logging with _debug_enabled flag - Use actual pool_size/cache_steps for kernel compilation - Add K-last layout support in hybrid_linear_attn_backend - Verified accuracy with different batch sizes (4, 8, 16)
…prevent OOB access for padding slots
- Add sglang_gdn_verify_kernel with optimized K-last memory layout - Support bfloat16 computation with proper dtype handling - Add early exit optimization for CUDA Graph padding slots (cache_idx < 0) - Integrate with hybrid_linear_attn_backend for MTP speculative decoding - Add warning when K-last is used without speculative decoding - Enable debug logging via SGLANG_GDN_DEBUG environment variable
…intermediate indices - Add _cached_has_prefix_cache to compute has_prefix_cache only at layer 0 - Remove unnecessary [:batch_size] slice from initial_state_indices in verify kernel - Add intermediate_state_indices field to ForwardMetadata for CUDA graph - Pre-allocate intermediate_state_indices_list in init_cuda_graph_state - Remove all debug logging code
Add log_tensor function to hybrid_linear_attn_backend.py for debugging precision differences between K-last (CuTe DSL) and V-last (Triton) kernels. Features: - Log tensor statistics (mean, std, min, max, abs_mean) and sample values - Support for verify and extend kernel input/output logging - Only logs on rank 0 in distributed mode - Skip logging during CUDA graph capture - Controlled by SGLANG_GDN_DEBUG=1 and SGLANG_GDN_DEBUG_DIR env vars Benchmark Results (256 prompts, input=128, output=1024): ┌─────────────────────┬──────────────┬──────────────┐ │ Metric │ K-last │ V-last │ ├─────────────────────┼──────────────┼──────────────┤ │ Accept Length │ 3.28 │ 3.17 │ │ Output Throughput │ 573.05 tok/s │ 633.68 tok/s │ │ Benchmark Duration │ 227.51s │ 205.75s │ │ mamba_ssm_k_last │ True │ False │ └─────────────────────┴──────────────┴──────────────┘ CUDA Graph Results (256 prompts, input=128, output=1024): ┌─────────────────────┬──────────────┬──────────────┐ │ Metric │ K-last │ V-last │ ├─────────────────────┼──────────────┼──────────────┤ │ Accept Length │ 3.29 │ 3.17 │ │ Output Throughput │ 2156.38 tok/s│ 2056.61 tok/s│ │ Benchmark Duration │ 60.46s │ 63.39s │ │ mamba_ssm_k_last │ True │ False │ └─────────────────────┴──────────────┴──────────────┘ How to run benchmark: # Set environment variables export SGLANG_GDN_DEBUG=1 export SGLANG_GDN_DEBUG_DIR=/path/to/log/dir # K-last server python -m sglang.launch_server --model-path /path/to/model --tp 2 --trust-remote-code --disable-radix-cache --mamba-ssm-k-last --disable-cuda-graph # V-last server (without --mamba-ssm-k-last) python -m sglang.launch_server --model-path /path/to/model --tp 2 --trust-remote-code --disable-radix-cache --disable-cuda-graph # Run benchmark python bench_serving.py --backend sglang --host 127.0.0.1 --port 30000 --model /path/to/model --dataset-name random --random-input 128 --random-output 1024 --num-prompts 256 --seed 42
Move the .item() call from forward_extend (called per layer) to init_forward_metadata (called once per batch) to reduce GPU-CPU sync overhead. Before: ~10ms overhead due to .item() sync in every layer After: negligible overhead (single sync at batch init) Profile results show K-last advantage improved from -0.4% to -1.3%.
…n V-last Key change for precision alignment (validated by ablation study): - Always quantize beta to activation dtype (bf16/fp16) in internal gating path (ablation: 0.26% argmax flip without this, +3pp on GSM8K) - Reuse Triton fused_gdn_gating for external g_log/beta to minimize numerical drift Removed unused ablation hooks and environment variables: - SGLANG_CUTEDSL_VERIFY_OUT_DTYPE (bf16 output was already in 72fb) - SGLANG_CUTEDSL_VERIFY_DISABLE_PRECOMPUTE_G_DECAY - SGLANG_CUTEDSL_VERIFY_L2NORM_DIV - SGLANG_CUTEDSL_VERIFY_QUANTIZE_BETA_TO_ACT E2E validation (GSM8K 100q): - K-last: 96% (up from 93% in 72fb, mainly due to beta quantization fix) - V-last: 97% - Diff: -1pp (within statistical noise) Also adds: - test_cutedsl_verify_vs_fused.py: correctness test comparing CuTe vs Triton - in_place_transpose.py: Triton kernel for K-last<->V-last layout conversion
…verify Ablation on GSM8K (200q) showed no accuracy benefit from computing gates (g_log/beta) via Triton fused_gdn_gating and feeding them into the CuTe verify kernel. This change removes the external gate precompute and always uses CuTe internal gating for K-last verify (topk=1).
Updated shared memory allocation for tensors in the CuTe verify kernel by removing padding for bank conflict avoidance. This change streamlines the allocation process for sQ, sK, and sV tensors, improving clarity and potentially reducing memory usage. Additionally, added a TODO comment to consider moving the computation of q and k into the main loop to further optimize shared memory usage.
- Keep FlashInfer prefill state in K-last [V, K] to match mamba pool. - Explicitly l2-normalize q/k since flashinfer wrapper ignores the flag.
- Remove torch.equal() check for cache_indices contiguity - Remove .item() call for prefix cache check - Simplify FlashInfer prefill to always use gather/scatter - Simplify Triton fallback to use copy-based approach This eliminates GPU-CPU synchronization in every layer forward, improving K-last prefill performance.
- Remove K-last transpose fallback in prefill path - Remove K-last transpose fallback in verify path - K-last now only works when native K-last kernel is available (FlashInfer prefill / CuTe DSL verify) - Otherwise fallback directly to V-last Triton kernel This simplifies the code and removes complex transpose logic.
- Gate Qwen3Next mamba SSM K-last layout on FlashInfer prefill + CuTe DSL verify - Remove prefill/verify transpose fallback paths (keep decode transpose) - Derive ssm_k_last from the actual mamba cache shape and cache kernel handles - Enforce K-last verify topk=1 (no retrieve_parent_token support) This ensures K-last is used end-to-end only with native K-last kernels; otherwise we fall back to the default V-last layout and Triton kernels.
- Fix Mamba2StateShape k_last layout swap direction (default matches base) - Simplify Qwen3Next cache shape to use --mamba-ssm-k-last directly - Drop FlashInfer backend tweaks unrelated to GDN/K-last - Remove profiling/env toggles in FlashInfer GDN prefill path - Always use in-place transpose in K-last decode when K==V
- Drop preallocated intermediate_state_indices from metadata/cuda-graph state - Keep target-verify path using on-the-fly torch.arange indices - Revert extend state tracking to base behavior
Add comprehensive benchmark script to compare K-last and V-last performance: - E2E serving benchmark with configurable workload - GSM8K and MMLU accuracy evaluation - Automatic server launch with speculative decoding support - Configurable GPU assignment and TP settings
…arison Add integrated torch profiling via HTTP API: - New --enable-profile flag to trigger profiling - Configurable --profile-num-steps and --profile-activities - Profiles saved to separate directories for K-last and V-last - Does not require --disable-cuda-graph (profiles with CUDA graph enabled)
- Add precompile_cutedsl_gdn_verify_kernels() for warmup-time kernel compilation - Pre-compile CuTe DSL kernels during GDN backend initialization to avoid JIT overhead - Add communication volume tracking with SGLANG_COMM_DEBUG environment variable - Fix profiling workflow in benchmark script (start -> run requests -> stop) - Add watchdog timeout for JIT compilation in benchmark server launch
- Change memory fraction from 0.80 to 0.70 in server launch configuration. - Update start_profiling function to include profile_by_stage parameter, defaulting to True. - Modify profiling JSON data to enable merging profiles and reflect the new profile_by_stage setting in the print statement.
…ions - Introduced `--chunked-prefill-size` parameter for server launch to optimize prefill operations. - Added `run_one_batch_profile` function for improved profiling of single batch requests. - Updated `run_benchmark` to utilize new profiling options, allowing for customizable batch size and input length during profiling. - Enhanced command-line argument parsing to support new profiling parameters. - Enabled device properties caching to reduce overhead during profiling and server operations.
- Add tensor cache for FlashInfer prefill state tensors to avoid cudaMalloc/cudaFree - Use squeeze(0) + conditional contiguous() instead of [0].contiguous() to avoid unnecessary copies - Use index_select with out= parameter when dtype matches to avoid temporary tensors - Use index_copy_ for state writeback to avoid intermediate tensor creation Benchmark results (num_prompts=256): - K-last throughput: 1456 tok/s (+6.6% vs V-last 1366 tok/s) - Accept length unchanged: 2.87 vs 2.90 (-1.0%) - cudaFree calls eliminated (0 vs previous)
…aches Remove state cache and output cache that were found unnecessary during binary search cleanup. The simplified code maintains same performance while being more readable. Benchmark (num_prompts=256): - K-last: 1450.87 tok/s, Accept Length: 2.87 - V-last: 1391.78 tok/s, Accept Length: 2.90 - Improvement: +4.2% throughput
…p, remove redundant casts - Fix precision: cast g[0] to float32 before torch.exp instead of after, avoiding bfloat16 precision loss in exponential computation - Consolidate _warmup_mtp_kernel and _warmup_decode_kernel into single _warmup_kernels function, sharing common tensor allocations (A_log, dt_bias) - Cast ssm_cache_indices to int64 once after torch.where, removing redundant .to(torch.int64) at the index_copy_ call site Benchmarked K-last spec and non-spec: no performance regression.
- Untrack .opencode/skills/gdn_benchmark.md (file kept on disk) - Remove python/sglang/srt/grpc/ directory
Revert unrelated grpc changes from the PR diff.
…warmup - Remove intermediate_state_indices pre-allocation from hybrid_linear_attn_backend.py (FlashInfer kernel accepts but never uses it; gdn_backend.py creates its own at runtime) - Remove _warmup_kernels() from gdn_flashinfer.py (CUDA graph capture already triggers the same kernel compilation; no other backend has warmup) - Remove intermediate_state_indices field from mamba2_metadata.py ForwardMetadata Benchmarked: no perf regression (non-spec BS32: 56558/3634, spec BS32: 54283/5245) MMLU 1000: non-spec 0.649, spec 0.648 (baseline 0.649)
Replace pool-indexed decode (state_indices parameter, requires PR#2521 source build) with explicit gather/scatter pattern that works with stock FlashInfer 0.6.3+ pip packages. Tested with FlashInfer 0.6.4: - Prefill: no regression vs source build (~0-2% noise) - Decode: ~7-9% regression from gather/scatter overhead - MMLU 1000: 0.651 (matches baseline 0.649)
- NOTE in module docstring: FlashInfer >= 0.6.4 fixes prefill perf regression via cudaGetDeviceProperties caching (PR#2509) - TODO in decode: switch to pool-indexed decode once FlashInfer PR#2521 is released, removing gather/scatter overhead (~7-9% regression)
|
/rerun-failed-ci |
The k_last parameter introduced a layout swap (H, K, V) for k_last=False that broke Mamba2 models (NemotronH, FalconH1) whose kernels expect (H, head_dim, state_size) = (H, V, K). Layout selection for FlashInfer GDN kernels should be handled by --linear-attn-backend instead.
|
Did you forget to checkin code? I have not found the new server args in the main. |
I have removed the new server args in the final version. Thanks for the reminder. This example hasn't been updated, so there is an error. |
|
Hi @xutizhou , I just tried this PR but encountered this error.
It seems FP16 is not supported for MTP yet. Do you have any plans to support it? |
…nd decode for Hopper (sgl-project#18361) Co-authored-by: HongliMi <106042350+HongliMi@users.noreply.github.com> Co-authored-by: xiaozhoupy <181108106+zhou9402@users.noreply.github.com> Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com> Co-authored-by: Avery Yingyi Huang <averyh@nvidia.com> Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
shiyu7 https://github.com/sgl-project/sglang/pull/19961/changes |
MTP doesn't support FP16 yet in the current implementation. The initial_state is expected to be float32 by flashinfer's gated_delta_rule_mtp. |
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
…nd decode for Hopper (sgl-project#18361) Co-authored-by: HongliMi <106042350+HongliMi@users.noreply.github.com> Co-authored-by: xiaozhoupy <181108106+zhou9402@users.noreply.github.com> Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com> Co-authored-by: Avery Yingyi Huang <averyh@nvidia.com> Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…nd decode for Hopper (sgl-project#18361) Co-authored-by: HongliMi <106042350+HongliMi@users.noreply.github.com> Co-authored-by: xiaozhoupy <181108106+zhou9402@users.noreply.github.com> Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com> Co-authored-by: Avery Yingyi Huang <averyh@nvidia.com> Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
DEPENDENCY NOTICE
This PR relies on upstream changes in FlashInfer that allow zero-copy pool indexing for the GDN decode kernel.
Motivation
Co-author: @HongliMi,@zhou9402,@liz-badada
The current GDN (Gated Delta Network) attention backend uses a V-last SSM state layout (
[pool_size, HV, K, V]) with Triton-based kernels for prefill, decode, and verify (MTP). This PR adds support for a K-last SSM state layout ([pool_size, HV, V, K]) that enables FlashInfer's CUTLASS-based GDN kernels, which deliver higher throughput and lower latency across all phases.The K-last layout is activated via the
--mamba-ssm-k-lastserver flag. When disabled, the existing V-last Triton path remains unchanged.Modifications
--mamba-ssm-k-lastflag toServerArgsto opt into K-last SSM state layoutchunk_gated_delta_rule) for K-last layout, replacing the Triton chunked prefill pathgated_delta_rule_decode) for K-last layout, replacing the Triton fused recurrent decode pathgated_delta_rule_mtp) for K-last layout, replacing the Triton fused recurrent verify pathintermediate_state_indicesfor CUDA graph compatibilityAccuracy Tests
Speculative decoding accept length is nearly identical between K-last and V-last, confirming K-last does not affect model output quality.
Benchmarking and Profiling
Model: Qwen3-Next-80B-A3B-Instruct, 8x NVIDIA H20, TP=8, EAGLE speculative decoding (num_steps=3, topk=1), input_len=1024, output_len=128.
Server Launch
V-last (baseline):
python -m sglang.launch_server \ --model-path Qwen3-Next-80B-A3B-Instruct \ --tp 8 --host 127.0.0.1 --port 30000 \ --trust-remote-code --mem-fraction-static 0.70 --disable-radix-cache \ --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1K-last:
python -m sglang.launch_server \ --model-path Qwen3-Next-80B-A3B-Instruct \ --tp 8 --host 127.0.0.1 --port 30000 \ --trust-remote-code --mem-fraction-static 0.70 --disable-radix-cache \ --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 \ --linear-attn-backend flashinferBenchmark Command
python -m sglang.bench_one_batch_server \ --model None --base-url http://127.0.0.1:30000 \ --batch-size 1 4 8 16 32 --input-len 1024 --output-len 128 --show-reportLatency (seconds, lower is better)
TTFT (seconds, lower is better)
Prefill Throughput (input tok/s, higher is better)
Decode Throughput (output tok/s, higher is better)
flashinfer-ai/flashinfer#2521
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci