Skip to content

feat(gdn): add FlashInfer K-last SSM layout support for GDN prefill and decode for Hopper#18361

Merged
ispobock merged 64 commits intosgl-project:mainfrom
xutizhou:feat/gdn_mtp
Mar 3, 2026
Merged

feat(gdn): add FlashInfer K-last SSM layout support for GDN prefill and decode for Hopper#18361
ispobock merged 64 commits intosgl-project:mainfrom
xutizhou:feat/gdn_mtp

Conversation

@xutizhou
Copy link
Copy Markdown
Collaborator

@xutizhou xutizhou commented Feb 6, 2026

DEPENDENCY NOTICE
This PR relies on upstream changes in FlashInfer that allow zero-copy pool indexing for the GDN decode kernel.

Motivation

Co-author: @HongliMi,@zhou9402,@liz-badada
The current GDN (Gated Delta Network) attention backend uses a V-last SSM state layout ([pool_size, HV, K, V]) with Triton-based kernels for prefill, decode, and verify (MTP). This PR adds support for a K-last SSM state layout ([pool_size, HV, V, K]) that enables FlashInfer's CUTLASS-based GDN kernels, which deliver higher throughput and lower latency across all phases.

The K-last layout is activated via the --mamba-ssm-k-last server flag. When disabled, the existing V-last Triton path remains unchanged.

Modifications

  • Add --mamba-ssm-k-last flag to ServerArgs to opt into K-last SSM state layout
  • Integrate FlashInfer GDN prefill kernel (chunk_gated_delta_rule) for K-last layout, replacing the Triton chunked prefill path
  • Integrate FlashInfer GDN decode kernel (gated_delta_rule_decode) for K-last layout, replacing the Triton fused recurrent decode path
  • Integrate FlashInfer GDN verify/MTP kernel (gated_delta_rule_mtp) for K-last layout, replacing the Triton fused recurrent verify path
  • Add kernel warmup during server startup for both MTP and decode kernels to eliminate JIT compilation overhead (~4s) on the first request
  • Pre-allocate intermediate_state_indices for CUDA graph compatibility

Accuracy Tests

Speculative decoding accept length is nearly identical between K-last and V-last, confirming K-last does not affect model output quality.

Batch V-last K-last
8 3.81 3.79
16 3.70 3.75
32 3.75 3.75

Benchmarking and Profiling

Model: Qwen3-Next-80B-A3B-Instruct, 8x NVIDIA H20, TP=8, EAGLE speculative decoding (num_steps=3, topk=1), input_len=1024, output_len=128.

Server Launch

V-last (baseline):

python -m sglang.launch_server \
    --model-path Qwen3-Next-80B-A3B-Instruct \
    --tp 8 --host 127.0.0.1 --port 30000 \
    --trust-remote-code --mem-fraction-static 0.70 --disable-radix-cache \
    --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1

K-last:

python -m sglang.launch_server \
    --model-path Qwen3-Next-80B-A3B-Instruct \
    --tp 8 --host 127.0.0.1 --port 30000 \
    --trust-remote-code --mem-fraction-static 0.70 --disable-radix-cache \
    --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 \
    --linear-attn-backend flashinfer

Benchmark Command

python -m sglang.bench_one_batch_server \
    --model None --base-url http://127.0.0.1:30000 \
    --batch-size 1 4 8 16 32 --input-len 1024 --output-len 128 --show-report

Latency (seconds, lower is better)

Batch V-last K-last Change
1 0.405 0.375 -7.5%
4 0.504 0.481 -4.5%
8 0.670 0.675 +0.8%
16 1.051 0.960 -8.6%
32 1.527 1.483 -2.9%

TTFT (seconds, lower is better)

Batch V-last K-last Change
1 0.112 0.096 -14.3%
4 0.126 0.117 -7.1%
8 0.189 0.184 -2.6%
16 0.343 0.332 -3.2%
32 0.666 0.652 -2.1%

Prefill Throughput (input tok/s, higher is better)

Batch V-last K-last Change
1 9,179 10,705 +16.6%
4 32,530 35,055 +7.8%
8 43,278 44,535 +2.9%
16 47,720 49,365 +3.4%
32 49,177 50,229 +2.1%

Decode Throughput (output tok/s, higher is better)

Batch V-last K-last Change
1 436 459 +5.3%
4 1,355 1,405 +3.7%
8 2,131 2,084 -2.2%
16 2,893 3,259 +12.6%
32 4,757 4,931 +3.7%

flashinfer-ai/flashinfer#2521

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

xutizhou and others added 30 commits January 11, 2026 18:27
This commit introduces a new file for the CuTe DSL GDN Verify kernel, optimized for K-last layout. The kernel supports multiple tokens in sequence, caches intermediate states for verification, and maintains gate computation within the kernel. Additionally, it includes configuration updates to support K-last layout in Mamba2StateShape and Qwen3NextConfig, as well as adjustments in the GDNAttnBackend to utilize the new kernel when applicable.

Key features:
- Efficient memory access with K-last layout.
- Support for speculative decoding verification.
- Integration into existing attention mechanisms.

This enhancement aims to improve performance and flexibility in handling GDN operations in neural network architectures.
- Optimize extend path: skip transpose for first prefill (zero initial state)
- Fix CUDA Graph compatibility: guard debug logging with _debug_enabled flag
- Use actual pool_size/cache_steps for kernel compilation
- Add K-last layout support in hybrid_linear_attn_backend
- Verified accuracy with different batch sizes (4, 8, 16)
- Add sglang_gdn_verify_kernel with optimized K-last memory layout
- Support bfloat16 computation with proper dtype handling
- Add early exit optimization for CUDA Graph padding slots (cache_idx < 0)
- Integrate with hybrid_linear_attn_backend for MTP speculative decoding
- Add warning when K-last is used without speculative decoding
- Enable debug logging via SGLANG_GDN_DEBUG environment variable
…intermediate indices

- Add _cached_has_prefix_cache to compute has_prefix_cache only at layer 0
- Remove unnecessary [:batch_size] slice from initial_state_indices in verify kernel
- Add intermediate_state_indices field to ForwardMetadata for CUDA graph
- Pre-allocate intermediate_state_indices_list in init_cuda_graph_state
- Remove all debug logging code
Add log_tensor function to hybrid_linear_attn_backend.py for debugging
precision differences between K-last (CuTe DSL) and V-last (Triton) kernels.

Features:
- Log tensor statistics (mean, std, min, max, abs_mean) and sample values
- Support for verify and extend kernel input/output logging
- Only logs on rank 0 in distributed mode
- Skip logging during CUDA graph capture
- Controlled by SGLANG_GDN_DEBUG=1 and SGLANG_GDN_DEBUG_DIR env vars

Benchmark Results (256 prompts, input=128, output=1024):
┌─────────────────────┬──────────────┬──────────────┐
│ Metric              │ K-last       │ V-last       │
├─────────────────────┼──────────────┼──────────────┤
│ Accept Length       │ 3.28         │ 3.17         │
│ Output Throughput   │ 573.05 tok/s │ 633.68 tok/s │
│ Benchmark Duration  │ 227.51s      │ 205.75s      │
│ mamba_ssm_k_last    │ True         │ False        │
└─────────────────────┴──────────────┴──────────────┘

CUDA Graph Results (256 prompts, input=128, output=1024):
┌─────────────────────┬──────────────┬──────────────┐
│ Metric              │ K-last       │ V-last       │
├─────────────────────┼──────────────┼──────────────┤
│ Accept Length       │ 3.29         │ 3.17         │
│ Output Throughput   │ 2156.38 tok/s│ 2056.61 tok/s│
│ Benchmark Duration  │ 60.46s       │ 63.39s       │
│ mamba_ssm_k_last    │ True         │ False        │
└─────────────────────┴──────────────┴──────────────┘

How to run benchmark:
  # Set environment variables
  export SGLANG_GDN_DEBUG=1
  export SGLANG_GDN_DEBUG_DIR=/path/to/log/dir

  # K-last server
  python -m sglang.launch_server     --model-path /path/to/model     --tp 2 --trust-remote-code     --disable-radix-cache --mamba-ssm-k-last --disable-cuda-graph

  # V-last server (without --mamba-ssm-k-last)
  python -m sglang.launch_server     --model-path /path/to/model     --tp 2 --trust-remote-code     --disable-radix-cache --disable-cuda-graph

  # Run benchmark
  python bench_serving.py     --backend sglang --host 127.0.0.1 --port 30000     --model /path/to/model --dataset-name random     --random-input 128 --random-output 1024     --num-prompts 256 --seed 42
Move the .item() call from forward_extend (called per layer) to
init_forward_metadata (called once per batch) to reduce GPU-CPU sync overhead.

Before: ~10ms overhead due to .item() sync in every layer
After: negligible overhead (single sync at batch init)

Profile results show K-last advantage improved from -0.4% to -1.3%.
…n V-last

Key change for precision alignment (validated by ablation study):
- Always quantize beta to activation dtype (bf16/fp16) in internal gating path
  (ablation: 0.26% argmax flip without this, +3pp on GSM8K)
- Reuse Triton fused_gdn_gating for external g_log/beta to minimize numerical drift

Removed unused ablation hooks and environment variables:
- SGLANG_CUTEDSL_VERIFY_OUT_DTYPE (bf16 output was already in 72fb)
- SGLANG_CUTEDSL_VERIFY_DISABLE_PRECOMPUTE_G_DECAY
- SGLANG_CUTEDSL_VERIFY_L2NORM_DIV
- SGLANG_CUTEDSL_VERIFY_QUANTIZE_BETA_TO_ACT

E2E validation (GSM8K 100q):
- K-last: 96% (up from 93% in 72fb, mainly due to beta quantization fix)
- V-last: 97%
- Diff: -1pp (within statistical noise)

Also adds:
- test_cutedsl_verify_vs_fused.py: correctness test comparing CuTe vs Triton
- in_place_transpose.py: Triton kernel for K-last<->V-last layout conversion
…verify

Ablation on GSM8K (200q) showed no accuracy benefit from computing gates (g_log/beta)
via Triton fused_gdn_gating and feeding them into the CuTe verify kernel.

This change removes the external gate precompute and always uses CuTe internal gating
for K-last verify (topk=1).
Updated shared memory allocation for tensors in the CuTe verify kernel by removing padding for bank conflict avoidance. This change streamlines the allocation process for sQ, sK, and sV tensors, improving clarity and potentially reducing memory usage. Additionally, added a TODO comment to consider moving the computation of q and k into the main loop to further optimize shared memory usage.
- Keep FlashInfer prefill state in K-last [V, K] to match mamba pool.
- Explicitly l2-normalize q/k since flashinfer wrapper ignores the flag.
- Remove torch.equal() check for cache_indices contiguity
- Remove .item() call for prefix cache check
- Simplify FlashInfer prefill to always use gather/scatter
- Simplify Triton fallback to use copy-based approach

This eliminates GPU-CPU synchronization in every layer forward,
improving K-last prefill performance.
- Remove K-last transpose fallback in prefill path
- Remove K-last transpose fallback in verify path
- K-last now only works when native K-last kernel is available
  (FlashInfer prefill / CuTe DSL verify)
- Otherwise fallback directly to V-last Triton kernel

This simplifies the code and removes complex transpose logic.
- Gate Qwen3Next mamba SSM K-last layout on FlashInfer prefill + CuTe DSL verify
- Remove prefill/verify transpose fallback paths (keep decode transpose)
- Derive ssm_k_last from the actual mamba cache shape and cache kernel handles
- Enforce K-last verify topk=1 (no retrieve_parent_token support)

This ensures K-last is used end-to-end only with native K-last kernels;
otherwise we fall back to the default V-last layout and Triton kernels.
- Fix Mamba2StateShape k_last layout swap direction (default matches base)
- Simplify Qwen3Next cache shape to use --mamba-ssm-k-last directly
- Drop FlashInfer backend tweaks unrelated to GDN/K-last
- Remove profiling/env toggles in FlashInfer GDN prefill path
- Always use in-place transpose in K-last decode when K==V
- Drop preallocated intermediate_state_indices from metadata/cuda-graph state
- Keep target-verify path using on-the-fly torch.arange indices
- Revert extend state tracking to base behavior
Add comprehensive benchmark script to compare K-last and V-last performance:
- E2E serving benchmark with configurable workload
- GSM8K and MMLU accuracy evaluation
- Automatic server launch with speculative decoding support
- Configurable GPU assignment and TP settings
…arison

Add integrated torch profiling via HTTP API:
- New --enable-profile flag to trigger profiling
- Configurable --profile-num-steps and --profile-activities
- Profiles saved to separate directories for K-last and V-last
- Does not require --disable-cuda-graph (profiles with CUDA graph enabled)
- Add precompile_cutedsl_gdn_verify_kernels() for warmup-time kernel compilation
- Pre-compile CuTe DSL kernels during GDN backend initialization to avoid JIT overhead
- Add communication volume tracking with SGLANG_COMM_DEBUG environment variable
- Fix profiling workflow in benchmark script (start -> run requests -> stop)
- Add watchdog timeout for JIT compilation in benchmark server launch
- Change memory fraction from 0.80 to 0.70 in server launch configuration.
- Update start_profiling function to include profile_by_stage parameter, defaulting to True.
- Modify profiling JSON data to enable merging profiles and reflect the new profile_by_stage setting in the print statement.
…ions

- Introduced `--chunked-prefill-size` parameter for server launch to optimize prefill operations.
- Added `run_one_batch_profile` function for improved profiling of single batch requests.
- Updated `run_benchmark` to utilize new profiling options, allowing for customizable batch size and input length during profiling.
- Enhanced command-line argument parsing to support new profiling parameters.
- Enabled device properties caching to reduce overhead during profiling and server operations.
- Add tensor cache for FlashInfer prefill state tensors to avoid cudaMalloc/cudaFree
- Use squeeze(0) + conditional contiguous() instead of [0].contiguous() to avoid unnecessary copies
- Use index_select with out= parameter when dtype matches to avoid temporary tensors
- Use index_copy_ for state writeback to avoid intermediate tensor creation

Benchmark results (num_prompts=256):
- K-last throughput: 1456 tok/s (+6.6% vs V-last 1366 tok/s)
- Accept length unchanged: 2.87 vs 2.90 (-1.0%)
- cudaFree calls eliminated (0 vs previous)
…aches

Remove state cache and output cache that were found unnecessary during
binary search cleanup. The simplified code maintains same performance
while being more readable.

Benchmark (num_prompts=256):
- K-last: 1450.87 tok/s, Accept Length: 2.87
- V-last: 1391.78 tok/s, Accept Length: 2.90
- Improvement: +4.2% throughput
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 26, 2026
…p, remove redundant casts

- Fix precision: cast g[0] to float32 before torch.exp instead of after,
  avoiding bfloat16 precision loss in exponential computation
- Consolidate _warmup_mtp_kernel and _warmup_decode_kernel into single
  _warmup_kernels function, sharing common tensor allocations (A_log, dt_bias)
- Cast ssm_cache_indices to int64 once after torch.where, removing
  redundant .to(torch.int64) at the index_copy_ call site

Benchmarked K-last spec and non-spec: no performance regression.
- Untrack .opencode/skills/gdn_benchmark.md (file kept on disk)
- Remove python/sglang/srt/grpc/ directory
Revert unrelated grpc changes from the PR diff.
…warmup

- Remove intermediate_state_indices pre-allocation from hybrid_linear_attn_backend.py
  (FlashInfer kernel accepts but never uses it; gdn_backend.py creates its own at runtime)
- Remove _warmup_kernels() from gdn_flashinfer.py (CUDA graph capture already triggers
  the same kernel compilation; no other backend has warmup)
- Remove intermediate_state_indices field from mamba2_metadata.py ForwardMetadata

Benchmarked: no perf regression (non-spec BS32: 56558/3634, spec BS32: 54283/5245)
MMLU 1000: non-spec 0.649, spec 0.648 (baseline 0.649)
Replace pool-indexed decode (state_indices parameter, requires PR#2521
source build) with explicit gather/scatter pattern that works with stock
FlashInfer 0.6.3+ pip packages.

Tested with FlashInfer 0.6.4:
- Prefill: no regression vs source build (~0-2% noise)
- Decode: ~7-9% regression from gather/scatter overhead
- MMLU 1000: 0.651 (matches baseline 0.649)
- NOTE in module docstring: FlashInfer >= 0.6.4 fixes prefill perf
  regression via cudaGetDeviceProperties caching (PR#2509)
- TODO in decode: switch to pool-indexed decode once FlashInfer PR#2521
  is released, removing gather/scatter overhead (~7-9% regression)
@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

xutizhou added 5 commits March 1, 2026 18:26
The k_last parameter introduced a layout swap (H, K, V) for k_last=False
that broke Mamba2 models (NemotronH, FalconH1) whose kernels expect
(H, head_dim, state_size) = (H, V, K). Layout selection for FlashInfer
GDN kernels should be handled by --linear-attn-backend instead.
@ispobock ispobock merged commit c6377bb into sgl-project:main Mar 3, 2026
125 of 135 checks passed
@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 3, 2026

Did you forget to checkin code? I have not found the new server args in the main.

➜  python git:(fuse_sigmoid_gating) ✗ python -m sglang.launch_server \
    --model-path Qwen3-Next-80B-A3B-Instruct \
    --tp 8 --host 127.0.0.1 --port 30000 \
    --trust-remote-code --mem-fraction-static 0.70 --disable-radix-cache \
    --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 \
    --mamba-ssm-k-last

......
                        [--remote-instance-weight-loader-send-weights-group-ports REMOTE_INSTANCE_WEIGHT_LOADER_SEND_WEIGHTS_GROUP_PORTS] [--remote-instance-weight-loader-backend {transfer_engine,nccl}]
                        [--remote-instance-weight-loader-start-seed-via-transfer-engine] [--enable-pdmux] [--pdmux-config-path PDMUX_CONFIG_PATH] [--sm-group-num SM_GROUP_NUM] [--config CONFIG]
                        [--mm-max-concurrent-calls MM_MAX_CONCURRENT_CALLS] [--mm-per-request-timeout MM_PER_REQUEST_TIMEOUT] [--enable-broadcast-mm-inputs-process] [--mm-process-config MM_PROCESS_CONFIG] [--mm-enable-dp-encoder]
                        [--limit-mm-data-per-request LIMIT_MM_DATA_PER_REQUEST] [--decrypted-config-file DECRYPTED_CONFIG_FILE] [--decrypted-draft-config-file DECRYPTED_DRAFT_CONFIG_FILE] [--enable-prefix-mm-cache] [--enable-mm-global-cache]
                        [--forward-hooks FORWARD_HOOKS]
launch_server.py: error: unrecognized arguments: --mamba-ssm-k-last

@xutizhou
Copy link
Copy Markdown
Collaborator Author

xutizhou commented Mar 4, 2026

Did you forget to checkin code? I have not found the new server args in the main.

➜  python git:(fuse_sigmoid_gating) ✗ python -m sglang.launch_server \
    --model-path Qwen3-Next-80B-A3B-Instruct \
    --tp 8 --host 127.0.0.1 --port 30000 \
    --trust-remote-code --mem-fraction-static 0.70 --disable-radix-cache \
    --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 \
    --mamba-ssm-k-last

......
                        [--remote-instance-weight-loader-send-weights-group-ports REMOTE_INSTANCE_WEIGHT_LOADER_SEND_WEIGHTS_GROUP_PORTS] [--remote-instance-weight-loader-backend {transfer_engine,nccl}]
                        [--remote-instance-weight-loader-start-seed-via-transfer-engine] [--enable-pdmux] [--pdmux-config-path PDMUX_CONFIG_PATH] [--sm-group-num SM_GROUP_NUM] [--config CONFIG]
                        [--mm-max-concurrent-calls MM_MAX_CONCURRENT_CALLS] [--mm-per-request-timeout MM_PER_REQUEST_TIMEOUT] [--enable-broadcast-mm-inputs-process] [--mm-process-config MM_PROCESS_CONFIG] [--mm-enable-dp-encoder]
                        [--limit-mm-data-per-request LIMIT_MM_DATA_PER_REQUEST] [--decrypted-config-file DECRYPTED_CONFIG_FILE] [--decrypted-draft-config-file DECRYPTED_DRAFT_CONFIG_FILE] [--enable-prefix-mm-cache] [--enable-mm-global-cache]
                        [--forward-hooks FORWARD_HOOKS]
launch_server.py: error: unrecognized arguments: --mamba-ssm-k-last

I have removed the new server args in the final version. Thanks for the reminder. This example hasn't been updated, so there is an error.

@shiyu7
Copy link
Copy Markdown
Contributor

shiyu7 commented Mar 4, 2026

Hi @xutizhou , I just tried this PR but encountered this error.

File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_5.py", line 270, in forward
core_attn_out = self.attn(
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/radix_linear_attention.py", line 95, in forward
return forward_batch.attn_backend.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 918, in forward
return self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 869, in forward_extend
return self.linear_attn_backend.forward_extend(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/gdn_backend.py", line 370, in forward_extend
core_attn_out = self.kernel_dispatcher.target_verify(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/gdn_backend.py", line 185, in target_verify
return self.verify_kernel.target_verify(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py", line 313, in target_verify
output_fi, _ = self._mtp_fn(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/gdn_decode.py", line 2572, in gated_delta_rule_mtp
assert initial_state.dtype == torch.float32, (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: initial_state must be float32, got torch.float16

It seems FP16 is not supported for MTP yet. Do you have any plans to support it?

Kangyan-Zhou pushed a commit to Kangyan-Zhou/sglang that referenced this pull request Mar 4, 2026
…nd decode for Hopper (sgl-project#18361)

Co-authored-by: HongliMi <106042350+HongliMi@users.noreply.github.com>
Co-authored-by: xiaozhoupy <181108106+zhou9402@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Avery Yingyi Huang <averyh@nvidia.com>
Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
@Swipe4057
Copy link
Copy Markdown
Contributor

Swipe4057 commented Mar 7, 2026

Hi @xutizhou , I just tried this PR but encountered this error.

File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_5.py", line 270, in forward
core_attn_out = self.attn(
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/radix_linear_attention.py", line 95, in forward
return forward_batch.attn_backend.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 918, in forward
return self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 869, in forward_extend
return self.linear_attn_backend.forward_extend(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/gdn_backend.py", line 370, in forward_extend
core_attn_out = self.kernel_dispatcher.target_verify(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/gdn_backend.py", line 185, in target_verify
return self.verify_kernel.target_verify(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py", line 313, in target_verify
output_fi, _ = self._mtp_fn(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/gdn_decode.py", line 2572, in gated_delta_rule_mtp
assert initial_state.dtype == torch.float32, (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: initial_state must be float32, got torch.float16

It seems FP16 is not supported for MTP yet. Do you have any plans to support it?

shiyu7 https://github.com/sgl-project/sglang/pull/19961/changes

@xutizhou
Copy link
Copy Markdown
Collaborator Author

xutizhou commented Mar 7, 2026

Hi @xutizhou , I just tried this PR but encountered this error.

File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_5.py", line 270, in forward
core_attn_out = self.attn(
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/radix_linear_attention.py", line 95, in forward
return forward_batch.attn_backend.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 918, in forward
return self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 869, in forward_extend
return self.linear_attn_backend.forward_extend(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/gdn_backend.py", line 370, in forward_extend
core_attn_out = self.kernel_dispatcher.target_verify(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/gdn_backend.py", line 185, in target_verify
return self.verify_kernel.target_verify(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py", line 313, in target_verify
output_fi, _ = self._mtp_fn(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/gdn_decode.py", line 2572, in gated_delta_rule_mtp
assert initial_state.dtype == torch.float32, (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: initial_state must be float32, got torch.float16

It seems FP16 is not supported for MTP yet. Do you have any plans to support it?

MTP doesn't support FP16 yet in the current implementation. The initial_state is expected to be float32 by flashinfer's gated_delta_rule_mtp.

yzh119 pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Mar 9, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state  
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)


## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
brandonmmusic-max pushed a commit to brandonmmusic-max/flashinfer that referenced this pull request Mar 9, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)

## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
…nd decode for Hopper (sgl-project#18361)

Co-authored-by: HongliMi <106042350+HongliMi@users.noreply.github.com>
Co-authored-by: xiaozhoupy <181108106+zhou9402@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Avery Yingyi Huang <averyh@nvidia.com>
Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state  
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)


## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)

## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…nd decode for Hopper (sgl-project#18361)

Co-authored-by: HongliMi <106042350+HongliMi@users.noreply.github.com>
Co-authored-by: xiaozhoupy <181108106+zhou9402@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Avery Yingyi Huang <averyh@nvidia.com>
Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state  
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)


## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants