feat: [Qwen3-Next] Add Cute DSL GDN decode kernel and tests#2370
feat: [Qwen3-Next] Add Cute DSL GDN decode kernel and tests#2370yzh119 merged 10 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds a new CUDA‑accelerated Gated Delta Rule (GDN) decode implementation (pretranspose, non‑transpose, MTP), a profiler-driven benchmarking CLI, Python reference/verifier implementations, extensive unit tests, and updates to test conftest compute-capability gating. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI as "Benchmark CLI\n(bench_gdn_decode.py)"
participant Profiler as "Torch Profiler"
participant GPU as "GPU Kernels\n(pretranspose/nontranspose/MTP)"
participant TraceParser as "Trace Parser\n(parse_trace_file)"
participant Metrics as "Metrics Calculator\n(TFLOPS / TB/s)"
User->>CLI: run benchmark (config)
CLI->>GPU: validate device capabilities
CLI->>GPU: run warmup iterations
CLI->>Profiler: start profiling
loop bench iterations
CLI->>GPU: launch decode / MTP kernel
GPU-->>GPU: execute kernel, update state
end
Profiler->>Profiler: export chrome trace
CLI->>TraceParser: parse trace for kernel timings
TraceParser-->>CLI: return kernel timing metrics
CLI->>Metrics: compute statistics and throughput
Metrics-->>CLI: return results
CLI-->>User: display results table
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @HongliMi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by integrating highly optimized Gated Delta Rule (GDN) decode kernels. These kernels, developed using CuTe DSL for Hopper GPUs, enable efficient linear attention decoding, crucial for modern large language models like Qwen3-Next. The PR introduces three distinct kernel versions tailored for different state layouts and processing modes (single-token and multi-token verification), ensuring both performance and flexibility. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces three new Gated Delta Rule (GDN) decode kernels using CuTe DSL, along with comprehensive benchmarks and tests. The changes are well-structured and the implementation is advanced. My review focuses on improving code reuse in the benchmark script and fixing critical bugs in the nontranspose kernel implementations.
Key feedback points:
- Critical Bug: The
nontransposedecode kernels (small_batchandbig_batch) contain loops with incorrectrange()arguments, which will likely cause them to not execute and produce incorrect results. - Code Duplication: The benchmark script has duplicated logic for parsing profiler traces and unused variables. I've suggested refactoring to improve maintainability and clarity.
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 2141-2148: When intermediate_states_buffer is provided, validate
that its second dimension (cache_steps) is at least T before reshaping/using it:
check intermediate_states_buffer.shape[1] >= T and raise a clear ValueError (or
similar) if not; keep the existing branch that sets cache_intermediate_states
and buffer_size, compute cache_steps = intermediate_states_buffer.shape[1], and
perform the validation prior to calling
intermediate_states_buffer.to(...).reshape(...), referencing the variables
intermediate_states_buffer, cache_steps, T, and intermediate_states so the check
is colocated with the existing logic.
- Around line 839-865: The cache key for compiled decode kernels omits
compile-time parameters causing reuse of kernels with wrong specializations;
update the cache_key tuple used before calling _get_compiled_decode_kernel to
include scale and use_qk_l2norm (and any other constexpr compile parameters) for
the pretranspose path (where cache_key = (B, T, H, HV, K, V, q.dtype)), and make
the same change in the nontranspose and MTP branches that build cache_key, so
the cached entry uniquely identifies the cute.compile call
(run_gdn_decode_kernel_small_batch_pretranspose,
run_gdn_decode_kernel_big_batch_pretranspose, and corresponding nontranspose/MTP
run functions) which are invoked with cute.compile(..., scale=scale,
use_qk_l2norm=use_qk_l2norm, ...).
- Around line 881-883: Remove the blocking device-wide synchronization calls
(torch.cuda.synchronize()) in the per-token decode hot path; specifically delete
the torch.cuda.synchronize() just before the state.copy_(h0_source.reshape(B,
HV, V, K)) and the analogous calls at the other two locations referenced in the
review, relying on PyTorch's stream ordering so that state.copy_() on the
current stream properly orders kernels without a global sync; ensure no
subsequent code assumes the global sync was required and that state.copy_ and
downstream ops remain on the same stream.
- Around line 792-799: The pretranspose kernel can write out-of-bounds because K
and V constraints are not validated and the compute loop lacks the MTP-style
bounds check; add assertions after unpacking shapes (where q, v, state are used)
- assert K >= 128 and assert V % 4 == 0 (vec_size = TILE_K // 32 == 4
requirement) - to ensure vectorized loads are safe, and modify the pretranspose
kernel's compute loop to guard the sOutput write with the same pattern as the
MTP kernel (e.g. check lane_id == 0 and o_idx < V before writing sOutput[o_idx])
so writes cannot exceed V.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 69-74: Add GPU-architecture skip guards to the tests in
tests/gdn/test_decode_delta_rule.py by using flashinfer.utils helpers (e.g.,
get_compute_capability) and pytest.mark.skipif so the test only runs on
supported GPUs; locate the test function(s) that use the device context manager
(the "with device:" block and related variables q,k,v) and add a skip condition
that checks compute capability (via get_compute_capability) or an appropriate
flashinfer.utils predicate before running the test, using
pytest.mark.skipif(...) to decorate the test or wrapping the test body with an
early pytest.skip when the GPU architecture is unsupported.
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1083-1084: The read using v_global (computed as v_tile *
TILE_V_SMALL_NT + v_idx) can go out of bounds when V % TILE_V_SMALL_NT != 0;
update the code that computes and uses v_global (and the analogous logic in the
big batch kernel) to either validate input (e.g., assert V % TILE_V_SMALL_NT ==
0) or guard the index access with a bounds check before calling v[i_n, 0, i_hv,
v_global], and handle the tail tile properly (skip, pad, or clamp) so no
out-of-range indexing occurs.
- Around line 1973-1974: The MTP writeback loop writes to o[(i_n, i_t, i_hv,
tidx)] for all tidx (0-127) without checking the actual V dimension; add a
bounds check using the V variable (or the output tensor's last-dim size) and
only perform the assignment when tidx < V to avoid out-of-bounds writes (same
pattern as the pretranspose kernel). Locate the write in gdn_decode.py (the loop
over i_t that assigns cutlass.BFloat16(sOutput[(i_t, tidx)])) and guard that
assignment with a conditional on tidx < V (or skip/zero-pad for tidx >= V) so
only valid channels are written.
- Around line 803-805: Add a check enforcing V >= 128 (e.g. assert V >= 128, f"V
must be at least 128, got V={V}") because kernels assume V>=128; insert this
validation alongside the existing K and V checks in the top-level validation
block (the snippet with K and V asserts) and also add the same assert at the
start of the functions gated_delta_rule_decode and gated_delta_rule_mtp so their
local kernel assumptions are validated.
- Around line 548-550: The final writeback to output tensor o in gdn_decode.py
can write out-of-bounds when tidx >= V; add a bounds check before the write in
the big-batch kernel so you only assign o[(i_n, i_t, i_hv, tidx)] =
sOutput[tidx] when tidx < V (or otherwise clamp/mask writes to V), similar to
the partial check present in the small-batch kernel; update the write
surrounding the existing cute.arch.barrier() and reference the variables
NUM_THREADS, V, sOutput, o, i_n, i_t, i_hv and tidx when making the conditional
guard.
- Around line 202-207: The code assumes V >= 128 but only validates V % 4 == 0;
add an explicit assertion in both gated_delta_rule_decode_pretranspose and
gated_delta_rule_decode immediately after the existing API validation block to
prevent out-of-bounds accesses when loading v into sV and indexing with i * 32 +
lane_id: assert V >= 128, f"V must be at least 128, got V={V}". This ensures the
vec_size/TILE_K loading loop and shared-memory writes to sV are safe.
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)
600-600: Remove unused variable to satisfy linter.
total_data_mbis calculated but never used (same at line 682).- total_data_mb = v_dim * k_dim * batch_size * 4 / 1024 / 1024
2164-2166: Consider removing.contiguous()for consistency with similar reshape patterns in the same function.The
intermediate_statestensor follows the same operation pattern ash0_source(both:.to(torch.float32).reshape(...)), but onlyintermediate_statesincludes.contiguous(). Since PyTorch'sreshape()returns a view when the requested shape is compatible with memory layout, and.contiguous()is a no-op if the tensor is already contiguous, the.contiguous()call here is defensive but likely unnecessary.Both tensors are passed to
from_dlpack()for CuTe kernel conversion. If this operation requires contiguity, both should include the call; if not, removing it fromintermediate_stateswould improve code consistency.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 900-907: Remove all Git merge conflict markers (<<<<<<<, =======,
>>>>>>>) in flashinfer/gdn_decode.py and resolve each conflict by keeping the
newer assertion checks replacing the old vectorized-load assertion: retain
"assert V >= 128" and "assert V % TILE_V == 0" for the V validation near the top
(the block touching V and TILE_V), and apply the same conflict resolution
pattern at the other listed locations (the small-batch nontranspose kernel loop,
the nontranspose API validation, and the MTP API validation) so the file
contains only valid Python assertions and logic with no leftover conflict
markers.
♻️ Duplicate comments (1)
flashinfer/gdn_decode.py (1)
1195-1196: Bounds safety depends on proper V validation.The access
v[i_n, 0, i_hv, v_global]wherev_global = v_tile * TILE_V_SMALL_NT + v_idxcould go out-of-bounds if V is not divisible by the tile size. The validation being added in the merge conflict resolution (V % TILE_V_NT == 0) will properly guard this, sinceTILE_V_NT=32is a multiple ofTILE_V_SMALL_NT=16.Ensure the merge conflict is resolved to include the
V % TILE_V_NT == 0assertion.
🧹 Nitpick comments (3)
flashinfer/gdn_decode.py (1)
626-626: Remove dead expression.This line computes a value but discards it (not assigned to any variable). This appears to be leftover from when it was assigned to
total_data_mbfor the commented-out debug print below.- v_dim * k_dim * batch_size * 4 / 1024 / 1024The same issue exists at line 735.
benchmarks/bench_gdn_decode.py (2)
78-112: Consider removing or documenting unusednum_k_headsparameter.The
num_k_headsparameter is never used in the FLOPs calculation. This is technically correct since FLOPs are determined by output heads (max(num_q_heads, num_v_heads)), not key heads. However, having an unused parameter can be confusing.Options:
- Remove the parameter if it's not needed for API consistency
- Add a comment explaining why it's intentionally unused (e.g.,
_ = num_k_heads # Unused: FLOPs depend on output heads)- Keep as-is for API consistency with
gdn_decode_bytes
653-656: Consider adding else clause for defensive coding.While the control flow guarantees only 'pretranspose' or 'nontranspose' can reach this point (due to the
continueat line 623), adding an else clause would make the code more robust against future changes.# Determine which kernel variant was used (based on batch size threshold) if version == "pretranspose": kernel_variant = "SmallBatch" if batch_size <= 32 else "LargeBatch" elif version == "nontranspose": kernel_variant = "SmallBatch" if batch_size < 32 else "LargeBatch" + else: + kernel_variant = "Unknown"
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 565-573: The final writeback only writes indices tidx (0–127) so
when V > NUM_THREADS the tail of sOutput isn't copied; update the writeback in
the function containing the variables tidx, V, NUM_THREADS, o and sOutput to
loop with a stride of NUM_THREADS (e.g., for idx = tidx; idx < V; idx +=
NUM_THREADS) and write o[(i_n, i_t, i_hv, idx)] = sOutput[idx], or alternatively
assert/enforce V == NUM_THREADS before this writeback to guarantee full
coverage; make the same change for the corresponding small-batch pretranspose
and MTP final writeback sites.
- Around line 898-901: The K validation currently allows K > 128 but the kernel
only loads TILE_K (128) elements; change the assertion to require K == TILE_K
(i.e., assert K == TILE_K, not K >= 128) so tails are not silently ignored, and
apply the identical K == TILE_K check in the non‑transpose and MTP entry points
(the other callsites/entry functions that validate K, e.g., the non‑transpose
decoder and the MTP entry routine) to ensure consistency with the fixed 128‑wide
kernel.
- Around line 931-932: The cache key for compiled kernels (built at cache_key =
(B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm) and used with
_get_compiled_decode_kernel) omits the output tensor dtype, causing reuse of
kernels compiled for the wrong output type; update the cache key to include the
output's dtype (e.g., add output.dtype or normalized_output_dtype) or normalize
by forcing/allocating a fixed output dtype before compilation so the key and
kernel compilation always match the actual output type used. Ensure you
reference the same output variable/name used where the kernel is launched when
adding this dtype to the key so cached kernels are only reused for compatible
output dtypes.
| # =================================================================== | ||
| # Final writeback: Copy output from shared memory to global memory | ||
| # All threads write (V=128, NUM_THREADS=128) | ||
| # =================================================================== | ||
| cute.arch.barrier() # Ensure all writes to sOutput are complete | ||
|
|
||
| if tidx < V: | ||
| o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] | ||
|
|
There was a problem hiding this comment.
Writeback only covers the first 128 channels.
tidx is 0–127, so when V > 128, the tail is never written in the big‑batch pretranspose path (same pattern exists in the small‑batch pretranspose and MTP final writeback). Either loop with stride or enforce V == NUM_THREADS.
🔧 Suggested fix (strided writeback)
- if tidx < V:
- o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx]
+ for v_offset in range(0, V, NUM_THREADS):
+ out_idx = v_offset + tidx
+ if out_idx < V:
+ o[(i_n, i_t, i_hv, out_idx)] = sOutput[out_idx]🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 565 - 573, The final writeback only
writes indices tidx (0–127) so when V > NUM_THREADS the tail of sOutput isn't
copied; update the writeback in the function containing the variables tidx, V,
NUM_THREADS, o and sOutput to loop with a stride of NUM_THREADS (e.g., for idx =
tidx; idx < V; idx += NUM_THREADS) and write o[(i_n, i_t, i_hv, idx)] =
sOutput[idx], or alternatively assert/enforce V == NUM_THREADS before this
writeback to guarantee full coverage; make the same change for the corresponding
small-batch pretranspose and MTP final writeback sites.
flashinfer/gdn_decode.py
Outdated
| # Validate K and V constraints | ||
| assert K >= 128, f"K must be at least 128, got K={K}" | ||
| assert V >= 128, f"V must be at least 128, got V={V}" | ||
| assert V % TILE_V == 0, f"V must be divisible by {TILE_V} to prevent out-of-bounds access, got V={V}" |
There was a problem hiding this comment.
Tighten K validation to match the fixed 128‑wide kernel.
The kernels only load TILE_K=128 elements (vec_size = TILE_K // 32) and never iterate over K‑tiles. Allowing K > 128 silently ignores the tail. Please assert K == TILE_K (and apply the same constraint in the non‑transpose and MTP entry points).
🔧 Suggested fix
- assert K >= 128, f"K must be at least 128, got K={K}"
+ assert K == TILE_K, f"K must be exactly {TILE_K}, got K={K}"🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 898 - 901, The K validation currently
allows K > 128 but the kernel only loads TILE_K (128) elements; change the
assertion to require K == TILE_K (i.e., assert K == TILE_K, not K >= 128) so
tails are not silently ignored, and apply the identical K == TILE_K check in the
non‑transpose and MTP entry points (the other callsites/entry functions that
validate K, e.g., the non‑transpose decoder and the MTP entry routine) to ensure
consistency with the fixed 128‑wide kernel.
|
[FAILED] Pipeline #42075372: 3/20 passed |
|
/bot run |
vadiklyutiy
left a comment
There was a problem hiding this comment.
@HongliMi Could you make measurement vs fused_reccurent_gated_delta_rule Triton kernel from vLLM?
vadiklyutiy
left a comment
There was a problem hiding this comment.
3.97 TB/s on H20 looks weird... The peak is 3.35TB/s
Agreed, @HongliMi can you double check? |
|
The H100 chip has a peak memory bandwidth of 3.35TB/s, while the H20 has a peak bandwidth of 4.8TB/s. |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 136-139: The code silently drops tail V-tiles because num_v_tiles
= V // TILE_V can have a remainder versus NUM_BLOCKS_PER_STATE; in
gated_delta_rule_decode_pretranspose (and any caller logic using
NUM_BLOCKS_PER_STATE and num_v_tiles) add a validation after computing
num_v_tiles that enforces num_v_tiles % NUM_BLOCKS_PER_STATE == 0 (raising an
assertion or ValueError with a clear message referencing V, num_v_tiles and
NUM_BLOCKS_PER_STATE), or alternatively change the block processing logic (e.g.,
in batch_idx/batch_inner handling) to explicitly handle remainder tiles instead
of using integer division; prefer adding the divisibility check near the
existing V/TILE_V checks to prevent silent data loss.
- Around line 1046-1049: The calculation in the small-batch nontranspose path
(variables block_idx, batch_idx, batch_inner, num_v_tiles,
num_v_tiles_per_block, start_v_tile, and constant NUM_BLOCKS_PER_STATE_SMALL_NT)
can drop remainder vertical tiles because num_v_tiles_per_block = num_v_tiles //
NUM_BLOCKS_PER_STATE_SMALL_NT does integer division; add a validation or
handling to prevent tile loss: either assert/raise if num_v_tiles %
NUM_BLOCKS_PER_STATE_SMALL_NT != 0 (i.e., enforce num_v_tiles per-state
divisible by NUM_BLOCKS_PER_STATE_SMALL_NT) or change the division to compute a
ceiling (and adjust start_v_tile/end_v_tile calculations) so all tiles are
covered; update the existing API validation (the V % TILE_V_NT check) to include
this new divisibility constraint or add the new check early in the
input-validation routine.
♻️ Duplicate comments (3)
benchmarks/bench_gdn_decode.py (2)
78-85: Consider prefixing unusednum_k_headswith underscore.The
num_k_headsparameter is unused in the FLOPs calculation since GDN uses the same head count for queries and keys. Prefix with_to silence the linter while keeping the signature consistent with related functions.def gdn_decode_flops( batch_size: int, num_q_heads: int, - num_k_heads: int, + _num_k_heads: int, num_v_heads: int,
171-177: Intermediate bytes counted unconditionally for seq_len > 1.The
gdn_decode_bytesfunction always addsintermediate_byteswhenseq_len > 1, butbench_gdn_mtpmay run withcache_intermediate_states=False. This overstates memory bandwidth when intermediate state caching is disabled.🔧 Suggested fix
def gdn_decode_bytes( ... seq_len: int = 1, disable_state_update: bool = False, + cache_intermediate_states: bool = True, ) -> int: ... - if seq_len > 1: + if seq_len > 1 and cache_intermediate_states: intermediate_bytes = ( batch_size * seq_len * num_sab_heads * head_size * head_size * 4 )Then update the caller in
bench_gdn_mtp:bytes_accessed = gdn_decode_bytes( ... seq_len, disable_state_update=True, cache_intermediate_states=cache_intermediate_states, # Pass through )flashinfer/gdn_decode.py (1)
565-573: Writeback only covers first 128 elements when V > 128.The final writeback uses
if tidx < Vbuttidxranges 0-127 (NUM_THREADS=128). When V > 128 (e.g., V=256), elements 128-255 are never written fromsOutputto global memoryo.The same issue exists in the small batch kernel (lines 334-336) and MTP kernel (lines 2119-2121).
Either enforce
V == NUM_THREADSin the API validation, or use strided writeback:🔧 Strided writeback fix
cute.arch.barrier() # Ensure all writes to sOutput are complete - if tidx < V: - o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + for v_offset in range(0, V, NUM_THREADS): + out_idx = v_offset + tidx + if out_idx < V: + o[(i_n, i_t, i_hv, out_idx)] = sOutput[out_idx]
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)
1913-1916: Initialize sOutput with bounds check for V > NUM_THREADS.The loop initializes
sOutput[(i_t, tidx)]for alltidxin 0-127, butsOutputis allocated with shape(T, V). If V > 128, indices 128+ are never initialized. While the kernel later writes tosOutput[(i_t, o_idx)]with proper bounds check at line 2104-2105, the initialization loop could leave uninitialized memory if the kernel logic changes.Consider adding a strided initialization loop for consistency:
for i_t in range(T): for v_offset in range(0, V, NUM_THREADS_MTP): v_idx = v_offset + tidx if v_idx < V: sOutput[(i_t, v_idx)] = 0.0
2369-2373: Consider removing.contiguous()if reshape guarantees contiguity.Past review questioned whether
.contiguous()is needed after.reshape(). PyTorch's.reshape()returns a contiguous tensor when the reshape is a view of a contiguous tensor. Sinceintermediate_states_buffer.to(torch.float32)returns a new contiguous tensor, the subsequent.reshape()should also be contiguous.However, keeping
.contiguous()is defensive and the overhead is negligible if the tensor is already contiguous. This is a minor optimization opportunity.
|
/bot run |
|
/bot run |
|
Sorry for jumping in, but what’s the point of merging a kernel that’s slower than what SGLang and vLLM are already using? |
Hi @vadiklyutiy thanks for the reminder, would you mind comparing with sglang's python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py as pointed by @liz-badada . |
I though compare vs existing implementations is mandatory part of posting PR |
<!-- .github/pull_request_template.md --> ## 📌 Description Follow up of #2370 , this PR improves the benchmark scripts and add comparison with baselines: * benchmark using cupti with l2 flush * compare with sglang's `fused_sigmoid_gating_delta_rule_update` function (with tile size optimization mentioned by @ vadiklyutiy). this PR also implements some optimizations on the original gdn kernel: * use fastmath as much as we can * change "/" to multiply * Use `cutlass.range_constexpr` and `cutlass.const_expr` whenever possible * fuse scale and inv_norm_q * For mtp, store state in registers directly, without load/write to shared memory, and remove cpasync * Vectorized memory access. ## Performance on B200 Non MTP setting ``` > python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify === Correctness Verification === Batch=8: Pretranspose: PASS Nontranspose: PASS Batch=16: Pretranspose: PASS Nontranspose: PASS Batch=32: Pretranspose: PASS Nontranspose: PASS Batch=64: Pretranspose: PASS Nontranspose: PASS ======================================================================================================================== GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON ======================================================================================================================== batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR | (us) (us) | (us) (us) | speedup speedup | speedup speedup ------------------------------------------------------------------------------------------------------------------------ 1 | 3.74 5.06 | 5.95 4.35 | 1.59x 0.86x | 1.35x 0.73x 2 | 4.29 5.89 | 6.37 5.02 | 1.49x 0.85x | 1.37x 0.79x 4 | 5.41 7.78 | 7.58 6.66 | 1.40x 0.86x | 1.44x 0.88x 8 | 7.65 12.03 | 9.95 10.21 | 1.30x 0.85x | 1.57x 1.03x 16 | 12.61 19.30 | 16.83 15.81 | 1.34x 0.82x | 1.53x 0.94x 32 | 22.91 32.86 | 31.55 27.84 | 1.38x 0.85x | 1.43x 0.88x 64 | 52.74 58.61 | 58.91 53.02 | 1.12x 0.90x | 1.11x 0.90x 128 | 92.93 107.98 | 114.45 106.78 | 1.23x 0.99x | 1.16x 0.93x 256 | 170.77 209.04 | 225.71 216.41 | 1.32x 1.04x | 1.22x 0.96x ------------------------------------------------------------------------------------------------------------------------ Legend: FI-PreTr = FlashInfer Pretranspose [B, HV, V, K] FI-NonTr = FlashInfer Nontranspose [B, HV, K, V] TR-PreTr = Triton Pretranspose [B, HV, V, K] TR-NonTr = Triton Nontranspose [B, HV, K, V] FI/TR speedup > 1.0 means FlashInfer is faster than Triton Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x ``` MTP Setting (pretranspose only) ``` > python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify === Correctness Verification (MTP) === Batch=8: PASS Batch=16: PASS Batch=32: PASS Batch=64: PASS GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF -------------------------------------------------------------------------------------------------------------- batch seq_len FlashInfer(us) Triton(us) FI TFLOPS TR TFLOPS Speedup -------------------------------------------------------------------------------------------------------------- 1 2 9.22 10.05 0.68 0.63 1.09x 1 4 11.20 14.43 1.12 0.87 1.29x 1 8 15.81 22.08 1.59 1.14 1.40x 2 2 10.11 10.69 1.24 1.18 1.06x 2 4 12.58 15.10 2.00 1.67 1.20x 2 8 18.82 23.63 2.67 2.13 1.26x 4 2 11.39 11.94 2.21 2.11 1.05x 4 4 15.23 16.54 3.30 3.04 1.09x 4 8 23.62 25.50 4.26 3.95 1.08x 8 2 14.69 17.23 3.43 2.92 1.17x 8 4 21.20 25.01 4.75 4.03 1.18x 8 8 34.69 40.86 5.80 4.93 1.18x 16 2 21.47 24.22 4.69 4.16 1.13x 16 4 32.54 36.98 6.19 5.44 1.14x 16 8 56.24 61.76 7.16 6.52 1.10x 32 2 33.50 37.68 6.01 5.34 1.12x 32 4 54.66 60.26 7.37 6.68 1.10x 32 8 97.98 104.35 8.22 7.72 1.06x 64 2 59.82 65.38 6.73 6.16 1.09x 64 4 102.05 108.83 7.89 7.40 1.07x 64 8 188.17 196.45 8.56 8.20 1.04x 128 2 107.44 121.41 7.50 6.63 1.13x 128 4 192.01 209.90 8.39 7.67 1.09x 128 8 366.81 389.12 8.78 8.28 1.06x 256 2 199.14 236.19 8.09 6.82 1.19x 256 4 363.36 422.61 8.87 7.62 1.16x 256 8 708.22 787.05 9.10 8.19 1.11x -------------------------------------------------------------------------------------------------------------- Speedup > 1.0 means FlashInfer is faster Summary: Average speedup: 1.13x Min speedup: 1.04x (batch=64, T=8) Max speedup: 1.40x (batch=1, T=8) ``` ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Triton-based benchmarks and end-to-end comparison/verify modes across multiple memory layouts (including MTP); new verification flows to compare implementations. * **Performance Improvements** * Batch-size-aware kernel selection, configurable tile/vec sizing, fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing for more accurate benchmarks. * **Behavior & Compatibility** * Improved layout handling, expanded CLI presets/modes, clearer error messages and guards when Triton is unavailable; default benchmark mode updated. * **Documentation** * Updated usage examples and CLI guidance. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: HongliMi <1667738261@qq.com> Co-authored-by: Hongli Mi <hmi@nvidia.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
<!-- .github/pull_request_template.md --> ## 📌 Description Follow up of flashinfer-ai#2370 , this PR improves the benchmark scripts and add comparison with baselines: * benchmark using cupti with l2 flush * compare with sglang's `fused_sigmoid_gating_delta_rule_update` function (with tile size optimization mentioned by @ vadiklyutiy). this PR also implements some optimizations on the original gdn kernel: * use fastmath as much as we can * change "/" to multiply * Use `cutlass.range_constexpr` and `cutlass.const_expr` whenever possible * fuse scale and inv_norm_q * For mtp, store state in registers directly, without load/write to shared memory, and remove cpasync * Vectorized memory access. ## Performance on B200 Non MTP setting ``` > python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify === Correctness Verification === Batch=8: Pretranspose: PASS Nontranspose: PASS Batch=16: Pretranspose: PASS Nontranspose: PASS Batch=32: Pretranspose: PASS Nontranspose: PASS Batch=64: Pretranspose: PASS Nontranspose: PASS ======================================================================================================================== GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON ======================================================================================================================== batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR | (us) (us) | (us) (us) | speedup speedup | speedup speedup ------------------------------------------------------------------------------------------------------------------------ 1 | 3.74 5.06 | 5.95 4.35 | 1.59x 0.86x | 1.35x 0.73x 2 | 4.29 5.89 | 6.37 5.02 | 1.49x 0.85x | 1.37x 0.79x 4 | 5.41 7.78 | 7.58 6.66 | 1.40x 0.86x | 1.44x 0.88x 8 | 7.65 12.03 | 9.95 10.21 | 1.30x 0.85x | 1.57x 1.03x 16 | 12.61 19.30 | 16.83 15.81 | 1.34x 0.82x | 1.53x 0.94x 32 | 22.91 32.86 | 31.55 27.84 | 1.38x 0.85x | 1.43x 0.88x 64 | 52.74 58.61 | 58.91 53.02 | 1.12x 0.90x | 1.11x 0.90x 128 | 92.93 107.98 | 114.45 106.78 | 1.23x 0.99x | 1.16x 0.93x 256 | 170.77 209.04 | 225.71 216.41 | 1.32x 1.04x | 1.22x 0.96x ------------------------------------------------------------------------------------------------------------------------ Legend: FI-PreTr = FlashInfer Pretranspose [B, HV, V, K] FI-NonTr = FlashInfer Nontranspose [B, HV, K, V] TR-PreTr = Triton Pretranspose [B, HV, V, K] TR-NonTr = Triton Nontranspose [B, HV, K, V] FI/TR speedup > 1.0 means FlashInfer is faster than Triton Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x ``` MTP Setting (pretranspose only) ``` > python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify === Correctness Verification (MTP) === Batch=8: PASS Batch=16: PASS Batch=32: PASS Batch=64: PASS GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF -------------------------------------------------------------------------------------------------------------- batch seq_len FlashInfer(us) Triton(us) FI TFLOPS TR TFLOPS Speedup -------------------------------------------------------------------------------------------------------------- 1 2 9.22 10.05 0.68 0.63 1.09x 1 4 11.20 14.43 1.12 0.87 1.29x 1 8 15.81 22.08 1.59 1.14 1.40x 2 2 10.11 10.69 1.24 1.18 1.06x 2 4 12.58 15.10 2.00 1.67 1.20x 2 8 18.82 23.63 2.67 2.13 1.26x 4 2 11.39 11.94 2.21 2.11 1.05x 4 4 15.23 16.54 3.30 3.04 1.09x 4 8 23.62 25.50 4.26 3.95 1.08x 8 2 14.69 17.23 3.43 2.92 1.17x 8 4 21.20 25.01 4.75 4.03 1.18x 8 8 34.69 40.86 5.80 4.93 1.18x 16 2 21.47 24.22 4.69 4.16 1.13x 16 4 32.54 36.98 6.19 5.44 1.14x 16 8 56.24 61.76 7.16 6.52 1.10x 32 2 33.50 37.68 6.01 5.34 1.12x 32 4 54.66 60.26 7.37 6.68 1.10x 32 8 97.98 104.35 8.22 7.72 1.06x 64 2 59.82 65.38 6.73 6.16 1.09x 64 4 102.05 108.83 7.89 7.40 1.07x 64 8 188.17 196.45 8.56 8.20 1.04x 128 2 107.44 121.41 7.50 6.63 1.13x 128 4 192.01 209.90 8.39 7.67 1.09x 128 8 366.81 389.12 8.78 8.28 1.06x 256 2 199.14 236.19 8.09 6.82 1.19x 256 4 363.36 422.61 8.87 7.62 1.16x 256 8 708.22 787.05 9.10 8.19 1.11x -------------------------------------------------------------------------------------------------------------- Speedup > 1.0 means FlashInfer is faster Summary: Average speedup: 1.13x Min speedup: 1.04x (batch=64, T=8) Max speedup: 1.40x (batch=1, T=8) ``` ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Triton-based benchmarks and end-to-end comparison/verify modes across multiple memory layouts (including MTP); new verification flows to compare implementations. * **Performance Improvements** * Batch-size-aware kernel selection, configurable tile/vec sizing, fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing for more accurate benchmarks. * **Behavior & Compatibility** * Improved layout handling, expanded CLI presets/modes, clearer error messages and guards when Triton is unavailable; default benchmark mode updated. * **Documentation** * Updated usage examples and CLI guidance. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: HongliMi <1667738261@qq.com> Co-authored-by: Hongli Mi <hmi@nvidia.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-author: @zhou9402 @liz-badada @xutizhou
📌 Description
This PR integrates three versions of the Gated Delta Rule (GDN) Decode kernels into FlashInfer, implemented using CuTe DSL for SM90 (Hopper) and SM100 (Blackwell) GPUs. These kernels enable efficient linear attention decoding for models like Qwen3-Next.
🎯 Features Added
gated_delta_rule_decode_pretranspose)[B, HV, V, K](K-last)🏗️ Architecture
All three versions follow FlashInfer's integration pattern:
@flashinfer_apidecorator📊 Performance Highlights
B200(HBM3e,192GB,8TB/s)
📁 Files Changed
Core Implementation:
flashinfer/gdn_decode.pyTesting:
tests/gdn/test_decode_delta_rule.pytests/gdn/reference_delta_rule.pydecode_delta_rule()- Single token refereBenchmarking:
benchmarks/bench_gdn_decode.py