Skip to content

perf: improve gdn decode cute-dsl kernels#2405

Merged
yzh119 merged 25 commits intoflashinfer-ai:mainfrom
yzh119:improve-gdn-decode
Feb 3, 2026
Merged

perf: improve gdn decode cute-dsl kernels#2405
yzh119 merged 25 commits intoflashinfer-ai:mainfrom
yzh119:improve-gdn-decode

Conversation

@yzh119
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 commented Jan 23, 2026

📌 Description

Follow up of #2370 , this PR improves the benchmark scripts and add comparison with baselines:

  • benchmark using cupti with l2 flush
  • compare with sglang's fused_sigmoid_gating_delta_rule_update function (with tile size optimization mentioned by @ vadiklyutiy).

this PR also implements some optimizations on the original gdn kernel:

  • use fastmath as much as we can
  • change "/" to multiply
  • Use cutlass.range_constexpr and cutlass.const_expr whenever possible
  • fuse scale and inv_norm_q
  • For mtp, store state in registers directly, without load/write to shared memory, and remove cpasync
  • Vectorized memory access.

Performance on B200

Non MTP setting

> python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification ===
Batch=8:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=16:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=32:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=64:
  Pretranspose: PASS
  Nontranspose: PASS


========================================================================================================================
GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON
========================================================================================================================

 batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR
       |     (us)     (us) |     (us)     (us) |   speedup   speedup |    speedup    speedup
------------------------------------------------------------------------------------------------------------------------
     1 |     3.74     5.06 |     5.95     4.35 |    1.59x    0.86x |    1.35x    0.73x
     2 |     4.29     5.89 |     6.37     5.02 |    1.49x    0.85x |    1.37x    0.79x
     4 |     5.41     7.78 |     7.58     6.66 |    1.40x    0.86x |    1.44x    0.88x
     8 |     7.65    12.03 |     9.95    10.21 |    1.30x    0.85x |    1.57x    1.03x
    16 |    12.61    19.30 |    16.83    15.81 |    1.34x    0.82x |    1.53x    0.94x
    32 |    22.91    32.86 |    31.55    27.84 |    1.38x    0.85x |    1.43x    0.88x
    64 |    52.74    58.61 |    58.91    53.02 |    1.12x    0.90x |    1.11x    0.90x
   128 |    92.93   107.98 |   114.45   106.78 |    1.23x    0.99x |    1.16x    0.93x
   256 |   170.77   209.04 |   225.71   216.41 |    1.32x    1.04x |    1.22x    0.96x
------------------------------------------------------------------------------------------------------------------------

Legend:
  FI-PreTr  = FlashInfer Pretranspose [B, HV, V, K]
  FI-NonTr  = FlashInfer Nontranspose [B, HV, K, V]
  TR-PreTr  = Triton Pretranspose [B, HV, V, K]
  TR-NonTr  = Triton Nontranspose [B, HV, K, V]
  FI/TR speedup > 1.0 means FlashInfer is faster than Triton
  Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose

FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x

MTP Setting (pretranspose only)

> python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification (MTP) ===
Batch=8: PASS
Batch=16: PASS
Batch=32: PASS
Batch=64: PASS


GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF
--------------------------------------------------------------------------------------------------------------
 batch  seq_len FlashInfer(us)   Triton(us)  FI TFLOPS  TR TFLOPS    Speedup
--------------------------------------------------------------------------------------------------------------
     1        2           9.22        10.05       0.68       0.63       1.09x
     1        4          11.20        14.43       1.12       0.87       1.29x
     1        8          15.81        22.08       1.59       1.14       1.40x
     2        2          10.11        10.69       1.24       1.18       1.06x
     2        4          12.58        15.10       2.00       1.67       1.20x
     2        8          18.82        23.63       2.67       2.13       1.26x
     4        2          11.39        11.94       2.21       2.11       1.05x
     4        4          15.23        16.54       3.30       3.04       1.09x
     4        8          23.62        25.50       4.26       3.95       1.08x
     8        2          14.69        17.23       3.43       2.92       1.17x
     8        4          21.20        25.01       4.75       4.03       1.18x
     8        8          34.69        40.86       5.80       4.93       1.18x
    16        2          21.47        24.22       4.69       4.16       1.13x
    16        4          32.54        36.98       6.19       5.44       1.14x
    16        8          56.24        61.76       7.16       6.52       1.10x
    32        2          33.50        37.68       6.01       5.34       1.12x
    32        4          54.66        60.26       7.37       6.68       1.10x
    32        8          97.98       104.35       8.22       7.72       1.06x
    64        2          59.82        65.38       6.73       6.16       1.09x
    64        4         102.05       108.83       7.89       7.40       1.07x
    64        8         188.17       196.45       8.56       8.20       1.04x
   128        2         107.44       121.41       7.50       6.63       1.13x
   128        4         192.01       209.90       8.39       7.67       1.09x
   128        8         366.81       389.12       8.78       8.28       1.06x
   256        2         199.14       236.19       8.09       6.82       1.19x
   256        4         363.36       422.61       8.87       7.62       1.16x
   256        8         708.22       787.05       9.10       8.19       1.11x
--------------------------------------------------------------------------------------------------------------
Speedup > 1.0 means FlashInfer is faster

Summary:
  Average speedup: 1.13x
  Min speedup: 1.04x (batch=64, T=8)
  Max speedup: 1.40x (batch=1, T=8)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added Triton-based benchmarks and end-to-end comparison/verify modes across multiple memory layouts (including MTP); new verification flows to compare implementations.
  • Performance Improvements
    • Batch-size-aware kernel selection, configurable tile/vec sizing, fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing for more accurate benchmarks.
  • Behavior & Compatibility
    • Improved layout handling, expanded CLI presets/modes, clearer error messages and guards when Triton is unavailable; default benchmark mode updated.
  • Documentation
    • Updated usage examples and CLI guidance.

✏️ Tip: You can customize this high-level summary in your review settings.

HongliMi and others added 14 commits January 21, 2026 04:01
1. Remove unnecessary state.copy_() when state is already contiguous:
   - For contiguous state, h0_source shares memory with state
   - Kernel updates state in-place, so copy_ is redundant

2. Cache h0_indices and cu_seqlens tensors:
   - These tensors have fixed values based on batch size
   - Reuse from cache instead of creating new tensors each call

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Use bench_gpu_time with CUPTI for GDN decode benchmark

Replace torch.profiler with flashinfer.testing.bench_gpu_time:
- More accurate kernel timing via CUPTI hardware profiling
- Simpler code without trace file parsing
- Consistent with other FlashInfer benchmarks

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Enable fastmath for exp/log/sqrt/rsqrt in GDN decode kernels

Use fastmath=True for all cute.exp, cute.log, cute.sqrt, cute.rsqrt
calls to enable faster approximate math intrinsics.

~1-2% improvement observed in small batch sizes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Use rsqrt instead of sqrt+division in GDN decode L2 norm

Replace `norm = sqrt(sum); x = x / norm` with `inv_norm = rsqrt(sum); x = x * inv_norm`
to eliminate a division instruction. The MTP kernel already used this pattern.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts by keeping GDN decode optimizations:
- fastmath=True for exp/log/rsqrt operations
- rsqrt instead of sqrt+division for L2 norm
- Cached h0_indices and cu_seqlens to avoid repeated allocation
- Optimized state.copy_ to only run when necessary
- Use bench_gpu_time (CUPTI) for benchmarks instead of torch.profiler

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 23, 2026

📝 Walkthrough

Walkthrough

Added Triton-based GDN decode and MTP kernels and benchmarking flows, replaced trace timing with CUPTI-backed GPU timing, added comparison and correctness-verification paths between FlashInfer and Triton, and made MTP tile/vec sizes dynamic with per-config caching and adjusted state-layout handling.

Changes

Cohort / File(s) Summary
Benchmarking & CLI
benchmarks/bench_gdn_decode.py
Added Triton wrappers (triton_gdn_decode*, triton_gdn_mtp), comparison/verification entry points (bench_comparison*, verify_correctness*), run flows (run_flashinfer_only_benchmark, run_comparison_benchmark), expanded CLI flags/presets, and changed default version to "nontranspose". Replaced trace timing with CUPTI-backed bench_gpu_time and added Triton-availability guards.
Benchmark flows / helpers
benchmarks/...
Introduced full-layout and single-layout comparison flows (pretranspose vs nontranspose), MTP comparison, correctness verification paths, and utility reorganizations for Triton vs FlashInfer runs.
GDN decode & MTP implementation
flashinfer/gdn_decode.py
Replaced fixed MTP constants with dynamic selection (get_tile_v_mtp, get_vec_size_mtp), threaded tile_v/vec_size through kernels and cache keys, added per-config caching for h0_indices/cu_seqlens, adjusted state copyback/contiguity handling, and applied vectorized loads, cutlass loops, and fastmath optimizations.
Kernel compilation & verification
flashinfer/gdn_decode.py
Updated _get_compiled_mtp_kernel, _get_compiled_decode_kernel, and gdn_verify_kernel_mtp to accept tile_v/vec_size, include them in cache keys, and propagate into launch/verification flows.
Public API surface
flashinfer/gdn_decode.py, benchmarks/bench_gdn_decode.py
Added exported functions: triton_gdn_decode, triton_gdn_decode_pretranspose, triton_gdn_mtp, bench_comparison*, bench_mtp_comparison, verify_correctness*, run_flashinfer_only_benchmark, run_comparison_benchmark; updated gated_delta_rule_mtp signatures to accept per-config params.
State & layout handling
flashinfer/gdn_decode.py, benchmarks/...
Normalized flattened state shapes for MTP/decode, adjusted memory layout handling (K-major vs V-major, pretranspose vs nontranspose), and reduced redundant copies.
Misc / refactor
benchmarks/bench_gdn_decode.py, flashinfer/gdn_decode.py
Reorganized utilities, improved error messaging for missing Triton, added Triton compatibility guards, and refined kernel parameterization across paths.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant BenchHarness
  participant FlashInfer
  participant TritonKernel
  participant GPU

  User->>BenchHarness: start benchmark (mode: compare / verify)
  BenchHarness->>FlashInfer: run GDN decode / MTP (FlashInfer path)
  FlashInfer->>GPU: launch FlashInfer kernel
  GPU-->>FlashInfer: results + CUPTI timings
  BenchHarness->>TritonKernel: run Triton GDN decode / MTP
  TritonKernel->>GPU: launch Triton kernel
  GPU-->>TritonKernel: results + CUPTI timings
  BenchHarness->>BenchHarness: verify outputs and states
  BenchHarness-->>User: report timings and verification outcome
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • yongwww
  • djmmoss
  • cyx-6
  • nvmbreughe
  • ttyio

Poem

🐇 I hopped through tiles and vecs so spry,
Triton and FlashInfer gave timing a try.
I balanced layouts with a twitch and a cheer,
Benchmarks hum, correctness near—
carrots for kernels, bugs disappear.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'perf: improve gdn decode cute-dsl kernels' accurately summarizes the main optimization changes to the GDN decode kernels using CuTe DSL.
Docstring Coverage ✅ Passed Docstring coverage is 89.19% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed PR description includes clear objectives, performance benchmarks, and implementation details with mostly complete checklist sections.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance and benchmarking capabilities of the GDN decode kernels. It introduces a comprehensive benchmarking framework that leverages CUPTI for precise kernel timing and enables direct performance comparisons against external Triton-based implementations. Concurrently, the core GDN CuTe DSL kernels have been optimized through fastmath usage and arithmetic operation changes, alongside improvements in state tensor management to reduce overhead.

Highlights

  • Benchmark Script Enhancements: The GDN decode benchmark script has been significantly improved to include CUPTI-based kernel timing for higher accuracy and now supports direct comparison against baseline implementations.
  • Triton Kernel Integration for Comparison: New Triton-based kernels, fused_sigmoid_gating_delta_rule_kernel and fused_sigmoid_gating_delta_rule_mtp_kernel, have been integrated into the benchmark for direct performance comparison with FlashInfer's CuTe DSL kernels, aligning with SGLang's implementation.
  • GDN Kernel Performance Optimizations: The original GDN CuTe DSL kernels have received performance optimizations, including the strategic use of fastmath for floating-point operations and replacing division with multiplication by inverse square root for L2 normalization, aiming for faster execution.
  • State Tensor Handling Optimization: Optimized the handling of state tensors in gated_delta_rule_decode_pretranspose, gated_delta_rule_decode, and gated_delta_rule_mtp to avoid unnecessary data copies when the input state is already contiguous or when the kernel updates the state in-place.
  • Cached Tensor Allocations: Introduced caching for h0_indices and cu_seqlens tensors in the GDN decode and MTP functions to prevent redundant allocations and improve efficiency across multiple calls with the same configuration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly improves the performance and benchmarking capabilities for the GDN decode kernels. The optimizations in the CuTe kernels, such as enabling fastmath and replacing division with multiplication via rsqrt, are well-implemented and should provide a good performance boost. The benchmark script has been substantially enhanced with the addition of a Triton kernel for baseline comparison, correctness verification, and a more accurate timing mechanism using bench_gpu_time. The code is well-structured and the new benchmarking features are very valuable. I have a couple of minor suggestions to further optimize the new Triton kernels for consistency with the performance goals of this PR.

Comment on lines +296 to +299
q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8)
k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8)
b_q = b_q / q_norm
b_k = b_k / k_norm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better performance, it's recommended to use tl.rsqrt and multiplication for normalization instead of tl.sqrt and division. This is consistent with the optimizations applied to the CuTe kernels in this PR and is a standard practice for performance-critical GPU code.

Suggested change
q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8)
k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8)
b_q = b_q / q_norm
b_k = b_k / k_norm
q_inv_norm = tl.rsqrt(tl.sum(b_q * b_q) + 1e-8)
k_inv_norm = tl.rsqrt(tl.sum(b_k * b_k) + 1e-8)
b_q = b_q * q_inv_norm
b_k = b_k * k_inv_norm

Comment on lines +491 to +494
q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8)
k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8)
b_q = b_q / q_norm
b_k = b_k / k_norm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the other Triton kernel, using tl.rsqrt and multiplication for normalization will be more performant than tl.sqrt and division. This change would align the Triton baseline with the optimization principles used elsewhere in this PR.

Suggested change
q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8)
k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8)
b_q = b_q / q_norm
b_k = b_k / k_norm
q_inv_norm = tl.rsqrt(tl.sum(b_q * b_q) + 1e-8)
k_inv_norm = tl.rsqrt(tl.sum(b_k * b_k) + 1e-8)
b_q = b_q * q_inv_norm
b_k = b_k * k_inv_norm

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1926-1929: The current try/except around the batch verification
(the block that sets status = "PASS" if passed else "FAIL" and prints
Batch={batch_size}) should not catch broad Exception; replace the generic
"except Exception as e" with explicit handlers for the expected verification
errors (e.g., except AssertionError as e, except ValueError as e, except
RuntimeError as e) and handle each by printing Batch={batch_size}: ERROR -
<type>, and keep an optional final except: raise to avoid swallowing unknown
errors; reference the variables/status logic in this try/except block (status,
passed, batch_size) when updating the handlers.
🧹 Nitpick comments (4)
flashinfer/gdn_decode.py (2)

945-950: Consider adding device to cache key for multi-GPU scenarios.

The device check handles device mismatches by creating new tensors, but the cache key (cache_key at line 942) doesn't include the device. If the same configuration is used across multiple GPUs, the cached tensors will be repeatedly recreated.

Consider either:

  1. Adding q.device to the cache key, or
  2. Using a nested dict keyed by device within the cache
♻️ Suggested improvement
# Option 1: Include device in cache key
cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm, q.device.index)

Or store per-device tensors in the cache:

device_key = q.device.index
if "h0_indices" not in cache:
    cache["h0_indices"] = {}
    cache["cu_seqlens"] = {}
if device_key not in cache["h0_indices"]:
    cache["h0_indices"][device_key] = torch.zeros(B, dtype=torch.int32, device=q.device)
    cache["cu_seqlens"][device_key] = torch.zeros(B + 1, dtype=torch.int32, device=q.device)
h0_indices = cache["h0_indices"][device_key]
cu_seqlens = cache["cu_seqlens"][device_key]

2419-2423: Potential issue: Contiguity check may not cover all copy scenarios.

The condition checks initial_state.is_contiguous() but h0_source is derived via .to(torch.float32).reshape(...). While .to() returns self when dtype matches, this relies on implementation details.

Consider using the identity check pattern (like the nontranspose version at line 1838) for consistency:

♻️ Suggested fix
-    if not disable_state_update and not initial_state.is_contiguous():
+    # Create reshaped view/copy of initial_state
+    h0_flat = initial_state.reshape(pool_size * HV, V, K)
+    h0_source = h0_flat.contiguous() if not h0_flat.is_contiguous() else h0_flat
+    
+    # ... later after kernel execution ...
+    if not disable_state_update and h0_source.data_ptr() != initial_state.data_ptr():
         initial_state.copy_(h0_source.reshape(pool_size, HV, V, K))

Alternatively, verify that the current logic works correctly in all cases.

benchmarks/bench_gdn_decode.py (2)

219-227: Optional: Remove unused B parameter from Triton kernels.

The B parameter is declared as tl.constexpr but not used in the kernel body. The batch index is derived from i_bh // HV and bounds checking uses K_DIM and V_DIM.

♻️ Suggested fix

Remove B: tl.constexpr, from the kernel signature and the corresponding argument at call sites. This applies to all three Triton kernels: fused_sigmoid_gating_delta_rule_kernel, fused_sigmoid_gating_delta_rule_mtp_kernel, and fused_sigmoid_gating_delta_rule_kernel_pretranspose.


1346-1355: Consider renaming speedup for clarity.

The current calculation triton_median_us / flashinfer_median_us means "how many times slower Triton is" rather than "how much faster FlashInfer is". While the comment at line 2049 clarifies this, consider renaming to flashinfer_speedup or inverting the ratio for more intuitive interpretation.

♻️ Alternative naming
# Option 1: Rename for clarity
flashinfer_speedup = triton_median_us / flashinfer_median_us  # >1 means FI faster

# Option 2: Use ratio that matches conventional speedup semantics  
speedup = flashinfer_median_us / triton_median_us  # >1 means Triton faster
# Then update print statement: "Speedup > 1.0 means Triton is faster"

Comment on lines +1926 to +1929
status = "PASS" if passed else "FAIL"
print(f"Batch={batch_size}: {status}")
except Exception as e:
print(f"Batch={batch_size}: ERROR - {type(e).__name__}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Catch specific exceptions instead of broad Exception.

Catching Exception can mask unexpected errors. Consider catching the specific exceptions that verification can raise:

🔧 Suggested fix
-            except Exception as e:
-                print(f"Batch={batch_size}: ERROR - {type(e).__name__}")
+            except (RuntimeError, torch.cuda.CudaError) as e:
+                print(f"Batch={batch_size}: ERROR - {type(e).__name__}: {e}")
🧰 Tools
🪛 Ruff (0.14.14)

1928-1928: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@benchmarks/bench_gdn_decode.py` around lines 1926 - 1929, The current
try/except around the batch verification (the block that sets status = "PASS" if
passed else "FAIL" and prints Batch={batch_size}) should not catch broad
Exception; replace the generic "except Exception as e" with explicit handlers
for the expected verification errors (e.g., except AssertionError as e, except
ValueError as e, except RuntimeError as e) and handle each by printing
Batch={batch_size}: ERROR - <type>, and keep an optional final except: raise to
avoid swallowing unknown errors; reference the variables/status logic in this
try/except block (status, passed, batch_size) when updating the handlers.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)

238-244: Fix v-indexing: sV preload only covers first 128 elements.
sV is filled with indices 0..127 regardless of start_v_tiles or V. For V > 128 or for small-batch blocks with batch_inner > 0, v_tiles * TILE_V + ... will read uninitialized values. Load v for the current tile or read v directly in the loop.

🐛 Proposed fix (read v directly per tile)
-            v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk
+            v_idx = v_tiles * TILE_V + row + row_offset
+            v_new = cutlass.Float32(v[i_n, i_t, i_hv, v_idx]) - sum_hk
-            v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk
+            v_idx = v_tiles * TILE_V + row + row_offset
+            v_new = cutlass.Float32(v[i_n, i_t, i_hv, v_idx]) - sum_hk

Also applies to: 346-347, 474-479, 581-582

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 105-114: The get_vec_size_mtp function currently chooses vec_size
solely from batch_size which can yield threads_per_group = K / vec_size > 32
when K > 128 and batch_size <= 4, causing groups_per_warp to become zero and
divide-by-zero downstream; update get_vec_size_mtp to consider K as well (either
accept K as an additional parameter or validate K before returning vec_size) and
select vec_size=8 when K/4 > 32 (or assert/raise a clear error if K is
incompatible), and apply the same guard/selection logic to the other occurrences
noted (around lines referenced: the other get_vec_size_mtp usages at 1917-1923,
2060-2062, 2333-2336) so threads_per_group never exceeds 32.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

2241-2257: Silence unused-arg lint for cache-only parameters.
Ruff flags tile_v/vec_size as unused. Consider explicitly marking them as cache-key-only to avoid ARG001 noise.

♻️ Proposed fix
 def _get_compiled_mtp_kernel(
     B: int,
     T: int,
     H: int,
@@
     tile_v: int,  # TILE_V - configurable for batch size
     vec_size: int,  # 4 for full warp, 8 for half-warp
 ):
     """Cache compiled MTP kernel for given configuration."""
+    _ = (tile_v, vec_size)  # used by functools.cache key
     return {}

Comment on lines +105 to +114
def get_vec_size_mtp(batch_size: int) -> int:
"""Select vec_size for MTP kernel based on batch size.

B <= 4: vec_size=4 (full warp reduction, 5 shuffles) - better for small batch
B > 4: vec_size=8 (half-warp reduction, 4 shuffles) - better for large batch
"""
if batch_size <= 4:
return 4 # Full warp: 32 threads * 4 elements = 128
else:
return 8 # Half-warp: 16 threads * 8 elements = 128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard MTP vec_size against K > 128 in small batch.
With B <= 4, vec_size=4threads_per_group = K/4. For K > 128, this exceeds 32, making groups_per_warp = 0 and rows_per_group = tile_v // 0 invalid. Either assert K compatibility for this path or choose vec_size based on K as well.

🐛 Proposed fix (explicit guard)
-    vec_size = get_vec_size_mtp(B)
+    vec_size = get_vec_size_mtp(B)
+    assert K // vec_size <= 32, (
+        "K too large for selected vec_size; adjust vec_size selection or constrain K."
+    )

Also applies to: 1917-1923, 2060-2062, 2333-2336

🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 105 - 114, The get_vec_size_mtp
function currently chooses vec_size solely from batch_size which can yield
threads_per_group = K / vec_size > 32 when K > 128 and batch_size <= 4, causing
groups_per_warp to become zero and divide-by-zero downstream; update
get_vec_size_mtp to consider K as well (either accept K as an additional
parameter or validate K before returning vec_size) and select vec_size=8 when
K/4 > 32 (or assert/raise a clear error if K is incompatible), and apply the
same guard/selection logic to the other occurrences noted (around lines
referenced: the other get_vec_size_mtp usages at 1917-1923, 2060-2062,
2333-2336) so threads_per_group never exceeds 32.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2463-2472: The code ignores --compare for non‑MTP runs because the
version branch always calls run_all_layouts_benchmark; move or add the compare
check so args.compare is evaluated first (or added into the non‑MTP branch).
Specifically, before calling run_all_layouts_benchmark, check if args.compare
and call run_comparison_benchmark(args, dtype, use_qk_l2norm); otherwise
preserve the existing behavior (for MTP: run_comparison_benchmark vs
run_flashinfer_only_benchmark; for non‑MTP: run_all_layouts_benchmark only when
not comparing). Update the logic around the run_all_layouts_benchmark,
run_comparison_benchmark, and run_flashinfer_only_benchmark calls accordingly so
--compare affects both MTP and non‑MTP paths.
🧹 Nitpick comments (1)
benchmarks/bench_gdn_decode.py (1)

2062-2069: Summary speedup can mismatch batches when some results are missing.

fi_pre_times and tr_pre_times are filtered independently, so the zipped list can pair different batches if any entry is missing on one side.

♻️ Suggested adjustment
-    fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")]
-    tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")]
-
-    if fi_pre_times and tr_pre_times:
-        speedups = [tr / fi for fi, tr in zip(fi_pre_times, tr_pre_times, strict=False)]
+    speedups = [
+        r["tr_pretrans_us"] / r["fi_pretrans_us"]
+        for r in all_results
+        if r.get("fi_pretrans_us") and r.get("tr_pretrans_us")
+    ]
+
+    if speedups:
         print(
             f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x"
         )

Comment on lines +2463 to +2472
if args.version == "mtp":
# MTP mode: use comparison or flashinfer-only
if args.compare:
run_comparison_benchmark(args, dtype, use_qk_l2norm)
else:
run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
else:
# Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose)
run_all_layouts_benchmark(args, dtype, use_qk_l2norm)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

--compare is ignored for non‑MTP paths.

For non‑MTP runs, main() always calls run_all_layouts_benchmark, so --compare (and single‑layout intent in the usage text) never takes effect. This is a user‑visible behavior bug.

🛠️ Proposed fix
-    if args.version == "mtp":
-        # MTP mode: use comparison or flashinfer-only
-        if args.compare:
-            run_comparison_benchmark(args, dtype, use_qk_l2norm)
-        else:
-            run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
-    else:
-        # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose)
-        run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
+    if args.version == "mtp":
+        # MTP mode: use comparison or flashinfer-only
+        if args.compare:
+            run_comparison_benchmark(args, dtype, use_qk_l2norm)
+        else:
+            run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
+    else:
+        # Non-MTP
+        if args.compare and args.version != "all":
+            run_comparison_benchmark(args, dtype, use_qk_l2norm)
+        else:
+            run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
🤖 Prompt for AI Agents
In `@benchmarks/bench_gdn_decode.py` around lines 2463 - 2472, The code ignores
--compare for non‑MTP runs because the version branch always calls
run_all_layouts_benchmark; move or add the compare check so args.compare is
evaluated first (or added into the non‑MTP branch). Specifically, before calling
run_all_layouts_benchmark, check if args.compare and call
run_comparison_benchmark(args, dtype, use_qk_l2norm); otherwise preserve the
existing behavior (for MTP: run_comparison_benchmark vs
run_flashinfer_only_benchmark; for non‑MTP: run_all_layouts_benchmark only when
not comparing). Update the logic around the run_all_layouts_benchmark,
run_comparison_benchmark, and run_flashinfer_only_benchmark calls accordingly so
--compare affects both MTP and non‑MTP paths.

@xutizhou
Copy link
Copy Markdown
Contributor

The precision is fine compared to the Triton reference. Performance improves by approximately 20%~40% over the Triton kernel at large batch sizes. However, for small batch sizes, the Triton reference should use BV=8.

BS FlashInfer PR2405 MTP (us / GB/s) Triton total (us / GB/s) FI speedup
1 5.501 / 962.1 7.888 / 670.9 1.43x
2 7.598 / 1393.0 8.774 / 1206.3 1.15x
4 13.098 / 1616.3 12.259 / 1726.8 0.94x
8 20.896 / 2026.2 23.745 / 1783.0 1.14x
16 35.509 / 2384.7 47.042 / 1800.0 1.32x

yzh119 and others added 5 commits January 27, 2026 20:31
- Change from interleaved to contiguous memory access pattern
- Use cute.local_tile + cute.autovec_copy for vectorized memory
  operations instead of scalar for-loops
- Set vec_size=8 (half-warp, 8 groups) for all batch sizes
- Tune tile_v: 8/16/32/64 based on batch size thresholds

This achieves 1.06x-1.45x speedup over Triton baseline (avg 1.17x).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use 3D local_tile directly instead of slice + 1D local_tile
- Cleaner code without intermediate tensor creation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace scalar loop with autovec_copy for coalesced vectorized loads
- Load into BF16 registers first, then convert to FP32
- Improves average speedup from 1.18x to 1.22x over Triton

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change q, k, v loading from interleaved to contiguous pattern
- Use autovec_copy for vectorized loads (BF16 -> FP32 conversion)
- Use local_tile + autovec_copy for h read/write in mainloop
- Applies to both small_batch and big_batch pretranspose kernels

Achieves 1.4x-1.6x speedup over Triton pretranspose baseline.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Always use vec_size=4 (32 threads/group = full warp) instead of vec_size=8
- Full warp shuffle is more efficient than half-warp shuffle
- Tune tile_v per batch size via grid search:
  B≤2: tile_v=4, B≤4: tile_v=8, B≤8: tile_v=16, B≤16: tile_v=32, B>16: tile_v=64
- Remove dead code for vec_size=8 shuffle branches
- Achieves >= 1.04x speedup vs Triton across all batch sizes (avg 1.14x)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gdn_decode.py (2)

728-728: Dead code: expression computed but result discarded.

This expression computes a value (likely data size in MB) but doesn't assign it anywhere. Remove it or assign to a variable if needed for debugging.

🧹 Proposed fix
     num_v_tiles = cute.ceil_div(v_dim, TILE_V)
-    v_dim * k_dim * batch_size * 4 / 1024 / 1024
 
     vec_size = (

831-831: Dead code: another unused expression.

Same issue as line 728. This and the commented print statements below should be cleaned up.

🧹 Proposed fix
     num_v_tiles = cute.ceil_div(v_dim, TILE_V)
-    v_dim * k_dim * batch_size * 4 / 1024 / 1024
 
     vec_size = (
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1603-1606: The stray expression h0_indices.layout.shape[0] is
evaluated but not used; either remove this no-op or assign it to a meaningful
variable (e.g., hv_dim or use it to compute batch_size) so the value is
consumed. Locate the block where h0_source.layout.shape unpacks into
batch_hv_dim, k_dim, v_dim and the subsequent h0_indices.layout.shape[0]
expression, and either delete that expression or replace it with an assignment
that is actually used by downstream logic (same fix also apply to the analogous
occurrence around the h0_* code at lines corresponding to 1685-1687). Ensure
variable names (h0_indices, h0_source, batch_hv_dim, batch_size) are consistent
after the change.
- Around line 2416-2420: Add an upper-bound validation for K to fail fast when
the MTP kernel's warp-based grouping is violated: in the same validation block
that contains the asserts for K and V (referencing symbols K, V, and tile_v in
gdn_decode.py) add an assertion that K <= 128 with a clear error message like "K
must be at most 128, got K={K}" so the code raises a descriptive exception
instead of producing cryptic kernel failures.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

1997-2000: Thread grouping assumes K ≤ 128; document or assert this constraint.

The calculation threads_per_group = K // vec_size assumes the result fits within a warp (≤32 threads). With vec_size=4 (always returned by get_vec_size_mtp), this requires K ≤ 128. Consider adding a comment or assert to make this constraint explicit at the kernel level.

Comment on lines +1603 to +1606
# h0_source is flattened to [B*HV, K, V] to ensure proper alignment for SIMT async copy
batch_hv_dim, k_dim, v_dim = h0_source.layout.shape
h0_indices.layout.shape[0]
batch_size = batch_hv_dim # batch_hv_dim = B * HV
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused expression: h0_indices.layout.shape[0] evaluated but not assigned.

This statement has no effect. Either remove it or assign to a variable if the value is needed.

🧹 Proposed fix
     # h0_source is flattened to [B*HV, K, V] to ensure proper alignment for SIMT async copy
-    batch_hv_dim, k_dim, v_dim = h0_source.layout.shape
-    h0_indices.layout.shape[0]
-    batch_size = batch_hv_dim  # batch_hv_dim = B * HV
+    batch_hv_dim, _k_dim, v_dim = h0_source.layout.shape
+    batch_size = batch_hv_dim  # batch_hv_dim = B * HV

Also applies to lines 1685-1687.

🧰 Tools
🪛 Ruff (0.14.14)

1604-1604: Unpacked variable k_dim is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 1603 - 1606, The stray expression
h0_indices.layout.shape[0] is evaluated but not used; either remove this no-op
or assign it to a meaningful variable (e.g., hv_dim or use it to compute
batch_size) so the value is consumed. Locate the block where
h0_source.layout.shape unpacks into batch_hv_dim, k_dim, v_dim and the
subsequent h0_indices.layout.shape[0] expression, and either delete that
expression or replace it with an assignment that is actually used by downstream
logic (same fix also apply to the analogous occurrence around the h0_* code at
lines corresponding to 1685-1687). Ensure variable names (h0_indices, h0_source,
batch_hv_dim, batch_size) are consistent after the change.

Comment on lines 2416 to 2420
assert K >= 128, f"K must be at least 128, got K={K}"
assert V >= 128, f"V must be at least 128, got V={V}"
assert V % TILE_V_MTP == 0, (
f"V must be divisible by {TILE_V_MTP} to prevent out-of-bounds access, got V={V}"
assert V % tile_v == 0, (
f"V must be divisible by {tile_v} to prevent out-of-bounds access, got V={V}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing K upper bound validation for MTP kernel.

The validation checks K >= 128 but the kernel requires K <= 128 due to the warp-based thread grouping. Add an upper bound check to fail fast with a clear error message instead of cryptic kernel failures.

🐛 Proposed fix
     # Validate K and V constraints
     assert K >= 128, f"K must be at least 128, got K={K}"
+    assert K == 128, f"MTP kernel currently requires K=128 (warp thread grouping constraint), got K={K}"
     assert V >= 128, f"V must be at least 128, got V={V}"
     assert V % tile_v == 0, (
         f"V must be divisible by {tile_v} to prevent out-of-bounds access, got V={V}"
     )
🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 2416 - 2420, Add an upper-bound
validation for K to fail fast when the MTP kernel's warp-based grouping is
violated: in the same validation block that contains the asserts for K and V
(referencing symbols K, V, and tile_v in gdn_decode.py) add an assertion that K
<= 128 with a clear error message like "K must be at most 128, got K={K}" so the
code raises a descriptive exception instead of producing cryptic kernel
failures.

@yzh119
Copy link
Copy Markdown
Collaborator Author

yzh119 commented Jan 28, 2026

@xutizhou the performance gap for B=4 should be fixed in most recent commit.

@yzh119
Copy link
Copy Markdown
Collaborator Author

yzh119 commented Jan 28, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !271 has been created, and the CI pipeline #42706808 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #42706808: 11/20 passed

@xutizhou
Copy link
Copy Markdown
Contributor

xutizhou commented Jan 28, 2026 via email

@yzh119 yzh119 merged commit 5e5a866 into flashinfer-ai:main Feb 3, 2026
22 of 26 checks passed
raayandhar pushed a commit to raayandhar/flashinfer that referenced this pull request Feb 5, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Follow up of flashinfer-ai#2370 , this PR improves the benchmark scripts and add
comparison with baselines:
* benchmark using cupti with l2 flush
* compare with sglang's `fused_sigmoid_gating_delta_rule_update`
function (with tile size optimization mentioned by @ vadiklyutiy).

this PR also implements some optimizations on the original gdn kernel:
* use fastmath as much as we can
* change "/" to multiply
* Use `cutlass.range_constexpr` and `cutlass.const_expr` whenever
possible
* fuse scale and inv_norm_q
* For mtp, store state in registers directly, without load/write to
shared memory, and remove cpasync
* Vectorized memory access.

## Performance on B200

Non MTP setting
```
> python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification ===
Batch=8:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=16:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=32:
  Pretranspose: PASS
  Nontranspose: PASS
Batch=64:
  Pretranspose: PASS
  Nontranspose: PASS


========================================================================================================================
GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON
========================================================================================================================

 batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR
       |     (us)     (us) |     (us)     (us) |   speedup   speedup |    speedup    speedup
------------------------------------------------------------------------------------------------------------------------
     1 |     3.74     5.06 |     5.95     4.35 |    1.59x    0.86x |    1.35x    0.73x
     2 |     4.29     5.89 |     6.37     5.02 |    1.49x    0.85x |    1.37x    0.79x
     4 |     5.41     7.78 |     7.58     6.66 |    1.40x    0.86x |    1.44x    0.88x
     8 |     7.65    12.03 |     9.95    10.21 |    1.30x    0.85x |    1.57x    1.03x
    16 |    12.61    19.30 |    16.83    15.81 |    1.34x    0.82x |    1.53x    0.94x
    32 |    22.91    32.86 |    31.55    27.84 |    1.38x    0.85x |    1.43x    0.88x
    64 |    52.74    58.61 |    58.91    53.02 |    1.12x    0.90x |    1.11x    0.90x
   128 |    92.93   107.98 |   114.45   106.78 |    1.23x    0.99x |    1.16x    0.93x
   256 |   170.77   209.04 |   225.71   216.41 |    1.32x    1.04x |    1.22x    0.96x
------------------------------------------------------------------------------------------------------------------------

Legend:
  FI-PreTr  = FlashInfer Pretranspose [B, HV, V, K]
  FI-NonTr  = FlashInfer Nontranspose [B, HV, K, V]
  TR-PreTr  = Triton Pretranspose [B, HV, V, K]
  TR-NonTr  = Triton Nontranspose [B, HV, K, V]
  FI/TR speedup > 1.0 means FlashInfer is faster than Triton
  Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose

FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x
```

MTP Setting (pretranspose only)
```
> python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify

=== Correctness Verification (MTP) ===
Batch=8: PASS
Batch=16: PASS
Batch=32: PASS
Batch=64: PASS


GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton
Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF
--------------------------------------------------------------------------------------------------------------
 batch  seq_len FlashInfer(us)   Triton(us)  FI TFLOPS  TR TFLOPS    Speedup
--------------------------------------------------------------------------------------------------------------
     1        2           9.22        10.05       0.68       0.63       1.09x
     1        4          11.20        14.43       1.12       0.87       1.29x
     1        8          15.81        22.08       1.59       1.14       1.40x
     2        2          10.11        10.69       1.24       1.18       1.06x
     2        4          12.58        15.10       2.00       1.67       1.20x
     2        8          18.82        23.63       2.67       2.13       1.26x
     4        2          11.39        11.94       2.21       2.11       1.05x
     4        4          15.23        16.54       3.30       3.04       1.09x
     4        8          23.62        25.50       4.26       3.95       1.08x
     8        2          14.69        17.23       3.43       2.92       1.17x
     8        4          21.20        25.01       4.75       4.03       1.18x
     8        8          34.69        40.86       5.80       4.93       1.18x
    16        2          21.47        24.22       4.69       4.16       1.13x
    16        4          32.54        36.98       6.19       5.44       1.14x
    16        8          56.24        61.76       7.16       6.52       1.10x
    32        2          33.50        37.68       6.01       5.34       1.12x
    32        4          54.66        60.26       7.37       6.68       1.10x
    32        8          97.98       104.35       8.22       7.72       1.06x
    64        2          59.82        65.38       6.73       6.16       1.09x
    64        4         102.05       108.83       7.89       7.40       1.07x
    64        8         188.17       196.45       8.56       8.20       1.04x
   128        2         107.44       121.41       7.50       6.63       1.13x
   128        4         192.01       209.90       8.39       7.67       1.09x
   128        8         366.81       389.12       8.78       8.28       1.06x
   256        2         199.14       236.19       8.09       6.82       1.19x
   256        4         363.36       422.61       8.87       7.62       1.16x
   256        8         708.22       787.05       9.10       8.19       1.11x
--------------------------------------------------------------------------------------------------------------
Speedup > 1.0 means FlashInfer is faster

Summary:
  Average speedup: 1.13x
  Min speedup: 1.04x (batch=64, T=8)
  Max speedup: 1.40x (batch=1, T=8)
```

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added Triton-based benchmarks and end-to-end comparison/verify modes
across multiple memory layouts (including MTP); new verification flows
to compare implementations.
* **Performance Improvements**
* Batch-size-aware kernel selection, configurable tile/vec sizing,
fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing
for more accurate benchmarks.
* **Behavior & Compatibility**
* Improved layout handling, expanded CLI presets/modes, clearer error
messages and guards when Triton is unavailable; default benchmark mode
updated.
* **Documentation**
  * Updated usage examples and CLI guidance.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: HongliMi <1667738261@qq.com>
Co-authored-by: Hongli Mi <hmi@nvidia.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
@ameynaik-hub ameynaik-hub mentioned this pull request Feb 6, 2026
5 tasks
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cache intermediate states should be True since it is required MTP.

@@ -425,52 +1213,14 @@ def bench_gdn_mtp(
intermediate_states_buffer,
disable_state_update=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt this be False for benchmarking? because we want the updated state to be an output in mtp.

@@ -425,52 +1213,14 @@ def bench_gdn_mtp(
intermediate_states_buffer,
disable_state_update=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 I think this should be False, we want h updated as output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants