Skip to content

feat: [Qwen3-Next] Add Cute DSL GDN decode kernel and tests#2370

Merged
yzh119 merged 10 commits intoflashinfer-ai:mainfrom
HongliMi:GDN_decode_kernel
Jan 22, 2026
Merged

feat: [Qwen3-Next] Add Cute DSL GDN decode kernel and tests#2370
yzh119 merged 10 commits intoflashinfer-ai:mainfrom
HongliMi:GDN_decode_kernel

Conversation

@HongliMi
Copy link
Contributor

@HongliMi HongliMi commented Jan 18, 2026

Co-author: @zhou9402 @liz-badada @xutizhou

📌 Description

This PR integrates three versions of the Gated Delta Rule (GDN) Decode kernels into FlashInfer, implemented using CuTe DSL for SM90 (Hopper) and SM100 (Blackwell) GPUs. These kernels enable efficient linear attention decoding for models like Qwen3-Next.

🎯 Features Added

  1. Pretranspose Decode Kernel (gated_delta_rule_decode_pretranspose)
    • State layout: [B, HV, V, K] (K-last)

🏗️ Architecture

All three versions follow FlashInfer's integration pattern:

  • CuTe DSL kernels with JIT compilation and caching
  • Python API layer with @flashinfer_api decorator
  • Reference implementations for correctness validation
  • Comprehensive unit tests with various configurations (GQA, GVA, etc.)

📊 Performance Highlights

B200(HBM3e,192GB,8TB/s)

GDN Decode Comparison: FlashInfer (Pretranspose) vs Triton
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON
----------------------------------------------------------------------------------------------------
 batch FlashInfer(us)   Triton(us)  FI TFLOPS  TR TFLOPS    Speedup
----------------------------------------------------------------------------------------------------
     1           4.03         5.92       0.78       0.53       1.47x
     2           4.58         6.46       1.37       0.97       1.41x
     4           5.74         7.49       2.19       1.68       1.30x
     8           8.61         9.95       2.92       2.53       1.16x
    16          14.19        16.83       3.55       2.99       1.19x
    32          24.51        31.90       4.11       3.16       1.30x
    64          49.60        58.50       4.06       3.44       1.18x
   128          92.26       113.90       4.36       3.54       1.23x
   256         175.36       225.57       4.59       3.57       1.29x
   512         340.45       442.26       4.73       3.64       1.30x
----------------------------------------------------------------------------------------------------

📁 Files Changed

Core Implementation:

  • flashinfer/gdn_decode.py
    • Three CuTe DSL kernel variants
    • JIT compilation with caching
    • API functions with comprehensive validation

Testing:

  • tests/gdn/test_decode_delta_rule.py

    • Unit tests for all three versions
    • Tests for various head configurations (MHA, GQA, GVA)
    • Dtype tests (float16, bfloat16)
  • tests/gdn/reference_delta_rule.py

    • decode_delta_rule() - Single token refere

Benchmarking:

  • benchmarks/bench_gdn_decode.py

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 18, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds a new CUDA‑accelerated Gated Delta Rule (GDN) decode implementation (pretranspose, non‑transpose, MTP), a profiler-driven benchmarking CLI, Python reference/verifier implementations, extensive unit tests, and updates to test conftest compute-capability gating.

Changes

Cohort / File(s) Summary
Core GDN Decode Implementation
flashinfer/gdn_decode.py
New CUDA‑accelerated GDN decode with pretranspose & non‑transpose kernel variants, small/large‑batch kernels, MTP/verify paths, runtime kernel compilation & caching, layout-aware entry points, and public wrappers (gated_delta_rule_decode*, gated_delta_rule_verify, gated_delta_rule_mtp).
Benchmarking Utility
benchmarks/bench_gdn_decode.py
New profiler-driven CLI and utilities: trace parsing (parse_trace_file), FLOPs/bytes calculators (gdn_*_flops, gdn_*_bytes), bench routines (bench_gdn_decode, bench_gdn_mtp), warmup/profile runs, Chrome trace export, kernel timing extraction, device capability checks, and result reporting.
Reference Implementations
tests/gdn/reference_delta_rule.py
Adds decode_delta_rule() (single-step) and verify_delta_rule() (multi-step, optional intermediate-state caching) reference implementations with gating, softplus stabilization, optional L2 normalization, and typed Optional return.
Unit Tests
tests/gdn/test_decode_delta_rule.py
New tests validating pretranspose, nontranspose, and MTP/verify paths against reference implementations across dtypes, batch/head configs, sequence lengths; includes randomized checks, helpers, and smoke-test main guard.
Prefill Tests
tests/gdn/test_prefill_delta_rule.py
Adds runtime guard to skip prefill-related tests unless running on SM90 GPUs (uses compute-capability check and helper _skip_if_not_sm90).
Test Fixtures / Conftest
tests/gdn/conftest.py
Removed autouse SM90a skip: import is_sm90a_supported and the autouse skip_if_not_sm90a fixture were deleted, so GDN tests are no longer auto-skipped based on that helper.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CLI as "Benchmark CLI\n(bench_gdn_decode.py)"
    participant Profiler as "Torch Profiler"
    participant GPU as "GPU Kernels\n(pretranspose/nontranspose/MTP)"
    participant TraceParser as "Trace Parser\n(parse_trace_file)"
    participant Metrics as "Metrics Calculator\n(TFLOPS / TB/s)"

    User->>CLI: run benchmark (config)
    CLI->>GPU: validate device capabilities
    CLI->>GPU: run warmup iterations
    CLI->>Profiler: start profiling
    loop bench iterations
        CLI->>GPU: launch decode / MTP kernel
        GPU-->>GPU: execute kernel, update state
    end
    Profiler->>Profiler: export chrome trace
    CLI->>TraceParser: parse trace for kernel timings
    TraceParser-->>CLI: return kernel timing metrics
    CLI->>Metrics: compute statistics and throughput
    Metrics-->>CLI: return results
    CLI-->>User: display results table
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • cyx-6
  • nvmbreughe
  • jiahanc
  • kahyunnam
  • jimmyzho
  • bkryu
  • Anerudhan

Poem

🐇 I hopped through kernels, warmups bright,
Traces twinkled in the profiling light.
States tucked neat in GPU's den,
Timings tallied, run again.
A rabbit claps — decode delight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.16% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding CuTe DSL GDN decode kernels and tests targeting Qwen3-Next, which aligns with the core implementation across three kernel variants and comprehensive test suite.
Description check ✅ Passed The PR description comprehensively addresses the template requirements with detailed sections on features, architecture, performance, and files changed. However, the template sections like pre-commit checks and reviewer notes are not explicitly marked as completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @HongliMi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities by integrating highly optimized Gated Delta Rule (GDN) decode kernels. These kernels, developed using CuTe DSL for Hopper GPUs, enable efficient linear attention decoding, crucial for modern large language models like Qwen3-Next. The PR introduces three distinct kernel versions tailored for different state layouts and processing modes (single-token and multi-token verification), ensuring both performance and flexibility.

Highlights

  • New GDN Decode Kernels: Introduced three highly optimized Gated Delta Rule (GDN) decode kernels using CuTe DSL for SM90 (Hopper) GPUs, enabling efficient linear attention decoding for models like Qwen3-Next.
  • Pretranspose Decode Kernel: Added gated_delta_rule_decode_pretranspose for single-token decoding, which expects a [B, HV, V, K] (V-major) state layout.
  • Nontranspose Decode Kernel: Implemented gated_delta_rule_decode for single-token decoding, utilizing a more natural [B, HV, K, V] (K-major) state layout, thereby eliminating the need for state transposition.
  • Multi-Token Processing (MTP) Kernel: Included a gated_delta_rule_verify (MTP) kernel designed for processing multiple tokens sequentially, supporting speculative decoding verification and intermediate state caching.
  • Comprehensive Integration and Validation: Provided comprehensive Python API wrappers, PyTorch-based reference implementations for correctness validation, and detailed benchmark scripts for performance evaluation of all new kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@HongliMi HongliMi changed the title Add Cute DSL GDN decode kernel and tests Add Cute DSL GDN decode kernel and tests Jan 18, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces three new Gated Delta Rule (GDN) decode kernels using CuTe DSL, along with comprehensive benchmarks and tests. The changes are well-structured and the implementation is advanced. My review focuses on improving code reuse in the benchmark script and fixing critical bugs in the nontranspose kernel implementations.

Key feedback points:

  • Critical Bug: The nontranspose decode kernels (small_batch and big_batch) contain loops with incorrect range() arguments, which will likely cause them to not execute and produce incorrect results.
  • Code Duplication: The benchmark script has duplicated logic for parsing profiler traces and unused variables. I've suggested refactoring to improve maintainability and clarity.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 2141-2148: When intermediate_states_buffer is provided, validate
that its second dimension (cache_steps) is at least T before reshaping/using it:
check intermediate_states_buffer.shape[1] >= T and raise a clear ValueError (or
similar) if not; keep the existing branch that sets cache_intermediate_states
and buffer_size, compute cache_steps = intermediate_states_buffer.shape[1], and
perform the validation prior to calling
intermediate_states_buffer.to(...).reshape(...), referencing the variables
intermediate_states_buffer, cache_steps, T, and intermediate_states so the check
is colocated with the existing logic.
- Around line 839-865: The cache key for compiled decode kernels omits
compile-time parameters causing reuse of kernels with wrong specializations;
update the cache_key tuple used before calling _get_compiled_decode_kernel to
include scale and use_qk_l2norm (and any other constexpr compile parameters) for
the pretranspose path (where cache_key = (B, T, H, HV, K, V, q.dtype)), and make
the same change in the nontranspose and MTP branches that build cache_key, so
the cached entry uniquely identifies the cute.compile call
(run_gdn_decode_kernel_small_batch_pretranspose,
run_gdn_decode_kernel_big_batch_pretranspose, and corresponding nontranspose/MTP
run functions) which are invoked with cute.compile(..., scale=scale,
use_qk_l2norm=use_qk_l2norm, ...).
- Around line 881-883: Remove the blocking device-wide synchronization calls
(torch.cuda.synchronize()) in the per-token decode hot path; specifically delete
the torch.cuda.synchronize() just before the state.copy_(h0_source.reshape(B,
HV, V, K)) and the analogous calls at the other two locations referenced in the
review, relying on PyTorch's stream ordering so that state.copy_() on the
current stream properly orders kernels without a global sync; ensure no
subsequent code assumes the global sync was required and that state.copy_ and
downstream ops remain on the same stream.
- Around line 792-799: The pretranspose kernel can write out-of-bounds because K
and V constraints are not validated and the compute loop lacks the MTP-style
bounds check; add assertions after unpacking shapes (where q, v, state are used)
- assert K >= 128 and assert V % 4 == 0 (vec_size = TILE_K // 32 == 4
requirement) - to ensure vectorized loads are safe, and modify the pretranspose
kernel's compute loop to guard the sOutput write with the same pattern as the
MTP kernel (e.g. check lane_id == 0 and o_idx < V before writing sOutput[o_idx])
so writes cannot exceed V.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 69-74: Add GPU-architecture skip guards to the tests in
tests/gdn/test_decode_delta_rule.py by using flashinfer.utils helpers (e.g.,
get_compute_capability) and pytest.mark.skipif so the test only runs on
supported GPUs; locate the test function(s) that use the device context manager
(the "with device:" block and related variables q,k,v) and add a skip condition
that checks compute capability (via get_compute_capability) or an appropriate
flashinfer.utils predicate before running the test, using
pytest.mark.skipif(...) to decorate the test or wrapping the test body with an
early pytest.skip when the GPU architecture is unsupported.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1083-1084: The read using v_global (computed as v_tile *
TILE_V_SMALL_NT + v_idx) can go out of bounds when V % TILE_V_SMALL_NT != 0;
update the code that computes and uses v_global (and the analogous logic in the
big batch kernel) to either validate input (e.g., assert V % TILE_V_SMALL_NT ==
0) or guard the index access with a bounds check before calling v[i_n, 0, i_hv,
v_global], and handle the tail tile properly (skip, pad, or clamp) so no
out-of-range indexing occurs.
- Around line 1973-1974: The MTP writeback loop writes to o[(i_n, i_t, i_hv,
tidx)] for all tidx (0-127) without checking the actual V dimension; add a
bounds check using the V variable (or the output tensor's last-dim size) and
only perform the assignment when tidx < V to avoid out-of-bounds writes (same
pattern as the pretranspose kernel). Locate the write in gdn_decode.py (the loop
over i_t that assigns cutlass.BFloat16(sOutput[(i_t, tidx)])) and guard that
assignment with a conditional on tidx < V (or skip/zero-pad for tidx >= V) so
only valid channels are written.
- Around line 803-805: Add a check enforcing V >= 128 (e.g. assert V >= 128, f"V
must be at least 128, got V={V}") because kernels assume V>=128; insert this
validation alongside the existing K and V checks in the top-level validation
block (the snippet with K and V asserts) and also add the same assert at the
start of the functions gated_delta_rule_decode and gated_delta_rule_mtp so their
local kernel assumptions are validated.
- Around line 548-550: The final writeback to output tensor o in gdn_decode.py
can write out-of-bounds when tidx >= V; add a bounds check before the write in
the big-batch kernel so you only assign o[(i_n, i_t, i_hv, tidx)] =
sOutput[tidx] when tidx < V (or otherwise clamp/mask writes to V), similar to
the partial check present in the small-batch kernel; update the write
surrounding the existing cute.arch.barrier() and reference the variables
NUM_THREADS, V, sOutput, o, i_n, i_t, i_hv and tidx when making the conditional
guard.
- Around line 202-207: The code assumes V >= 128 but only validates V % 4 == 0;
add an explicit assertion in both gated_delta_rule_decode_pretranspose and
gated_delta_rule_decode immediately after the existing API validation block to
prevent out-of-bounds accesses when loading v into sV and indexing with i * 32 +
lane_id: assert V >= 128, f"V must be at least 128, got V={V}". This ensures the
vec_size/TILE_K loading loop and shared-memory writes to sV are safe.
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)

600-600: Remove unused variable to satisfy linter.

total_data_mb is calculated but never used (same at line 682).

-    total_data_mb = v_dim * k_dim * batch_size * 4 / 1024 / 1024

2164-2166: Consider removing .contiguous() for consistency with similar reshape patterns in the same function.

The intermediate_states tensor follows the same operation pattern as h0_source (both: .to(torch.float32).reshape(...)), but only intermediate_states includes .contiguous(). Since PyTorch's reshape() returns a view when the requested shape is compatible with memory layout, and .contiguous() is a no-op if the tensor is already contiguous, the .contiguous() call here is defensive but likely unnecessary.

Both tensors are passed to from_dlpack() for CuTe kernel conversion. If this operation requires contiguity, both should include the call; if not, removing it from intermediate_states would improve code consistency.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 20, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !248 has been created, and the CI pipeline #42075372 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed UTs are because of a bug introduced in #2366 which is fixed in #2378, this PR itself should be ready to merge (and bypass the check).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 900-907: Remove all Git merge conflict markers (<<<<<<<, =======,
>>>>>>>) in flashinfer/gdn_decode.py and resolve each conflict by keeping the
newer assertion checks replacing the old vectorized-load assertion: retain
"assert V >= 128" and "assert V % TILE_V == 0" for the V validation near the top
(the block touching V and TILE_V), and apply the same conflict resolution
pattern at the other listed locations (the small-batch nontranspose kernel loop,
the nontranspose API validation, and the MTP API validation) so the file
contains only valid Python assertions and logic with no leftover conflict
markers.
♻️ Duplicate comments (1)
flashinfer/gdn_decode.py (1)

1195-1196: Bounds safety depends on proper V validation.

The access v[i_n, 0, i_hv, v_global] where v_global = v_tile * TILE_V_SMALL_NT + v_idx could go out-of-bounds if V is not divisible by the tile size. The validation being added in the merge conflict resolution (V % TILE_V_NT == 0) will properly guard this, since TILE_V_NT=32 is a multiple of TILE_V_SMALL_NT=16.

Ensure the merge conflict is resolved to include the V % TILE_V_NT == 0 assertion.

🧹 Nitpick comments (3)
flashinfer/gdn_decode.py (1)

626-626: Remove dead expression.

This line computes a value but discards it (not assigned to any variable). This appears to be leftover from when it was assigned to total_data_mb for the commented-out debug print below.

-    v_dim * k_dim * batch_size * 4 / 1024 / 1024

The same issue exists at line 735.

benchmarks/bench_gdn_decode.py (2)

78-112: Consider removing or documenting unused num_k_heads parameter.

The num_k_heads parameter is never used in the FLOPs calculation. This is technically correct since FLOPs are determined by output heads (max(num_q_heads, num_v_heads)), not key heads. However, having an unused parameter can be confusing.

Options:

  1. Remove the parameter if it's not needed for API consistency
  2. Add a comment explaining why it's intentionally unused (e.g., _ = num_k_heads # Unused: FLOPs depend on output heads)
  3. Keep as-is for API consistency with gdn_decode_bytes

653-656: Consider adding else clause for defensive coding.

While the control flow guarantees only 'pretranspose' or 'nontranspose' can reach this point (due to the continue at line 623), adding an else clause would make the code more robust against future changes.

             # Determine which kernel variant was used (based on batch size threshold)
             if version == "pretranspose":
                 kernel_variant = "SmallBatch" if batch_size <= 32 else "LargeBatch"
             elif version == "nontranspose":
                 kernel_variant = "SmallBatch" if batch_size < 32 else "LargeBatch"
+            else:
+                kernel_variant = "Unknown"

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 565-573: The final writeback only writes indices tidx (0–127) so
when V > NUM_THREADS the tail of sOutput isn't copied; update the writeback in
the function containing the variables tidx, V, NUM_THREADS, o and sOutput to
loop with a stride of NUM_THREADS (e.g., for idx = tidx; idx < V; idx +=
NUM_THREADS) and write o[(i_n, i_t, i_hv, idx)] = sOutput[idx], or alternatively
assert/enforce V == NUM_THREADS before this writeback to guarantee full
coverage; make the same change for the corresponding small-batch pretranspose
and MTP final writeback sites.
- Around line 898-901: The K validation currently allows K > 128 but the kernel
only loads TILE_K (128) elements; change the assertion to require K == TILE_K
(i.e., assert K == TILE_K, not K >= 128) so tails are not silently ignored, and
apply the identical K == TILE_K check in the non‑transpose and MTP entry points
(the other callsites/entry functions that validate K, e.g., the non‑transpose
decoder and the MTP entry routine) to ensure consistency with the fixed 128‑wide
kernel.
- Around line 931-932: The cache key for compiled kernels (built at cache_key =
(B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm) and used with
_get_compiled_decode_kernel) omits the output tensor dtype, causing reuse of
kernels compiled for the wrong output type; update the cache key to include the
output's dtype (e.g., add output.dtype or normalized_output_dtype) or normalize
by forcing/allocating a fixed output dtype before compilation so the key and
kernel compilation always match the actual output type used. Ensure you
reference the same output variable/name used where the kernel is launched when
adding this dtype to the key so cached kernels are only reused for compatible
output dtypes.

Comment on lines +565 to +573
# ===================================================================
# Final writeback: Copy output from shared memory to global memory
# All threads write (V=128, NUM_THREADS=128)
# ===================================================================
cute.arch.barrier() # Ensure all writes to sOutput are complete

if tidx < V:
o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Writeback only covers the first 128 channels.

tidx is 0–127, so when V > 128, the tail is never written in the big‑batch pretranspose path (same pattern exists in the small‑batch pretranspose and MTP final writeback). Either loop with stride or enforce V == NUM_THREADS.

🔧 Suggested fix (strided writeback)
-    if tidx < V:
-        o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx]
+    for v_offset in range(0, V, NUM_THREADS):
+        out_idx = v_offset + tidx
+        if out_idx < V:
+            o[(i_n, i_t, i_hv, out_idx)] = sOutput[out_idx]
🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 565 - 573, The final writeback only
writes indices tidx (0–127) so when V > NUM_THREADS the tail of sOutput isn't
copied; update the writeback in the function containing the variables tidx, V,
NUM_THREADS, o and sOutput to loop with a stride of NUM_THREADS (e.g., for idx =
tidx; idx < V; idx += NUM_THREADS) and write o[(i_n, i_t, i_hv, idx)] =
sOutput[idx], or alternatively assert/enforce V == NUM_THREADS before this
writeback to guarantee full coverage; make the same change for the corresponding
small-batch pretranspose and MTP final writeback sites.

Comment on lines +898 to +901
# Validate K and V constraints
assert K >= 128, f"K must be at least 128, got K={K}"
assert V >= 128, f"V must be at least 128, got V={V}"
assert V % TILE_V == 0, f"V must be divisible by {TILE_V} to prevent out-of-bounds access, got V={V}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Tighten K validation to match the fixed 128‑wide kernel.

The kernels only load TILE_K=128 elements (vec_size = TILE_K // 32) and never iterate over K‑tiles. Allowing K > 128 silently ignores the tail. Please assert K == TILE_K (and apply the same constraint in the non‑transpose and MTP entry points).

🔧 Suggested fix
-    assert K >= 128, f"K must be at least 128, got K={K}"
+    assert K == TILE_K, f"K must be exactly {TILE_K}, got K={K}"
🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 898 - 901, The K validation currently
allows K > 128 but the kernel only loads TILE_K (128) elements; change the
assertion to require K == TILE_K (i.e., assert K == TILE_K, not K >= 128) so
tails are not silently ignored, and apply the identical K == TILE_K check in the
non‑transpose and MTP entry points (the other callsites/entry functions that
validate K, e.g., the non‑transpose decoder and the MTP entry routine) to ensure
consistency with the fixed 128‑wide kernel.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #42075372: 3/20 passed

@yzh119
Copy link
Collaborator

yzh119 commented Jan 20, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !248 has been updated with latest changes, and the CI pipeline #42125544 is currently running. I'll report back once the pipeline job completes.

Copy link

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HongliMi Could you make measurement vs fused_reccurent_gated_delta_rule Triton kernel from vLLM?

Copy link

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.97 TB/s on H20 looks weird... The peak is 3.35TB/s

@yzh119
Copy link
Collaborator

yzh119 commented Jan 20, 2026

3.97 TB/s on H20 looks weird... The peak is 3.35TB/s

Agreed, @HongliMi can you double check?

@HongliMi
Copy link
Contributor Author

HongliMi commented Jan 21, 2026

The H100 chip has a peak memory bandwidth of 3.35TB/s, while the H20 has a peak bandwidth of 4.8TB/s.

@HongliMi HongliMi closed this Jan 21, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 136-139: The code silently drops tail V-tiles because num_v_tiles
= V // TILE_V can have a remainder versus NUM_BLOCKS_PER_STATE; in
gated_delta_rule_decode_pretranspose (and any caller logic using
NUM_BLOCKS_PER_STATE and num_v_tiles) add a validation after computing
num_v_tiles that enforces num_v_tiles % NUM_BLOCKS_PER_STATE == 0 (raising an
assertion or ValueError with a clear message referencing V, num_v_tiles and
NUM_BLOCKS_PER_STATE), or alternatively change the block processing logic (e.g.,
in batch_idx/batch_inner handling) to explicitly handle remainder tiles instead
of using integer division; prefer adding the divisibility check near the
existing V/TILE_V checks to prevent silent data loss.
- Around line 1046-1049: The calculation in the small-batch nontranspose path
(variables block_idx, batch_idx, batch_inner, num_v_tiles,
num_v_tiles_per_block, start_v_tile, and constant NUM_BLOCKS_PER_STATE_SMALL_NT)
can drop remainder vertical tiles because num_v_tiles_per_block = num_v_tiles //
NUM_BLOCKS_PER_STATE_SMALL_NT does integer division; add a validation or
handling to prevent tile loss: either assert/raise if num_v_tiles %
NUM_BLOCKS_PER_STATE_SMALL_NT != 0 (i.e., enforce num_v_tiles per-state
divisible by NUM_BLOCKS_PER_STATE_SMALL_NT) or change the division to compute a
ceiling (and adjust start_v_tile/end_v_tile calculations) so all tiles are
covered; update the existing API validation (the V % TILE_V_NT check) to include
this new divisibility constraint or add the new check early in the
input-validation routine.
♻️ Duplicate comments (3)
benchmarks/bench_gdn_decode.py (2)

78-85: Consider prefixing unused num_k_heads with underscore.

The num_k_heads parameter is unused in the FLOPs calculation since GDN uses the same head count for queries and keys. Prefix with _ to silence the linter while keeping the signature consistent with related functions.

 def gdn_decode_flops(
     batch_size: int,
     num_q_heads: int,
-    num_k_heads: int,
+    _num_k_heads: int,
     num_v_heads: int,

171-177: Intermediate bytes counted unconditionally for seq_len > 1.

The gdn_decode_bytes function always adds intermediate_bytes when seq_len > 1, but bench_gdn_mtp may run with cache_intermediate_states=False. This overstates memory bandwidth when intermediate state caching is disabled.

🔧 Suggested fix
 def gdn_decode_bytes(
     ...
     seq_len: int = 1,
     disable_state_update: bool = False,
+    cache_intermediate_states: bool = True,
 ) -> int:
     ...
-    if seq_len > 1:
+    if seq_len > 1 and cache_intermediate_states:
         intermediate_bytes = (
             batch_size * seq_len * num_sab_heads * head_size * head_size * 4
         )

Then update the caller in bench_gdn_mtp:

bytes_accessed = gdn_decode_bytes(
    ...
    seq_len,
    disable_state_update=True,
    cache_intermediate_states=cache_intermediate_states,  # Pass through
)
flashinfer/gdn_decode.py (1)

565-573: Writeback only covers first 128 elements when V > 128.

The final writeback uses if tidx < V but tidx ranges 0-127 (NUM_THREADS=128). When V > 128 (e.g., V=256), elements 128-255 are never written from sOutput to global memory o.

The same issue exists in the small batch kernel (lines 334-336) and MTP kernel (lines 2119-2121).

Either enforce V == NUM_THREADS in the API validation, or use strided writeback:

🔧 Strided writeback fix
     cute.arch.barrier()  # Ensure all writes to sOutput are complete

-    if tidx < V:
-        o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx]
+    for v_offset in range(0, V, NUM_THREADS):
+        out_idx = v_offset + tidx
+        if out_idx < V:
+            o[(i_n, i_t, i_hv, out_idx)] = sOutput[out_idx]
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)

1913-1916: Initialize sOutput with bounds check for V > NUM_THREADS.

The loop initializes sOutput[(i_t, tidx)] for all tidx in 0-127, but sOutput is allocated with shape (T, V). If V > 128, indices 128+ are never initialized. While the kernel later writes to sOutput[(i_t, o_idx)] with proper bounds check at line 2104-2105, the initialization loop could leave uninitialized memory if the kernel logic changes.

Consider adding a strided initialization loop for consistency:

for i_t in range(T):
    for v_offset in range(0, V, NUM_THREADS_MTP):
        v_idx = v_offset + tidx
        if v_idx < V:
            sOutput[(i_t, v_idx)] = 0.0

2369-2373: Consider removing .contiguous() if reshape guarantees contiguity.

Past review questioned whether .contiguous() is needed after .reshape(). PyTorch's .reshape() returns a contiguous tensor when the reshape is a view of a contiguous tensor. Since intermediate_states_buffer.to(torch.float32) returns a new contiguous tensor, the subsequent .reshape() should also be contiguous.

However, keeping .contiguous() is defensive and the overhead is negligible if the tensor is already contiguous. This is a minor optimization opportunity.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 21, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !248 has been updated with latest changes, and the CI pipeline #42228162 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 22, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !248 has been updated with latest changes, and the CI pipeline #42252329 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 merged commit 3115872 into flashinfer-ai:main Jan 22, 2026
16 of 20 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in FlashInfer Roadmap Jan 22, 2026
@vadiklyutiy
Copy link

Sorry for jumping in, but what’s the point of merging a kernel that’s slower than what SGLang and vLLM are already using?

@yzh119
Copy link
Collaborator

yzh119 commented Jan 22, 2026

Sorry for jumping in, but what’s the point of merging a kernel that’s slower than what SGLang and vLLM are already using?

Hi @vadiklyutiy thanks for the reminder, would you mind comparing with sglang's python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py as pointed by @liz-badada .

@vadiklyutiy
Copy link

Hi @vadiklyutiy thanks for the reminder, would you mind comparing with sglang's python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py as pointed by @liz-badada .

I though compare vs existing implementations is mandatory part of posting PR

yzh119 added a commit that referenced this pull request Feb 3, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Follow up of #2370 , this PR improves the benchmark scripts and add
comparison with baselines:
* benchmark using cupti with l2 flush
* compare with sglang's `fused_sigmoid_gating_delta_rule_update`
function (with tile size optimization mentioned by @ vadiklyutiy).

this PR also implements some optimizations on the original gdn kernel:
* use fastmath as much as we can
* change "/" to multiply
* Use `cutlass.range_constexpr` and `cutlass.const_expr` whenever
possible
* fuse scale and inv_norm_q
* For mtp, store state in registers directly, without load/write to
shared memory, and remove cpasync
* Vectorized memory access.

## Performance on B200

Non MTP setting
```
> python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification ===
Batch=8:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=16:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=32:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=64:
  Pretranspose: PASS
  Nontranspose: PASS


========================================================================================================================
GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON
========================================================================================================================

 batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR
       |     (us)     (us) |     (us)     (us) |   speedup   speedup |    speedup    speedup
------------------------------------------------------------------------------------------------------------------------
     1 |     3.74     5.06 |     5.95     4.35 |    1.59x    0.86x |    1.35x    0.73x
     2 |     4.29     5.89 |     6.37     5.02 |    1.49x    0.85x |    1.37x    0.79x
     4 |     5.41     7.78 |     7.58     6.66 |    1.40x    0.86x |    1.44x    0.88x
     8 |     7.65    12.03 |     9.95    10.21 |    1.30x    0.85x |    1.57x    1.03x
    16 |    12.61    19.30 |    16.83    15.81 |    1.34x    0.82x |    1.53x    0.94x
    32 |    22.91    32.86 |    31.55    27.84 |    1.38x    0.85x |    1.43x    0.88x
    64 |    52.74    58.61 |    58.91    53.02 |    1.12x    0.90x |    1.11x    0.90x
   128 |    92.93   107.98 |   114.45   106.78 |    1.23x    0.99x |    1.16x    0.93x
   256 |   170.77   209.04 |   225.71   216.41 |    1.32x    1.04x |    1.22x    0.96x
------------------------------------------------------------------------------------------------------------------------

Legend:
  FI-PreTr  = FlashInfer Pretranspose [B, HV, V, K]
  FI-NonTr  = FlashInfer Nontranspose [B, HV, K, V]
  TR-PreTr  = Triton Pretranspose [B, HV, V, K]
  TR-NonTr  = Triton Nontranspose [B, HV, K, V]
  FI/TR speedup > 1.0 means FlashInfer is faster than Triton
  Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose

FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x
```

MTP Setting (pretranspose only)
```
> python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification (MTP) ===
Batch=8: PASS
Batch=16: PASS
Batch=32: PASS
Batch=64: PASS


GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF
--------------------------------------------------------------------------------------------------------------
 batch  seq_len FlashInfer(us)   Triton(us)  FI TFLOPS  TR TFLOPS    Speedup
--------------------------------------------------------------------------------------------------------------
     1        2           9.22        10.05       0.68       0.63       1.09x
     1        4          11.20        14.43       1.12       0.87       1.29x
     1        8          15.81        22.08       1.59       1.14       1.40x
     2        2          10.11        10.69       1.24       1.18       1.06x
     2        4          12.58        15.10       2.00       1.67       1.20x
     2        8          18.82        23.63       2.67       2.13       1.26x
     4        2          11.39        11.94       2.21       2.11       1.05x
     4        4          15.23        16.54       3.30       3.04       1.09x
     4        8          23.62        25.50       4.26       3.95       1.08x
     8        2          14.69        17.23       3.43       2.92       1.17x
     8        4          21.20        25.01       4.75       4.03       1.18x
     8        8          34.69        40.86       5.80       4.93       1.18x
    16        2          21.47        24.22       4.69       4.16       1.13x
    16        4          32.54        36.98       6.19       5.44       1.14x
    16        8          56.24        61.76       7.16       6.52       1.10x
    32        2          33.50        37.68       6.01       5.34       1.12x
    32        4          54.66        60.26       7.37       6.68       1.10x
    32        8          97.98       104.35       8.22       7.72       1.06x
    64        2          59.82        65.38       6.73       6.16       1.09x
    64        4         102.05       108.83       7.89       7.40       1.07x
    64        8         188.17       196.45       8.56       8.20       1.04x
   128        2         107.44       121.41       7.50       6.63       1.13x
   128        4         192.01       209.90       8.39       7.67       1.09x
   128        8         366.81       389.12       8.78       8.28       1.06x
   256        2         199.14       236.19       8.09       6.82       1.19x
   256        4         363.36       422.61       8.87       7.62       1.16x
   256        8         708.22       787.05       9.10       8.19       1.11x
--------------------------------------------------------------------------------------------------------------
Speedup > 1.0 means FlashInfer is faster

Summary:
  Average speedup: 1.13x
  Min speedup: 1.04x (batch=64, T=8)
  Max speedup: 1.40x (batch=1, T=8)
```

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added Triton-based benchmarks and end-to-end comparison/verify modes
across multiple memory layouts (including MTP); new verification flows
to compare implementations.
* **Performance Improvements**
* Batch-size-aware kernel selection, configurable tile/vec sizing,
fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing
for more accurate benchmarks.
* **Behavior & Compatibility**
* Improved layout handling, expanded CLI presets/modes, clearer error
messages and guards when Triton is unavailable; default benchmark mode
updated.
* **Documentation**
  * Updated usage examples and CLI guidance.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: HongliMi <1667738261@qq.com>
Co-authored-by: Hongli Mi <hmi@nvidia.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
raayandhar pushed a commit to raayandhar/flashinfer that referenced this pull request Feb 5, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Follow up of flashinfer-ai#2370 , this PR improves the benchmark scripts and add
comparison with baselines:
* benchmark using cupti with l2 flush
* compare with sglang's `fused_sigmoid_gating_delta_rule_update`
function (with tile size optimization mentioned by @ vadiklyutiy).

this PR also implements some optimizations on the original gdn kernel:
* use fastmath as much as we can
* change "/" to multiply
* Use `cutlass.range_constexpr` and `cutlass.const_expr` whenever
possible
* fuse scale and inv_norm_q
* For mtp, store state in registers directly, without load/write to
shared memory, and remove cpasync
* Vectorized memory access.

## Performance on B200

Non MTP setting
```
> python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification ===
Batch=8:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=16:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=32:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=64:
  Pretranspose: PASS
  Nontranspose: PASS


========================================================================================================================
GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON
========================================================================================================================

 batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR
       |     (us)     (us) |     (us)     (us) |   speedup   speedup |    speedup    speedup
------------------------------------------------------------------------------------------------------------------------
     1 |     3.74     5.06 |     5.95     4.35 |    1.59x    0.86x |    1.35x    0.73x
     2 |     4.29     5.89 |     6.37     5.02 |    1.49x    0.85x |    1.37x    0.79x
     4 |     5.41     7.78 |     7.58     6.66 |    1.40x    0.86x |    1.44x    0.88x
     8 |     7.65    12.03 |     9.95    10.21 |    1.30x    0.85x |    1.57x    1.03x
    16 |    12.61    19.30 |    16.83    15.81 |    1.34x    0.82x |    1.53x    0.94x
    32 |    22.91    32.86 |    31.55    27.84 |    1.38x    0.85x |    1.43x    0.88x
    64 |    52.74    58.61 |    58.91    53.02 |    1.12x    0.90x |    1.11x    0.90x
   128 |    92.93   107.98 |   114.45   106.78 |    1.23x    0.99x |    1.16x    0.93x
   256 |   170.77   209.04 |   225.71   216.41 |    1.32x    1.04x |    1.22x    0.96x
------------------------------------------------------------------------------------------------------------------------

Legend:
  FI-PreTr  = FlashInfer Pretranspose [B, HV, V, K]
  FI-NonTr  = FlashInfer Nontranspose [B, HV, K, V]
  TR-PreTr  = Triton Pretranspose [B, HV, V, K]
  TR-NonTr  = Triton Nontranspose [B, HV, K, V]
  FI/TR speedup > 1.0 means FlashInfer is faster than Triton
  Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose

FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x
```

MTP Setting (pretranspose only)
```
> python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification (MTP) ===
Batch=8: PASS
Batch=16: PASS
Batch=32: PASS
Batch=64: PASS


GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF
--------------------------------------------------------------------------------------------------------------
 batch  seq_len FlashInfer(us)   Triton(us)  FI TFLOPS  TR TFLOPS    Speedup
--------------------------------------------------------------------------------------------------------------
     1        2           9.22        10.05       0.68       0.63       1.09x
     1        4          11.20        14.43       1.12       0.87       1.29x
     1        8          15.81        22.08       1.59       1.14       1.40x
     2        2          10.11        10.69       1.24       1.18       1.06x
     2        4          12.58        15.10       2.00       1.67       1.20x
     2        8          18.82        23.63       2.67       2.13       1.26x
     4        2          11.39        11.94       2.21       2.11       1.05x
     4        4          15.23        16.54       3.30       3.04       1.09x
     4        8          23.62        25.50       4.26       3.95       1.08x
     8        2          14.69        17.23       3.43       2.92       1.17x
     8        4          21.20        25.01       4.75       4.03       1.18x
     8        8          34.69        40.86       5.80       4.93       1.18x
    16        2          21.47        24.22       4.69       4.16       1.13x
    16        4          32.54        36.98       6.19       5.44       1.14x
    16        8          56.24        61.76       7.16       6.52       1.10x
    32        2          33.50        37.68       6.01       5.34       1.12x
    32        4          54.66        60.26       7.37       6.68       1.10x
    32        8          97.98       104.35       8.22       7.72       1.06x
    64        2          59.82        65.38       6.73       6.16       1.09x
    64        4         102.05       108.83       7.89       7.40       1.07x
    64        8         188.17       196.45       8.56       8.20       1.04x
   128        2         107.44       121.41       7.50       6.63       1.13x
   128        4         192.01       209.90       8.39       7.67       1.09x
   128        8         366.81       389.12       8.78       8.28       1.06x
   256        2         199.14       236.19       8.09       6.82       1.19x
   256        4         363.36       422.61       8.87       7.62       1.16x
   256        8         708.22       787.05       9.10       8.19       1.11x
--------------------------------------------------------------------------------------------------------------
Speedup > 1.0 means FlashInfer is faster

Summary:
  Average speedup: 1.13x
  Min speedup: 1.04x (batch=64, T=8)
  Max speedup: 1.40x (batch=1, T=8)
```

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added Triton-based benchmarks and end-to-end comparison/verify modes
across multiple memory layouts (including MTP); new verification flows
to compare implementations.
* **Performance Improvements**
* Batch-size-aware kernel selection, configurable tile/vec sizing,
fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing
for more accurate benchmarks.
* **Behavior & Compatibility**
* Improved layout handling, expanded CLI presets/modes, clearer error
messages and guards when Triton is unavailable; default benchmark mode
updated.
* **Documentation**
  * Updated usage examples and CLI guidance.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: HongliMi <1667738261@qq.com>
Co-authored-by: Hongli Mi <hmi@nvidia.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
@vadiklyutiy vadiklyutiy mentioned this pull request Feb 6, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants